From 53ecf08210df2d79e3e9a30c6a1bf6c37f0d8c1b Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Thu, 12 Oct 2023 14:34:28 -0500 Subject: [PATCH 01/22] cluster client --- arq/connections.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index d4fc4434..dcae98c5 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -11,7 +11,7 @@ from redis.asyncio import ConnectionPool, Redis from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError - +from redis.asyncio.cluster import RedisCluster from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job from .utils import timestamp_ms, to_ms, to_unix_ms @@ -43,7 +43,7 @@ class RedisSettings: conn_timeout: int = 1 conn_retries: int = 5 conn_retry_delay: int = 1 - + cluster_mode: bool = False sentinel: bool = False sentinel_master: str = 'mymaster' @@ -74,7 +74,7 @@ def __repr__(self) -> str: if TYPE_CHECKING: BaseRedis = Redis[bytes] else: - BaseRedis = Redis + BaseRedis = RedisCluster class ArqRedis(BaseRedis): From 9b7b32f2872d5088f367f15c9bd2c917806e6a5f Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Thu, 12 Oct 2023 15:26:14 -0500 Subject: [PATCH 02/22] connection initilization --- arq/connections.py | 5 +++-- arq/jobs.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index dcae98c5..98a82cf5 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -72,7 +72,7 @@ def __repr__(self) -> str: if TYPE_CHECKING: - BaseRedis = Redis[bytes] + BaseRedis = RedisCluster[bytes] else: BaseRedis = RedisCluster @@ -255,12 +255,13 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: while True: try: pool = pool_factory( - db=settings.database, username=settings.username, password=settings.password, encoding='utf8' + db=settings.database, password=settings.password, encoding='utf8' ) pool.job_serializer = job_serializer pool.job_deserializer = job_deserializer pool.default_queue_name = default_queue_name pool.expires_extra_ms = expires_extra_ms + await pool.initialize() await pool.ping() except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e: diff --git a/arq/jobs.py b/arq/jobs.py index 8028cbe7..c2ae796b 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -7,8 +7,8 @@ from enum import Enum from typing import Any, Callable, Dict, Optional, Tuple -from redis.asyncio import Redis +from redis.asyncio.cluster import RedisCluster from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix from .utils import ms_to_datetime, poll, timestamp_ms @@ -73,7 +73,7 @@ class Job: def __init__( self, job_id: str, - redis: 'Redis[bytes]', + redis: 'RedisCluster[bytes]', _queue_name: str = default_queue_name, _deserializer: Optional[Deserializer] = None, ): From 8329cdf7a24b4b06214b1eecab985d386192fa66 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Mon, 23 Oct 2023 09:35:02 -0500 Subject: [PATCH 03/22] pipelines kicked off --- .gitignore | 1 + arq/connections.py | 56 +++--- arq/jobs.py | 6 +- arq/worker.py | 14 +- pyproject.toml | 2 +- test.py | 87 ++++++++ tests/conftest.py | 32 ++- tests/test_main.py | 484 ++++++++++++++++++++++----------------------- 8 files changed, 379 insertions(+), 303 deletions(-) create mode 100644 test.py diff --git a/.gitignore b/.gitignore index e2d3e183..a287f54f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /env*/ +/venv*/ /.idea __pycache__/ *.py[cod] diff --git a/arq/connections.py b/arq/connections.py index 98a82cf5..36303301 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -8,7 +8,7 @@ from urllib.parse import parse_qs, urlparse from uuid import uuid4 -from redis.asyncio import ConnectionPool, Redis +from redis.asyncio import ConnectionPool from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError from redis.asyncio.cluster import RedisCluster @@ -29,8 +29,7 @@ class RedisSettings: host: Union[str, List[Tuple[str, int]]] = 'localhost' port: int = 6379 - unix_socket_path: Optional[str] = None - database: int = 0 + username: Optional[str] = None password: Optional[str] = None ssl: bool = False @@ -51,20 +50,13 @@ class RedisSettings: def from_dsn(cls, dsn: str) -> 'RedisSettings': conf = urlparse(dsn) assert conf.scheme in {'redis', 'rediss', 'unix'}, 'invalid DSN scheme' - query_db = parse_qs(conf.query).get('db') - if query_db: - # e.g. redis://localhost:6379?db=1 - database = int(query_db[0]) - else: - database = int(conf.path.lstrip('/')) if conf.path else 0 + + return RedisSettings( host=conf.hostname or 'localhost', port=conf.port or 6379, ssl=conf.scheme == 'rediss', - username=conf.username, password=conf.password, - database=database, - unix_socket_path=conf.path if conf.scheme == 'unix' else None, ) def __repr__(self) -> str: @@ -143,8 +135,10 @@ async def enqueue_job( defer_by_ms = to_ms(_defer_by) expires_ms = to_ms(_expires) - async with self.pipeline(transaction=True) as pipe: - await pipe.watch(job_key) + + async with self.pipeline() as pipe: + logger.debug("insides pipeline---------------------------") + if await pipe.exists(job_key, result_key_prefix + job_id): await pipe.reset() return None @@ -160,7 +154,7 @@ async def enqueue_job( expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer) - pipe.multi() + pipe.psetex(job_key, expires_ms, job) # type: ignore[no-untyped-call] pipe.zadd(_queue_name, {job_id: score}) # type: ignore[unused-coroutine] try: @@ -241,7 +235,6 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: ArqRedis, host=settings.host, port=settings.port, - unix_socket_path=settings.unix_socket_path, socket_connect_timeout=settings.conn_timeout, ssl=settings.ssl, ssl_keyfile=settings.ssl_keyfile, @@ -254,15 +247,15 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: while True: try: - pool = pool_factory( - db=settings.database, password=settings.password, encoding='utf8' + pool = await pool_factory( + password=settings.password, encoding='utf8' ) pool.job_serializer = job_serializer pool.job_deserializer = job_deserializer pool.default_queue_name = default_queue_name pool.expires_extra_ms = expires_extra_ms - await pool.initialize() - await pool.ping() + + except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e: if retry < settings.conn_retries: @@ -284,21 +277,22 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: return pool -async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any]) -> None: - async with redis.pipeline(transaction=False) as pipe: - pipe.info(section='Server') # type: ignore[unused-coroutine] - pipe.info(section='Memory') # type: ignore[unused-coroutine] - pipe.info(section='Clients') # type: ignore[unused-coroutine] - pipe.dbsize() # type: ignore[unused-coroutine] - info_server, info_memory, info_clients, key_count = await pipe.execute() +async def log_redis_info(redis: 'RedisCluster[bytes]', log_func: Callable[[str], Any]) -> None: + # async with redis.pipeline() as pipe: + # pipe.info(section='Server') + # # type: ignore[unused-coroutine] + # pipe.info(section='Memory') # type: ignore[unused-coroutine] + # pipe.info(section='Clients') # type: ignore[unused-coroutine] + # pipe.dbsize() # type: ignore[unused-coroutine] + # info_server, info_memory, info_clients, key_count = await pipe.execute() - redis_version = info_server.get('redis_version', '?') - mem_usage = info_memory.get('used_memory_human', '?') - clients_connected = info_clients.get('connected_clients', '?') + redis_version = "info_server.get('redis_version', '?')" + mem_usage = "info_memory.get('used_memory_human', '?')" + clients_connected =" info_clients.get('connected_clients', '?')" log_func( f'redis_version={redis_version} ' f'mem_usage={mem_usage} ' f'clients_connected={clients_connected} ' - f'db_keys={key_count}' + f'db_keys={88}' ) diff --git a/arq/jobs.py b/arq/jobs.py index c2ae796b..cce65a9d 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -102,7 +102,7 @@ async def result( poll_delay = pole_delay async for delay in poll(poll_delay): - async with self._redis.pipeline(transaction=True) as tr: + async with self._redis.pipeline() as tr: tr.get(result_key_prefix + self.job_id) # type: ignore[unused-coroutine] tr.zscore(self._queue_name, self.job_id) # type: ignore[unused-coroutine] v, s = await tr.execute() @@ -153,7 +153,7 @@ async def status(self) -> JobStatus: """ Status of the job. """ - async with self._redis.pipeline(transaction=True) as tr: + async with self._redis.pipeline() as tr: tr.exists(result_key_prefix + self.job_id) # type: ignore[unused-coroutine] tr.exists(in_progress_key_prefix + self.job_id) # type: ignore[unused-coroutine] tr.zscore(self._queue_name, self.job_id) # type: ignore[unused-coroutine] @@ -179,7 +179,7 @@ async def abort(self, *, timeout: Optional[float] = None, poll_delay: float = 0. """ job_info = await self.info() if job_info and job_info.score and job_info.score > timestamp_ms(): - async with self._redis.pipeline(transaction=True) as tr: + async with self._redis.pipeline() as tr: tr.zrem(self._queue_name, self.job_id) # type: ignore[unused-coroutine] tr.zadd(self._queue_name, {self.job_id: 1}) # type: ignore[unused-coroutine] await tr.execute() diff --git a/arq/worker.py b/arq/worker.py index 81afd5b7..a62009ad 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -397,7 +397,7 @@ async def _cancel_aborted_jobs(self) -> None: """ Go through job_ids in the abort_jobs_ss sorted set and cancel those tasks. """ - async with self.pool.pipeline(transaction=True) as pipe: + async with self.pool.pipeline() as pipe: pipe.zrange(abort_jobs_ss, start=0, end=-1) # type: ignore[unused-coroutine] pipe.zremrangebyscore( # type: ignore[unused-coroutine] abort_jobs_ss, min=timestamp_ms() + abort_job_max_age, max=float('inf') @@ -427,7 +427,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: await self.sem.acquire() job_id = job_id_b.decode() in_progress_key = in_progress_key_prefix + job_id - async with self.pool.pipeline(transaction=True) as pipe: + async with self.pool.pipeline() as pipe: await pipe.watch(in_progress_key) ongoing_exists = await pipe.exists(in_progress_key) score = await pipe.zscore(self.queue_name, job_id) @@ -454,7 +454,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: async def run_job(self, job_id: str, score: int) -> None: # noqa: C901 start_ms = timestamp_ms() - async with self.pool.pipeline(transaction=True) as pipe: + async with self.pool.pipeline() as pipe: pipe.get(job_key_prefix + job_id) # type: ignore[unused-coroutine] pipe.incr(retry_key_prefix + job_id) # type: ignore[unused-coroutine] pipe.expire(retry_key_prefix + job_id, 88400) # type: ignore[unused-coroutine] @@ -663,7 +663,7 @@ async def finish_job( incr_score: Optional[int], keep_in_progress: Optional[float], ) -> None: - async with self.pool.pipeline(transaction=True) as tr: + async with self.pool.pipeline() as tr: delete_keys = [] in_progress_key = in_progress_key_prefix + job_id if keep_in_progress is None: @@ -685,7 +685,7 @@ async def finish_job( await tr.execute() async def finish_failed_job(self, job_id: str, result_data: Optional[bytes]) -> None: - async with self.pool.pipeline(transaction=True) as tr: + async with self.pool.pipeline() as tr: tr.delete( # type: ignore[unused-coroutine] retry_key_prefix + job_id, in_progress_key_prefix + job_id, @@ -843,7 +843,7 @@ async def close(self) -> None: await self.pool.delete(self.health_check_key) if self.on_shutdown: await self.on_shutdown(self.ctx) - await self.pool.close(close_connection_pool=True) + await self.pool.close() self._pool = None def __repr__(self) -> str: @@ -884,7 +884,7 @@ async def async_check_health( else: logger.info('Health check successful: %s', data) r = 0 - await redis.close(close_connection_pool=True) + await redis.close() return r diff --git a/pyproject.toml b/pyproject.toml index 7d88ada4..faad1826 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ Changelog = 'https://github.com/samuelcolvin/arq/releases' testpaths = 'tests' filterwarnings = ['error'] asyncio_mode = 'auto' -timeout = 10 + [tool.coverage.run] source = ['arq'] diff --git a/test.py b/test.py new file mode 100644 index 00000000..8fe226e5 --- /dev/null +++ b/test.py @@ -0,0 +1,87 @@ +from redis.cluster import RedisCluster, ClusterNode +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +import arq +import redis.connection as conn +import asyncio +from arq.worker import Retry, Worker, func + + +async def test_async_redis_client(): + print("Testing Async Redis Client") + arc = AsyncRedisCluster( + host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", + port="6379", + decode_responses=True, + ) + print("Got here") + arc = await arc.initialize() + print("Got here 2") + print(arc.get_nodes()) + + + +def arq_from_settings() -> arq.connections.RedisSettings: + """Return arq RedisSettings from a settings section""" + return arq.connections.RedisSettings( + host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", + port="6379", + conn_timeout=5 + + ) + + +_arq_pool: arq.ArqRedis | None = None +worker_: Worker = None + +async def open_arq_pool() -> arq.ArqRedis: + """Opens a shared ArqRedis pool for this process""" + global _arq_pool + if not _arq_pool: + _arq_pool = await arq.create_pool(arq_from_settings()) + await _arq_pool.__aenter__() + return _arq_pool + + +async def close_arq_pool() -> None: + """Closes the shared ArqRedis pool for this process""" + if _arq_pool: + await _arq_pool.__aexit__(None, None, None) + + +async def arq_pool() -> arq.ArqRedis: + if not _arq_pool: + raise Exception("The global pool was not opened for this process") + return _arq_pool + + +async def get_queued_jobs_ids(arq_pool: arq.ArqRedis, queue_name: str) -> set[str]: + return {job_id.decode() for job_id in await arq_pool.zrange(queue_name, 0, -1)} + + +async def create_worker(arq_redis:arq.ArqRedis, functions=[], burst=True, poll_delay=0, max_jobs=10, **kwargs): + global worker_ + worker_ = Worker( + functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs, **kwargs + ) + return worker_ + + +async def qj(): + """Schedule an arq task to remove the access grant from the database at the time of expiration.""" + await open_arq_pool() + arq = await arq_pool() + async def foobar(ctx): + return 42 + + j = await arq.enqueue_job('foobar') + worker: Worker = await create_worker(arq,functions=[func(foobar, name='foobar')]) + await worker.main() + r = await j.result(poll_delay=0) + print(r) + + +if __name__ == "__main__": + + asyncio.run(test_async_redis_client()) + asyncio.run(qj()) + \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 755aeec6..c821997a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import msgpack import pytest -from redislite import Redis + from arq.connections import ArqRedis, create_pool from arq.worker import Worker @@ -18,38 +18,32 @@ def _fix_loop(event_loop): @pytest.fixture async def arq_redis(loop): - redis_ = ArqRedis( - host='localhost', - port=6379, - encoding='utf-8', + redis_ = ArqRedis( + host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", + port="6379", + decode_responses=True, ) - await redis_.flushall() - yield redis_ - await redis_.close(close_connection_pool=True) + await redis_.close() -@pytest.fixture -async def unix_socket_path(loop, tmp_path): - rdb = Redis(str(tmp_path / 'redis_test.db')) - yield rdb.socket_file - rdb.close() @pytest.fixture async def arq_redis_msgpack(loop): - redis_ = ArqRedis( - host='localhost', - port=6379, + redis_ = await ArqRedis( + host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", + port="6379", + decode_responses=True, encoding='utf-8', job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False), ) - await redis_.flushall() + yield redis_ - await redis_.close(close_connection_pool=True) + await redis_.close() @pytest.fixture @@ -80,7 +74,7 @@ async def create_pool_(settings, *args, **kwargs): yield create_pool_ - await asyncio.gather(*[p.close(close_connection_pool=True) for p in pools]) + await asyncio.gather(*[await p.close() for p in pools]) @pytest.fixture(name='cancel_remaining_task') diff --git a/tests/test_main.py b/tests/test_main.py index 7c3a9835..af4e5ac8 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -29,256 +29,256 @@ async def foobar(ctx): assert r == 42 # 1 -async def test_enqueue_job_different_queues(arq_redis: ArqRedis, worker): - async def foobar(ctx): - return 42 - - j1 = await arq_redis.enqueue_job('foobar', _queue_name='arq:queue1') - j2 = await arq_redis.enqueue_job('foobar', _queue_name='arq:queue2') - worker1: Worker = worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue1') - worker2: Worker = worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue2') - - await worker1.main() - await worker2.main() - r1 = await j1.result(poll_delay=0) - r2 = await j2.result(poll_delay=0) - assert r1 == 42 # 1 - assert r2 == 42 # 2 +# async def test_enqueue_job_different_queues(arq_redis: ArqRedis, worker): +# async def foobar(ctx): +# return 42 + +# j1 = await arq_redis.enqueue_job('foobar', _queue_name='arq:queue1') +# j2 = await arq_redis.enqueue_job('foobar', _queue_name='arq:queue2') +# worker1: Worker = worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue1') +# worker2: Worker = worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue2') +# await worker1.main() +# await worker2.main() +# r1 = await j1.result(poll_delay=0) +# r2 = await j2.result(poll_delay=0) +# assert r1 == 42 # 1 +# assert r2 == 42 # 2 -async def test_enqueue_job_nested(arq_redis: ArqRedis, worker): - async def foobar(ctx): - return 42 - async def parent_job(ctx): - inner_job = await ctx['redis'].enqueue_job('foobar') - return inner_job.job_id +# async def test_enqueue_job_nested(arq_redis: ArqRedis, worker): +# async def foobar(ctx): +# return 42 + +# async def parent_job(ctx): +# inner_job = await ctx['redis'].enqueue_job('foobar') +# return inner_job.job_id - job = await arq_redis.enqueue_job('parent_job') - worker: Worker = worker(functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')]) - - await worker.main() - result = await job.result(poll_delay=0) - assert result is not None - inner_job = Job(result, arq_redis) - inner_result = await inner_job.result(poll_delay=0) - assert inner_result == 42 +# job = await arq_redis.enqueue_job('parent_job') +# worker: Worker = worker(functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')]) + +# await worker.main() +# result = await job.result(poll_delay=0) +# assert result is not None +# inner_job = Job(result, arq_redis) +# inner_result = await inner_job.result(poll_delay=0) +# assert inner_result == 42 -async def test_enqueue_job_nested_custom_serializer(arq_redis_msgpack: ArqRedis, worker): - async def foobar(ctx): - return 42 - - async def parent_job(ctx): - inner_job = await ctx['redis'].enqueue_job('foobar') - return inner_job.job_id - - job = await arq_redis_msgpack.enqueue_job('parent_job') - - worker: Worker = worker( - functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], - arq_redis=None, - job_serializer=msgpack.packb, - job_deserializer=functools.partial(msgpack.unpackb, raw=False), - ) - - await worker.main() - result = await job.result(poll_delay=0) - assert result is not None - inner_job = Job(result, arq_redis_msgpack, _deserializer=functools.partial(msgpack.unpackb, raw=False)) - inner_result = await inner_job.result(poll_delay=0) - assert inner_result == 42 +# async def test_enqueue_job_nested_custom_serializer(arq_redis_msgpack: ArqRedis, worker): +# async def foobar(ctx): +# return 42 + +# async def parent_job(ctx): +# inner_job = await ctx['redis'].enqueue_job('foobar') +# return inner_job.job_id + +# job = await arq_redis_msgpack.enqueue_job('parent_job') + +# worker: Worker = worker( +# functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], +# arq_redis=None, +# job_serializer=msgpack.packb, +# job_deserializer=functools.partial(msgpack.unpackb, raw=False), +# ) + +# await worker.main() +# result = await job.result(poll_delay=0) +# assert result is not None +# inner_job = Job(result, arq_redis_msgpack, _deserializer=functools.partial(msgpack.unpackb, raw=False)) +# inner_result = await inner_job.result(poll_delay=0) +# assert inner_result == 42 + + +# async def test_enqueue_job_custom_queue(arq_redis: ArqRedis, worker): +# async def foobar(ctx): +# return 42 + +# async def parent_job(ctx): +# inner_job = await ctx['redis'].enqueue_job('foobar') +# return inner_job.job_id + +# job = await arq_redis.enqueue_job('parent_job', _queue_name='spanner') + +# worker: Worker = worker( +# functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], +# arq_redis=None, +# queue_name='spanner', +# ) + +# await worker.main() +# inner_job_id = await job.result(poll_delay=0) +# assert inner_job_id is not None +# inner_job = Job(inner_job_id, arq_redis, _queue_name='spanner') +# inner_result = await inner_job.result(poll_delay=0) +# assert inner_result == 42 + + +# async def test_job_error(arq_redis: ArqRedis, worker): +# async def foobar(ctx): +# raise RuntimeError('foobar error') + +# j = await arq_redis.enqueue_job('foobar') +# worker: Worker = worker(functions=[func(foobar, name='foobar')]) +# await worker.main() +# with pytest.raises(RuntimeError, match='foobar error'): +# await j.result(poll_delay=0) + + +# async def test_job_info(arq_redis: ArqRedis): +# t_before = time() +# j = await arq_redis.enqueue_job('foobar', 123, a=456) +# info = await j.info() +# assert info.enqueue_time == IsNow(tz='utc') +# assert info.job_try is None +# assert info.function == 'foobar' +# assert info.args == (123,) +# assert info.kwargs == {'a': 456} +# assert abs(t_before * 1000 - info.score) < 1000 -async def test_enqueue_job_custom_queue(arq_redis: ArqRedis, worker): - async def foobar(ctx): - return 42 - - async def parent_job(ctx): - inner_job = await ctx['redis'].enqueue_job('foobar') - return inner_job.job_id - - job = await arq_redis.enqueue_job('parent_job', _queue_name='spanner') - - worker: Worker = worker( - functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], - arq_redis=None, - queue_name='spanner', - ) - - await worker.main() - inner_job_id = await job.result(poll_delay=0) - assert inner_job_id is not None - inner_job = Job(inner_job_id, arq_redis, _queue_name='spanner') - inner_result = await inner_job.result(poll_delay=0) - assert inner_result == 42 - - -async def test_job_error(arq_redis: ArqRedis, worker): - async def foobar(ctx): - raise RuntimeError('foobar error') - - j = await arq_redis.enqueue_job('foobar') - worker: Worker = worker(functions=[func(foobar, name='foobar')]) - await worker.main() - with pytest.raises(RuntimeError, match='foobar error'): - await j.result(poll_delay=0) +# async def test_repeat_job(arq_redis: ArqRedis): +# j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id') +# assert isinstance(j1, Job) +# j2 = await arq_redis.enqueue_job('foobar', _job_id='job_id') +# assert j2 is None -async def test_job_info(arq_redis: ArqRedis): - t_before = time() - j = await arq_redis.enqueue_job('foobar', 123, a=456) - info = await j.info() - assert info.enqueue_time == IsNow(tz='utc') - assert info.job_try is None - assert info.function == 'foobar' - assert info.args == (123,) - assert info.kwargs == {'a': 456} - assert abs(t_before * 1000 - info.score) < 1000 - - -async def test_repeat_job(arq_redis: ArqRedis): - j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id') - assert isinstance(j1, Job) - j2 = await arq_redis.enqueue_job('foobar', _job_id='job_id') - assert j2 is None - - -async def test_defer_until(arq_redis: ArqRedis): - j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id', _defer_until=datetime(2032, 1, 1, tzinfo=timezone.utc)) - assert isinstance(j1, Job) - score = await arq_redis.zscore(default_queue_name, 'job_id') - assert score == 1_956_528_000_000 - - -async def test_defer_by(arq_redis: ArqRedis): - j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id', _defer_by=20) - assert isinstance(j1, Job) - score = await arq_redis.zscore(default_queue_name, 'job_id') - ts = timestamp_ms() - assert score > ts + 19000 - assert ts + 21000 > score - - -async def test_mung(arq_redis: ArqRedis, worker): - """ - check a job can't be enqueued multiple times with the same id - """ - counter = Counter() - - async def count(ctx, v): - counter[v] += 1 - - tasks = [] - for i in range(50): - tasks += [ - arq_redis.enqueue_job('count', i, _job_id=f'v-{i}'), - arq_redis.enqueue_job('count', i, _job_id=f'v-{i}'), - ] - shuffle(tasks) - await asyncio.gather(*tasks) - - worker: Worker = worker(functions=[func(count, name='count')]) - await worker.main() - assert counter.most_common(1)[0][1] == 1 # no job go enqueued twice - - -async def test_custom_try(arq_redis: ArqRedis, worker): - async def foobar(ctx): - return ctx['job_try'] - - j1 = await arq_redis.enqueue_job('foobar') - w: Worker = worker(functions=[func(foobar, name='foobar')]) - await w.main() - r = await j1.result(poll_delay=0) - assert r == 1 - - j2 = await arq_redis.enqueue_job('foobar', _job_try=3) - await w.main() - r = await j2.result(poll_delay=0) - assert r == 3 - - -async def test_custom_try2(arq_redis: ArqRedis, worker): - async def foobar(ctx): - if ctx['job_try'] == 3: - raise Retry() - return ctx['job_try'] - - j1 = await arq_redis.enqueue_job('foobar', _job_try=3) - w: Worker = worker(functions=[func(foobar, name='foobar')]) - await w.main() - r = await j1.result(poll_delay=0) - assert r == 4 - - -async def test_cant_pickle_arg(arq_redis: ArqRedis): - class Foobar: - def __getstate__(self): - raise TypeError("this doesn't pickle") - - with pytest.raises(SerializationError, match='unable to serialize job "foobar"'): - await arq_redis.enqueue_job('foobar', Foobar()) - - -async def test_cant_pickle_result(arq_redis: ArqRedis, worker): - class Foobar: - def __getstate__(self): - raise TypeError("this doesn't pickle") - - async def foobar(ctx): - return Foobar() - - j1 = await arq_redis.enqueue_job('foobar') - w: Worker = worker(functions=[func(foobar, name='foobar')]) - await w.main() - with pytest.raises(SerializationError, match='unable to serialize result'): - await j1.result(poll_delay=0) - - -async def test_get_jobs(arq_redis: ArqRedis): - await arq_redis.enqueue_job('foobar', a=1, b=2, c=3) - await asyncio.sleep(0.01) - await arq_redis.enqueue_job('second', 4, b=5, c=6) - await asyncio.sleep(0.01) - await arq_redis.enqueue_job('third', 7, b=8) - jobs = await arq_redis.queued_jobs() - assert [dataclasses.asdict(j) for j in jobs] == [ - { - 'function': 'foobar', - 'args': (), - 'kwargs': {'a': 1, 'b': 2, 'c': 3}, - 'job_try': None, - 'enqueue_time': IsNow(tz='utc'), - 'score': IsInt(), - }, - { - 'function': 'second', - 'args': (4,), - 'kwargs': {'b': 5, 'c': 6}, - 'job_try': None, - 'enqueue_time': IsNow(tz='utc'), - 'score': IsInt(), - }, - { - 'function': 'third', - 'args': (7,), - 'kwargs': {'b': 8}, - 'job_try': None, - 'enqueue_time': IsNow(tz='utc'), - 'score': IsInt(), - }, - ] - assert jobs[0].score < jobs[1].score < jobs[2].score - assert isinstance(jobs[0], JobDef) - assert isinstance(jobs[1], JobDef) - assert isinstance(jobs[2], JobDef) - - -async def test_enqueue_multiple(arq_redis: ArqRedis, caplog): - caplog.set_level(logging.DEBUG) - results = await asyncio.gather(*[arq_redis.enqueue_job('foobar', i, _job_id='testing') for i in range(10)]) - assert sum(r is not None for r in results) == 1 - assert sum(r is None for r in results) == 9 - assert 'WatchVariableError' not in caplog.text +# async def test_defer_until(arq_redis: ArqRedis): +# j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id', _defer_until=datetime(2032, 1, 1, tzinfo=timezone.utc)) +# assert isinstance(j1, Job) +# score = await arq_redis.zscore(default_queue_name, 'job_id') +# assert score == 1_956_528_000_000 + + +# async def test_defer_by(arq_redis: ArqRedis): +# j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id', _defer_by=20) +# assert isinstance(j1, Job) +# score = await arq_redis.zscore(default_queue_name, 'job_id') +# ts = timestamp_ms() +# assert score > ts + 19000 +# assert ts + 21000 > score + + +# async def test_mung(arq_redis: ArqRedis, worker): +# """ +# check a job can't be enqueued multiple times with the same id +# """ +# counter = Counter() + +# async def count(ctx, v): +# counter[v] += 1 + +# tasks = [] +# for i in range(50): +# tasks += [ +# arq_redis.enqueue_job('count', i, _job_id=f'v-{i}'), +# arq_redis.enqueue_job('count', i, _job_id=f'v-{i}'), +# ] +# shuffle(tasks) +# await asyncio.gather(*tasks) + +# worker: Worker = worker(functions=[func(count, name='count')]) +# await worker.main() +# assert counter.most_common(1)[0][1] == 1 # no job go enqueued twice + + +# async def test_custom_try(arq_redis: ArqRedis, worker): +# async def foobar(ctx): +# return ctx['job_try'] + +# j1 = await arq_redis.enqueue_job('foobar') +# w: Worker = worker(functions=[func(foobar, name='foobar')]) +# await w.main() +# r = await j1.result(poll_delay=0) +# assert r == 1 + +# j2 = await arq_redis.enqueue_job('foobar', _job_try=3) +# await w.main() +# r = await j2.result(poll_delay=0) +# assert r == 3 + + +# async def test_custom_try2(arq_redis: ArqRedis, worker): +# async def foobar(ctx): +# if ctx['job_try'] == 3: +# raise Retry() +# return ctx['job_try'] + +# j1 = await arq_redis.enqueue_job('foobar', _job_try=3) +# w: Worker = worker(functions=[func(foobar, name='foobar')]) +# await w.main() +# r = await j1.result(poll_delay=0) +# assert r == 4 + + +# async def test_cant_pickle_arg(arq_redis: ArqRedis): +# class Foobar: +# def __getstate__(self): +# raise TypeError("this doesn't pickle") + +# with pytest.raises(SerializationError, match='unable to serialize job "foobar"'): +# await arq_redis.enqueue_job('foobar', Foobar()) + + +# async def test_cant_pickle_result(arq_redis: ArqRedis, worker): +# class Foobar: +# def __getstate__(self): +# raise TypeError("this doesn't pickle") + +# async def foobar(ctx): +# return Foobar() + +# j1 = await arq_redis.enqueue_job('foobar') +# w: Worker = worker(functions=[func(foobar, name='foobar')]) +# await w.main() +# with pytest.raises(SerializationError, match='unable to serialize result'): +# await j1.result(poll_delay=0) + + +# async def test_get_jobs(arq_redis: ArqRedis): +# await arq_redis.enqueue_job('foobar', a=1, b=2, c=3) +# await asyncio.sleep(0.01) +# await arq_redis.enqueue_job('second', 4, b=5, c=6) +# await asyncio.sleep(0.01) +# await arq_redis.enqueue_job('third', 7, b=8) +# jobs = await arq_redis.queued_jobs() +# assert [dataclasses.asdict(j) for j in jobs] == [ +# { +# 'function': 'foobar', +# 'args': (), +# 'kwargs': {'a': 1, 'b': 2, 'c': 3}, +# 'job_try': None, +# 'enqueue_time': IsNow(tz='utc'), +# 'score': IsInt(), +# }, +# { +# 'function': 'second', +# 'args': (4,), +# 'kwargs': {'b': 5, 'c': 6}, +# 'job_try': None, +# 'enqueue_time': IsNow(tz='utc'), +# 'score': IsInt(), +# }, +# { +# 'function': 'third', +# 'args': (7,), +# 'kwargs': {'b': 8}, +# 'job_try': None, +# 'enqueue_time': IsNow(tz='utc'), +# 'score': IsInt(), +# }, +# ] +# assert jobs[0].score < jobs[1].score < jobs[2].score +# assert isinstance(jobs[0], JobDef) +# assert isinstance(jobs[1], JobDef) +# assert isinstance(jobs[2], JobDef) + + +# async def test_enqueue_multiple(arq_redis: ArqRedis, caplog): +# caplog.set_level(logging.DEBUG) +# results = await asyncio.gather(*[arq_redis.enqueue_job('foobar', i, _job_id='testing') for i in range(10)]) +# assert sum(r is not None for r in results) == 1 +# assert sum(r is None for r in results) == 9 +# assert 'WatchVariableError' not in caplog.text From 621a5bc88e917b51acb9d9e726b57e22f00e05c0 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Fri, 27 Oct 2023 09:39:26 -0500 Subject: [PATCH 04/22] jobs not starting --- arq/connections.py | 26 ++++++++++++++------------ arq/jobs.py | 5 +++-- arq/worker.py | 26 ++++++++++++++++++++------ test.py | 35 +++++++++++++++-------------------- tests/conftest.py | 8 ++++---- 5 files changed, 56 insertions(+), 44 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index 36303301..8cfb2c9b 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -17,7 +17,7 @@ from .utils import timestamp_ms, to_ms, to_unix_ms logger = logging.getLogger('arq.connections') - +logging.basicConfig(level=logging.DEBUG) @dataclass class RedisSettings: @@ -27,9 +27,9 @@ class RedisSettings: Used by :func:`arq.connections.create_pool` and :class:`arq.worker.Worker`. """ - host: Union[str, List[Tuple[str, int]]] = 'localhost' + host: Union[str, List[Tuple[str, int]]] = 'test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com' port: int = 6379 - + username: Optional[str] = None password: Optional[str] = None ssl: bool = False @@ -50,8 +50,8 @@ class RedisSettings: def from_dsn(cls, dsn: str) -> 'RedisSettings': conf = urlparse(dsn) assert conf.scheme in {'redis', 'rediss', 'unix'}, 'invalid DSN scheme' - - + + return RedisSettings( host=conf.hostname or 'localhost', port=conf.port or 6379, @@ -135,10 +135,9 @@ async def enqueue_job( defer_by_ms = to_ms(_defer_by) expires_ms = to_ms(_expires) - + async with self.pipeline() as pipe: - logger.debug("insides pipeline---------------------------") - + logger.debug("insides pipeline Enq Job---------------------------") if await pipe.exists(job_key, result_key_prefix + job_id): await pipe.reset() return None @@ -154,15 +153,18 @@ async def enqueue_job( expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer) - + pipe.psetex(job_key, expires_ms, job) # type: ignore[no-untyped-call] pipe.zadd(_queue_name, {job_id: score}) # type: ignore[unused-coroutine] try: + logger.debug("Executing Enq Job---------------------------") await pipe.execute() except WatchError: # job got enqueued since we checked 'job_exists' return None - return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer) + the_job = Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer) + logger.debug(the_job) + return the_job async def _get_job_result(self, key: bytes) -> JobResult: job_id = key[len(result_key_prefix) :].decode() @@ -254,8 +256,8 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: pool.job_deserializer = job_deserializer pool.default_queue_name = default_queue_name pool.expires_extra_ms = expires_extra_ms - - + + except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e: if retry < settings.conn_retries: diff --git a/arq/jobs.py b/arq/jobs.py index cce65a9d..adec8b74 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -13,7 +13,7 @@ from .utils import ms_to_datetime, poll, timestamp_ms logger = logging.getLogger('arq.jobs') - +logging.basicConfig(level=logging.DEBUG) Serializer = Callable[[Dict[str, Any]], bytes] Deserializer = Callable[[bytes], Dict[str, Any]] @@ -106,9 +106,10 @@ async def result( tr.get(result_key_prefix + self.job_id) # type: ignore[unused-coroutine] tr.zscore(self._queue_name, self.job_id) # type: ignore[unused-coroutine] v, s = await tr.execute() - + if v: info = deserialize_result(v, deserializer=self._deserializer) + print(info) if info.success: return info.result elif isinstance(info.result, (Exception, asyncio.CancelledError)): diff --git a/arq/worker.py b/arq/worker.py index a62009ad..cbeb3ca3 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -44,6 +44,7 @@ from .typing import SecondsTimedelta, StartupShutdown, WorkerCoroutine, WorkerSettingsType # noqa F401 logger = logging.getLogger('arq.worker') +logging.basicConfig(level=logging.DEBUG) no_result = object() @@ -358,6 +359,7 @@ async def main(self) -> None: await asyncio.gather(*self.tasks.values()) return None queued_jobs = await self.pool.zcard(self.queue_name) + if queued_jobs == 0: await asyncio.gather(*self.tasks.values()) return None @@ -379,7 +381,7 @@ async def _poll_iteration(self) -> None: job_ids = await self.pool.zrangebyscore( self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now ) - + await self.start_jobs(job_ids) if self.allow_abort_jobs: @@ -427,21 +429,31 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: await self.sem.acquire() job_id = job_id_b.decode() in_progress_key = in_progress_key_prefix + job_id + scores = await self.pool.zscore(self.queue_name, job_id) + print(scores) async with self.pool.pipeline() as pipe: - await pipe.watch(in_progress_key) - ongoing_exists = await pipe.exists(in_progress_key) + + score = await pipe.zscore(self.queue_name, job_id) - if ongoing_exists or not score: + + + if not score: + + + print(f"score is {bool(score)} queue {self.queue_name} didn't have job {job_id}") # job already started elsewhere, or already finished and removed from queue self.sem.release() - logger.debug('job %s already running elsewhere', job_id) + # logger.debug('job %s already running elsewhere', job_id) continue - pipe.multi() + pipe.psetex( # type: ignore[no-untyped-call] in_progress_key, int(self.in_progress_timeout_s * 1000), b'1' ) + try: + + print("we triedeeeeeeeeeeeeeeee") await pipe.execute() except (ResponseError, WatchError): # job already started elsewhere since we got 'existing' @@ -461,8 +473,10 @@ async def run_job(self, job_id: str, score: int) -> None: # noqa: C901 if self.allow_abort_jobs: pipe.zrem(abort_jobs_ss, job_id) # type: ignore[unused-coroutine] v, job_try, _, abort_job = await pipe.execute() + print(f"v in worker l-471 {v}") else: v, job_try, _ = await pipe.execute() + print(f"v in worker l-473 {v}") abort_job = False function_name, enqueue_time_ms = '', 0 diff --git a/test.py b/test.py index 8fe226e5..c7bba342 100644 --- a/test.py +++ b/test.py @@ -1,32 +1,21 @@ from redis.cluster import RedisCluster, ClusterNode from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster -import arq +import arq import redis.connection as conn import asyncio from arq.worker import Retry, Worker, func -async def test_async_redis_client(): - print("Testing Async Redis Client") - arc = AsyncRedisCluster( - host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", - port="6379", - decode_responses=True, - ) - print("Got here") - arc = await arc.initialize() - print("Got here 2") - print(arc.get_nodes()) def arq_from_settings() -> arq.connections.RedisSettings: """Return arq RedisSettings from a settings section""" return arq.connections.RedisSettings( - host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", + host="test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com", port="6379", conn_timeout=5 - + ) @@ -58,30 +47,36 @@ async def get_queued_jobs_ids(arq_pool: arq.ArqRedis, queue_name: str) -> set[st return {job_id.decode() for job_id in await arq_pool.zrange(queue_name, 0, -1)} +def print_job(): + print("job started") + async def create_worker(arq_redis:arq.ArqRedis, functions=[], burst=True, poll_delay=0, max_jobs=10, **kwargs): global worker_ worker_ = Worker( - functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs, **kwargs + functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs,on_job_start=print_job, **kwargs ) return worker_ + async def qj(): """Schedule an arq task to remove the access grant from the database at the time of expiration.""" await open_arq_pool() arq = await arq_pool() + async def foobar(ctx): return 42 j = await arq.enqueue_job('foobar') - worker: Worker = await create_worker(arq,functions=[func(foobar, name='foobar')]) + + worker: Worker = await create_worker(arq,functions=[func(foobar, name='foobar')],) await worker.main() r = await j.result(poll_delay=0) print(r) -if __name__ == "__main__": - - asyncio.run(test_async_redis_client()) +if __name__ == "__main__": + + asyncio.run(qj()) - \ No newline at end of file + diff --git a/tests/conftest.py b/tests/conftest.py index c821997a..79989a6a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,9 +19,9 @@ def _fix_loop(event_loop): @pytest.fixture async def arq_redis(loop): redis_ = ArqRedis( - host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", + host="test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com", port="6379", - decode_responses=True, + ) yield redis_ @@ -34,9 +34,9 @@ async def arq_redis(loop): @pytest.fixture async def arq_redis_msgpack(loop): redis_ = await ArqRedis( - host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", + host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", port="6379", - decode_responses=True, + encoding='utf-8', job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False), From ebc1e709198f18cc065ace8c76aabe258e92ba94 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Mon, 6 Nov 2023 12:21:04 -0600 Subject: [PATCH 05/22] mvp --- arq/connections.py | 114 +++++++++-- arq/jobs.py | 10 +- arq/worker.py | 34 ++-- test.py | 8 +- tests/conftest.py | 15 +- tests/test_main.py | 484 ++++++++++++++++++++++----------------------- 6 files changed, 364 insertions(+), 301 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index 8cfb2c9b..d9b35b10 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -4,14 +4,16 @@ from dataclasses import dataclass from datetime import datetime, timedelta from operator import attrgetter -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, TypeVar, Union from urllib.parse import parse_qs, urlparse from uuid import uuid4 -from redis.asyncio import ConnectionPool +from redis.asyncio import ConnectionPool, Redis +from redis.asyncio.cluster import ClusterPipeline, PipelineCommand, RedisCluster from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError -from redis.asyncio.cluster import RedisCluster +from redis.typing import EncodableT, KeyT + from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job from .utils import timestamp_ms, to_ms, to_unix_ms @@ -19,6 +21,10 @@ logger = logging.getLogger('arq.connections') logging.basicConfig(level=logging.DEBUG) + +_KeyT = TypeVar('_KeyT', bound=KeyT) + + @dataclass class RedisSettings: """ @@ -29,7 +35,8 @@ class RedisSettings: host: Union[str, List[Tuple[str, int]]] = 'test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com' port: int = 6379 - + unix_socket_path: Optional[str] = None + database: int = 0 username: Optional[str] = None password: Optional[str] = None ssl: bool = False @@ -42,7 +49,7 @@ class RedisSettings: conn_timeout: int = 1 conn_retries: int = 5 conn_retry_delay: int = 1 - cluster_mode: bool = False + cluster_mode: bool = True sentinel: bool = False sentinel_master: str = 'mymaster' @@ -51,7 +58,6 @@ def from_dsn(cls, dsn: str) -> 'RedisSettings': conf = urlparse(dsn) assert conf.scheme in {'redis', 'rediss', 'unix'}, 'invalid DSN scheme' - return RedisSettings( host=conf.hostname or 'localhost', port=conf.port or 6379, @@ -64,9 +70,9 @@ def __repr__(self) -> str: if TYPE_CHECKING: - BaseRedis = RedisCluster[bytes] + BaseRedis = Redis[bytes] else: - BaseRedis = RedisCluster + BaseRedis = Redis class ArqRedis(BaseRedis): @@ -135,9 +141,8 @@ async def enqueue_job( defer_by_ms = to_ms(_defer_by) expires_ms = to_ms(_expires) - - async with self.pipeline() as pipe: - logger.debug("insides pipeline Enq Job---------------------------") + async with self.pipeline(transaction=True) as pipe: + await pipe.watch(job_key) if await pipe.exists(job_key, result_key_prefix + job_id): await pipe.reset() return None @@ -153,11 +158,10 @@ async def enqueue_job( expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer) - + pipe.multi() pipe.psetex(job_key, expires_ms, job) # type: ignore[no-untyped-call] pipe.zadd(_queue_name, {job_id: score}) # type: ignore[unused-coroutine] try: - logger.debug("Executing Enq Job---------------------------") await pipe.execute() except WatchError: # job got enqueued since we checked 'job_exists' @@ -201,6 +205,64 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef] return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs]) +class ArqRedisCluster(RedisCluster): + def __init__( + self, + job_serializer: Optional[Serializer] = None, + job_deserializer: Optional[Deserializer] = None, + default_queue_name: str = default_queue_name, + expires_extra_ms: int = expires_extra_ms, + **kwargs: Any, + ) -> None: + self.job_serializer = job_serializer + self.job_deserializer = job_deserializer + self.default_queue_name = default_queue_name + self.expires_extra_ms = expires_extra_ms + super().__init__(**kwargs) + + enqueue_job = ArqRedis.enqueue_job + _get_job_result = ArqRedis._get_job_result + all_job_results = ArqRedis.all_job_results + _get_job_def = ArqRedis._get_job_def + queued_jobs = ArqRedis.queued_jobs + + def pipeline(self, transaction: Any | None = None, shard_hint: Any | None = None) -> ClusterPipeline: + + return ArqRedisClusterPipeline(self) + + +class ArqRedisClusterPipeline(ClusterPipeline): + def __init__(self, client: RedisCluster) -> None: + self.watching = False + super().__init__(client) + + async def watch(self, *names: KeyT) -> None: + self.watching = True + + def multi(self) -> None: + self.watching = False + + def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> 'ClusterPipeline': + """ + Append a raw command to the pipeline. + + :param args: + | Raw command args + :param kwargs: + + - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` + or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] + - Rest of the kwargs are passed to the Redis connection + """ + cmd = PipelineCommand(len(self._command_stack), *args, **kwargs) + if self.watching: + cmd.result = self._client.execute_command(*cmd.args, **cmd.kwargs) + + return cmd.result + self._command_stack.append(cmd) + return self + + async def create_pool( settings_: RedisSettings = None, *, @@ -232,11 +294,28 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: ) return client.master_for(settings.sentinel_master, redis_class=ArqRedis) + if settings.cluster_mode: + pool_factory = functools.partial( + ArqRedisCluster, + host=settings.host, + port=settings.port, + socket_connect_timeout=settings.conn_timeout, + ssl=settings.ssl, + ssl_keyfile=settings.ssl_keyfile, + ssl_certfile=settings.ssl_certfile, + ssl_cert_reqs=settings.ssl_cert_reqs, + ssl_ca_certs=settings.ssl_ca_certs, + ssl_ca_data=settings.ssl_ca_data, + ssl_check_hostname=settings.ssl_check_hostname, + ) else: pool_factory = functools.partial( ArqRedis, + db=settings.database, + username=settings.username, host=settings.host, port=settings.port, + unix_socket_path=settings.unix_socket_path, socket_connect_timeout=settings.conn_timeout, ssl=settings.ssl, ssl_keyfile=settings.ssl_keyfile, @@ -249,16 +328,12 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: while True: try: - pool = await pool_factory( - password=settings.password, encoding='utf8' - ) + pool = await pool_factory(password=settings.password, encoding='utf8') pool.job_serializer = job_serializer pool.job_deserializer = job_deserializer pool.default_queue_name = default_queue_name pool.expires_extra_ms = expires_extra_ms - - except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e: if retry < settings.conn_retries: logger.warning( @@ -279,6 +354,7 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: return pool +# TODO async def log_redis_info(redis: 'RedisCluster[bytes]', log_func: Callable[[str], Any]) -> None: # async with redis.pipeline() as pipe: # pipe.info(section='Server') @@ -290,7 +366,7 @@ async def log_redis_info(redis: 'RedisCluster[bytes]', log_func: Callable[[str], redis_version = "info_server.get('redis_version', '?')" mem_usage = "info_memory.get('used_memory_human', '?')" - clients_connected =" info_clients.get('connected_clients', '?')" + clients_connected = " info_clients.get('connected_clients', '?')" log_func( f'redis_version={redis_version} ' diff --git a/arq/jobs.py b/arq/jobs.py index adec8b74..ce715c83 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -7,8 +7,8 @@ from enum import Enum from typing import Any, Callable, Dict, Optional, Tuple - from redis.asyncio.cluster import RedisCluster + from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix from .utils import ms_to_datetime, poll, timestamp_ms @@ -102,11 +102,11 @@ async def result( poll_delay = pole_delay async for delay in poll(poll_delay): - async with self._redis.pipeline() as tr: + async with self._redis.pipeline(transaction=True) as tr: tr.get(result_key_prefix + self.job_id) # type: ignore[unused-coroutine] tr.zscore(self._queue_name, self.job_id) # type: ignore[unused-coroutine] v, s = await tr.execute() - + if v: info = deserialize_result(v, deserializer=self._deserializer) print(info) @@ -154,7 +154,7 @@ async def status(self) -> JobStatus: """ Status of the job. """ - async with self._redis.pipeline() as tr: + async with self._redis.pipeline(transaction=True) as tr: tr.exists(result_key_prefix + self.job_id) # type: ignore[unused-coroutine] tr.exists(in_progress_key_prefix + self.job_id) # type: ignore[unused-coroutine] tr.zscore(self._queue_name, self.job_id) # type: ignore[unused-coroutine] @@ -180,7 +180,7 @@ async def abort(self, *, timeout: Optional[float] = None, poll_delay: float = 0. """ job_info = await self.info() if job_info and job_info.score and job_info.score > timestamp_ms(): - async with self._redis.pipeline() as tr: + async with self._redis.pipeline(transaction=True) as tr: tr.zrem(self._queue_name, self.job_id) # type: ignore[unused-coroutine] tr.zadd(self._queue_name, {self.job_id: 1}) # type: ignore[unused-coroutine] await tr.execute() diff --git a/arq/worker.py b/arq/worker.py index cbeb3ca3..f8a35862 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -359,7 +359,7 @@ async def main(self) -> None: await asyncio.gather(*self.tasks.values()) return None queued_jobs = await self.pool.zcard(self.queue_name) - + if queued_jobs == 0: await asyncio.gather(*self.tasks.values()) return None @@ -381,7 +381,7 @@ async def _poll_iteration(self) -> None: job_ids = await self.pool.zrangebyscore( self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now ) - + await self.start_jobs(job_ids) if self.allow_abort_jobs: @@ -399,7 +399,7 @@ async def _cancel_aborted_jobs(self) -> None: """ Go through job_ids in the abort_jobs_ss sorted set and cancel those tasks. """ - async with self.pool.pipeline() as pipe: + async with self.pool.pipeline(transaction=True) as pipe: pipe.zrange(abort_jobs_ss, start=0, end=-1) # type: ignore[unused-coroutine] pipe.zremrangebyscore( # type: ignore[unused-coroutine] abort_jobs_ss, min=timestamp_ms() + abort_job_max_age, max=float('inf') @@ -429,31 +429,25 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: await self.sem.acquire() job_id = job_id_b.decode() in_progress_key = in_progress_key_prefix + job_id - scores = await self.pool.zscore(self.queue_name, job_id) - print(scores) - async with self.pool.pipeline() as pipe: - - + + async with self.pool.pipeline(transaction=True) as pipe: + await pipe.watch(in_progress_key) + ongoing_exists = await pipe.exists(in_progress_key) score = await pipe.zscore(self.queue_name, job_id) - - - if not score: - - - print(f"score is {bool(score)} queue {self.queue_name} didn't have job {job_id}") + + if ongoing_exists or not score: # job already started elsewhere, or already finished and removed from queue self.sem.release() # logger.debug('job %s already running elsewhere', job_id) continue - + pipe.multi() pipe.psetex( # type: ignore[no-untyped-call] in_progress_key, int(self.in_progress_timeout_s * 1000), b'1' ) try: - print("we triedeeeeeeeeeeeeeeee") await pipe.execute() except (ResponseError, WatchError): # job already started elsewhere since we got 'existing' @@ -466,17 +460,15 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: async def run_job(self, job_id: str, score: int) -> None: # noqa: C901 start_ms = timestamp_ms() - async with self.pool.pipeline() as pipe: + async with self.pool.pipeline(transaction=True) as pipe: pipe.get(job_key_prefix + job_id) # type: ignore[unused-coroutine] pipe.incr(retry_key_prefix + job_id) # type: ignore[unused-coroutine] pipe.expire(retry_key_prefix + job_id, 88400) # type: ignore[unused-coroutine] if self.allow_abort_jobs: pipe.zrem(abort_jobs_ss, job_id) # type: ignore[unused-coroutine] v, job_try, _, abort_job = await pipe.execute() - print(f"v in worker l-471 {v}") else: v, job_try, _ = await pipe.execute() - print(f"v in worker l-473 {v}") abort_job = False function_name, enqueue_time_ms = '', 0 @@ -677,7 +669,7 @@ async def finish_job( incr_score: Optional[int], keep_in_progress: Optional[float], ) -> None: - async with self.pool.pipeline() as tr: + async with self.pool.pipeline(transaction=True) as tr: delete_keys = [] in_progress_key = in_progress_key_prefix + job_id if keep_in_progress is None: @@ -699,7 +691,7 @@ async def finish_job( await tr.execute() async def finish_failed_job(self, job_id: str, result_data: Optional[bytes]) -> None: - async with self.pool.pipeline() as tr: + async with self.pool.pipeline(transaction=True) as tr: tr.delete( # type: ignore[unused-coroutine] retry_key_prefix + job_id, in_progress_key_prefix + job_id, diff --git a/test.py b/test.py index c7bba342..d404610a 100644 --- a/test.py +++ b/test.py @@ -14,7 +14,8 @@ def arq_from_settings() -> arq.connections.RedisSettings: return arq.connections.RedisSettings( host="test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com", port="6379", - conn_timeout=5 + conn_timeout=5, + cluster_mode=True ) @@ -53,7 +54,7 @@ def print_job(): async def create_worker(arq_redis:arq.ArqRedis, functions=[], burst=True, poll_delay=0, max_jobs=10, **kwargs): global worker_ worker_ = Worker( - functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs,on_job_start=print_job, **kwargs + functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs, **kwargs ) return worker_ @@ -68,7 +69,7 @@ async def foobar(ctx): return 42 j = await arq.enqueue_job('foobar') - + worker: Worker = await create_worker(arq,functions=[func(foobar, name='foobar')],) await worker.main() r = await j.result(poll_delay=0) @@ -79,4 +80,3 @@ async def foobar(ctx): asyncio.run(qj()) - diff --git a/tests/conftest.py b/tests/conftest.py index 79989a6a..c889a1ad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,6 @@ import msgpack import pytest - from arq.connections import ArqRedis, create_pool from arq.worker import Worker @@ -18,10 +17,9 @@ def _fix_loop(event_loop): @pytest.fixture async def arq_redis(loop): - redis_ = ArqRedis( - host="test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com", - port="6379", - + redis_ = ArqRedis( + host='localhost', + port='6379', ) yield redis_ @@ -29,14 +27,11 @@ async def arq_redis(loop): await redis_.close() - - @pytest.fixture async def arq_redis_msgpack(loop): redis_ = await ArqRedis( - host="tf-rep-group-1.48tzwx.clustercfg.use2.cache.amazonaws.com", - port="6379", - + host='localhost', + port='6379', encoding='utf-8', job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False), diff --git a/tests/test_main.py b/tests/test_main.py index af4e5ac8..7c3a9835 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -29,256 +29,256 @@ async def foobar(ctx): assert r == 42 # 1 -# async def test_enqueue_job_different_queues(arq_redis: ArqRedis, worker): -# async def foobar(ctx): -# return 42 - -# j1 = await arq_redis.enqueue_job('foobar', _queue_name='arq:queue1') -# j2 = await arq_redis.enqueue_job('foobar', _queue_name='arq:queue2') -# worker1: Worker = worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue1') -# worker2: Worker = worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue2') +async def test_enqueue_job_different_queues(arq_redis: ArqRedis, worker): + async def foobar(ctx): + return 42 + + j1 = await arq_redis.enqueue_job('foobar', _queue_name='arq:queue1') + j2 = await arq_redis.enqueue_job('foobar', _queue_name='arq:queue2') + worker1: Worker = worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue1') + worker2: Worker = worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue2') + + await worker1.main() + await worker2.main() + r1 = await j1.result(poll_delay=0) + r2 = await j2.result(poll_delay=0) + assert r1 == 42 # 1 + assert r2 == 42 # 2 -# await worker1.main() -# await worker2.main() -# r1 = await j1.result(poll_delay=0) -# r2 = await j2.result(poll_delay=0) -# assert r1 == 42 # 1 -# assert r2 == 42 # 2 +async def test_enqueue_job_nested(arq_redis: ArqRedis, worker): + async def foobar(ctx): + return 42 -# async def test_enqueue_job_nested(arq_redis: ArqRedis, worker): -# async def foobar(ctx): -# return 42 - -# async def parent_job(ctx): -# inner_job = await ctx['redis'].enqueue_job('foobar') -# return inner_job.job_id + async def parent_job(ctx): + inner_job = await ctx['redis'].enqueue_job('foobar') + return inner_job.job_id -# job = await arq_redis.enqueue_job('parent_job') -# worker: Worker = worker(functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')]) - -# await worker.main() -# result = await job.result(poll_delay=0) -# assert result is not None -# inner_job = Job(result, arq_redis) -# inner_result = await inner_job.result(poll_delay=0) -# assert inner_result == 42 + job = await arq_redis.enqueue_job('parent_job') + worker: Worker = worker(functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')]) + + await worker.main() + result = await job.result(poll_delay=0) + assert result is not None + inner_job = Job(result, arq_redis) + inner_result = await inner_job.result(poll_delay=0) + assert inner_result == 42 -# async def test_enqueue_job_nested_custom_serializer(arq_redis_msgpack: ArqRedis, worker): -# async def foobar(ctx): -# return 42 - -# async def parent_job(ctx): -# inner_job = await ctx['redis'].enqueue_job('foobar') -# return inner_job.job_id - -# job = await arq_redis_msgpack.enqueue_job('parent_job') - -# worker: Worker = worker( -# functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], -# arq_redis=None, -# job_serializer=msgpack.packb, -# job_deserializer=functools.partial(msgpack.unpackb, raw=False), -# ) - -# await worker.main() -# result = await job.result(poll_delay=0) -# assert result is not None -# inner_job = Job(result, arq_redis_msgpack, _deserializer=functools.partial(msgpack.unpackb, raw=False)) -# inner_result = await inner_job.result(poll_delay=0) -# assert inner_result == 42 - - -# async def test_enqueue_job_custom_queue(arq_redis: ArqRedis, worker): -# async def foobar(ctx): -# return 42 - -# async def parent_job(ctx): -# inner_job = await ctx['redis'].enqueue_job('foobar') -# return inner_job.job_id - -# job = await arq_redis.enqueue_job('parent_job', _queue_name='spanner') - -# worker: Worker = worker( -# functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], -# arq_redis=None, -# queue_name='spanner', -# ) - -# await worker.main() -# inner_job_id = await job.result(poll_delay=0) -# assert inner_job_id is not None -# inner_job = Job(inner_job_id, arq_redis, _queue_name='spanner') -# inner_result = await inner_job.result(poll_delay=0) -# assert inner_result == 42 - - -# async def test_job_error(arq_redis: ArqRedis, worker): -# async def foobar(ctx): -# raise RuntimeError('foobar error') - -# j = await arq_redis.enqueue_job('foobar') -# worker: Worker = worker(functions=[func(foobar, name='foobar')]) -# await worker.main() +async def test_enqueue_job_nested_custom_serializer(arq_redis_msgpack: ArqRedis, worker): + async def foobar(ctx): + return 42 + + async def parent_job(ctx): + inner_job = await ctx['redis'].enqueue_job('foobar') + return inner_job.job_id + + job = await arq_redis_msgpack.enqueue_job('parent_job') + + worker: Worker = worker( + functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], + arq_redis=None, + job_serializer=msgpack.packb, + job_deserializer=functools.partial(msgpack.unpackb, raw=False), + ) + + await worker.main() + result = await job.result(poll_delay=0) + assert result is not None + inner_job = Job(result, arq_redis_msgpack, _deserializer=functools.partial(msgpack.unpackb, raw=False)) + inner_result = await inner_job.result(poll_delay=0) + assert inner_result == 42 -# with pytest.raises(RuntimeError, match='foobar error'): -# await j.result(poll_delay=0) - - -# async def test_job_info(arq_redis: ArqRedis): -# t_before = time() -# j = await arq_redis.enqueue_job('foobar', 123, a=456) -# info = await j.info() -# assert info.enqueue_time == IsNow(tz='utc') -# assert info.job_try is None -# assert info.function == 'foobar' -# assert info.args == (123,) -# assert info.kwargs == {'a': 456} -# assert abs(t_before * 1000 - info.score) < 1000 +async def test_enqueue_job_custom_queue(arq_redis: ArqRedis, worker): + async def foobar(ctx): + return 42 + + async def parent_job(ctx): + inner_job = await ctx['redis'].enqueue_job('foobar') + return inner_job.job_id + + job = await arq_redis.enqueue_job('parent_job', _queue_name='spanner') + + worker: Worker = worker( + functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], + arq_redis=None, + queue_name='spanner', + ) + + await worker.main() + inner_job_id = await job.result(poll_delay=0) + assert inner_job_id is not None + inner_job = Job(inner_job_id, arq_redis, _queue_name='spanner') + inner_result = await inner_job.result(poll_delay=0) + assert inner_result == 42 + + +async def test_job_error(arq_redis: ArqRedis, worker): + async def foobar(ctx): + raise RuntimeError('foobar error') + + j = await arq_redis.enqueue_job('foobar') + worker: Worker = worker(functions=[func(foobar, name='foobar')]) + await worker.main() -# async def test_repeat_job(arq_redis: ArqRedis): -# j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id') -# assert isinstance(j1, Job) -# j2 = await arq_redis.enqueue_job('foobar', _job_id='job_id') -# assert j2 is None + with pytest.raises(RuntimeError, match='foobar error'): + await j.result(poll_delay=0) -# async def test_defer_until(arq_redis: ArqRedis): -# j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id', _defer_until=datetime(2032, 1, 1, tzinfo=timezone.utc)) -# assert isinstance(j1, Job) -# score = await arq_redis.zscore(default_queue_name, 'job_id') -# assert score == 1_956_528_000_000 - - -# async def test_defer_by(arq_redis: ArqRedis): -# j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id', _defer_by=20) -# assert isinstance(j1, Job) -# score = await arq_redis.zscore(default_queue_name, 'job_id') -# ts = timestamp_ms() -# assert score > ts + 19000 -# assert ts + 21000 > score - - -# async def test_mung(arq_redis: ArqRedis, worker): -# """ -# check a job can't be enqueued multiple times with the same id -# """ -# counter = Counter() - -# async def count(ctx, v): -# counter[v] += 1 - -# tasks = [] -# for i in range(50): -# tasks += [ -# arq_redis.enqueue_job('count', i, _job_id=f'v-{i}'), -# arq_redis.enqueue_job('count', i, _job_id=f'v-{i}'), -# ] -# shuffle(tasks) -# await asyncio.gather(*tasks) - -# worker: Worker = worker(functions=[func(count, name='count')]) -# await worker.main() -# assert counter.most_common(1)[0][1] == 1 # no job go enqueued twice - - -# async def test_custom_try(arq_redis: ArqRedis, worker): -# async def foobar(ctx): -# return ctx['job_try'] - -# j1 = await arq_redis.enqueue_job('foobar') -# w: Worker = worker(functions=[func(foobar, name='foobar')]) -# await w.main() -# r = await j1.result(poll_delay=0) -# assert r == 1 - -# j2 = await arq_redis.enqueue_job('foobar', _job_try=3) -# await w.main() -# r = await j2.result(poll_delay=0) -# assert r == 3 - - -# async def test_custom_try2(arq_redis: ArqRedis, worker): -# async def foobar(ctx): -# if ctx['job_try'] == 3: -# raise Retry() -# return ctx['job_try'] - -# j1 = await arq_redis.enqueue_job('foobar', _job_try=3) -# w: Worker = worker(functions=[func(foobar, name='foobar')]) -# await w.main() -# r = await j1.result(poll_delay=0) -# assert r == 4 - - -# async def test_cant_pickle_arg(arq_redis: ArqRedis): -# class Foobar: -# def __getstate__(self): -# raise TypeError("this doesn't pickle") - -# with pytest.raises(SerializationError, match='unable to serialize job "foobar"'): -# await arq_redis.enqueue_job('foobar', Foobar()) - - -# async def test_cant_pickle_result(arq_redis: ArqRedis, worker): -# class Foobar: -# def __getstate__(self): -# raise TypeError("this doesn't pickle") - -# async def foobar(ctx): -# return Foobar() - -# j1 = await arq_redis.enqueue_job('foobar') -# w: Worker = worker(functions=[func(foobar, name='foobar')]) -# await w.main() -# with pytest.raises(SerializationError, match='unable to serialize result'): -# await j1.result(poll_delay=0) - - -# async def test_get_jobs(arq_redis: ArqRedis): -# await arq_redis.enqueue_job('foobar', a=1, b=2, c=3) -# await asyncio.sleep(0.01) -# await arq_redis.enqueue_job('second', 4, b=5, c=6) -# await asyncio.sleep(0.01) -# await arq_redis.enqueue_job('third', 7, b=8) -# jobs = await arq_redis.queued_jobs() -# assert [dataclasses.asdict(j) for j in jobs] == [ -# { -# 'function': 'foobar', -# 'args': (), -# 'kwargs': {'a': 1, 'b': 2, 'c': 3}, -# 'job_try': None, -# 'enqueue_time': IsNow(tz='utc'), -# 'score': IsInt(), -# }, -# { -# 'function': 'second', -# 'args': (4,), -# 'kwargs': {'b': 5, 'c': 6}, -# 'job_try': None, -# 'enqueue_time': IsNow(tz='utc'), -# 'score': IsInt(), -# }, -# { -# 'function': 'third', -# 'args': (7,), -# 'kwargs': {'b': 8}, -# 'job_try': None, -# 'enqueue_time': IsNow(tz='utc'), -# 'score': IsInt(), -# }, -# ] -# assert jobs[0].score < jobs[1].score < jobs[2].score -# assert isinstance(jobs[0], JobDef) -# assert isinstance(jobs[1], JobDef) -# assert isinstance(jobs[2], JobDef) - - -# async def test_enqueue_multiple(arq_redis: ArqRedis, caplog): -# caplog.set_level(logging.DEBUG) -# results = await asyncio.gather(*[arq_redis.enqueue_job('foobar', i, _job_id='testing') for i in range(10)]) -# assert sum(r is not None for r in results) == 1 -# assert sum(r is None for r in results) == 9 -# assert 'WatchVariableError' not in caplog.text +async def test_job_info(arq_redis: ArqRedis): + t_before = time() + j = await arq_redis.enqueue_job('foobar', 123, a=456) + info = await j.info() + assert info.enqueue_time == IsNow(tz='utc') + assert info.job_try is None + assert info.function == 'foobar' + assert info.args == (123,) + assert info.kwargs == {'a': 456} + assert abs(t_before * 1000 - info.score) < 1000 + + +async def test_repeat_job(arq_redis: ArqRedis): + j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id') + assert isinstance(j1, Job) + j2 = await arq_redis.enqueue_job('foobar', _job_id='job_id') + assert j2 is None + + +async def test_defer_until(arq_redis: ArqRedis): + j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id', _defer_until=datetime(2032, 1, 1, tzinfo=timezone.utc)) + assert isinstance(j1, Job) + score = await arq_redis.zscore(default_queue_name, 'job_id') + assert score == 1_956_528_000_000 + + +async def test_defer_by(arq_redis: ArqRedis): + j1 = await arq_redis.enqueue_job('foobar', _job_id='job_id', _defer_by=20) + assert isinstance(j1, Job) + score = await arq_redis.zscore(default_queue_name, 'job_id') + ts = timestamp_ms() + assert score > ts + 19000 + assert ts + 21000 > score + + +async def test_mung(arq_redis: ArqRedis, worker): + """ + check a job can't be enqueued multiple times with the same id + """ + counter = Counter() + + async def count(ctx, v): + counter[v] += 1 + + tasks = [] + for i in range(50): + tasks += [ + arq_redis.enqueue_job('count', i, _job_id=f'v-{i}'), + arq_redis.enqueue_job('count', i, _job_id=f'v-{i}'), + ] + shuffle(tasks) + await asyncio.gather(*tasks) + + worker: Worker = worker(functions=[func(count, name='count')]) + await worker.main() + assert counter.most_common(1)[0][1] == 1 # no job go enqueued twice + + +async def test_custom_try(arq_redis: ArqRedis, worker): + async def foobar(ctx): + return ctx['job_try'] + + j1 = await arq_redis.enqueue_job('foobar') + w: Worker = worker(functions=[func(foobar, name='foobar')]) + await w.main() + r = await j1.result(poll_delay=0) + assert r == 1 + + j2 = await arq_redis.enqueue_job('foobar', _job_try=3) + await w.main() + r = await j2.result(poll_delay=0) + assert r == 3 + + +async def test_custom_try2(arq_redis: ArqRedis, worker): + async def foobar(ctx): + if ctx['job_try'] == 3: + raise Retry() + return ctx['job_try'] + + j1 = await arq_redis.enqueue_job('foobar', _job_try=3) + w: Worker = worker(functions=[func(foobar, name='foobar')]) + await w.main() + r = await j1.result(poll_delay=0) + assert r == 4 + + +async def test_cant_pickle_arg(arq_redis: ArqRedis): + class Foobar: + def __getstate__(self): + raise TypeError("this doesn't pickle") + + with pytest.raises(SerializationError, match='unable to serialize job "foobar"'): + await arq_redis.enqueue_job('foobar', Foobar()) + + +async def test_cant_pickle_result(arq_redis: ArqRedis, worker): + class Foobar: + def __getstate__(self): + raise TypeError("this doesn't pickle") + + async def foobar(ctx): + return Foobar() + + j1 = await arq_redis.enqueue_job('foobar') + w: Worker = worker(functions=[func(foobar, name='foobar')]) + await w.main() + with pytest.raises(SerializationError, match='unable to serialize result'): + await j1.result(poll_delay=0) + + +async def test_get_jobs(arq_redis: ArqRedis): + await arq_redis.enqueue_job('foobar', a=1, b=2, c=3) + await asyncio.sleep(0.01) + await arq_redis.enqueue_job('second', 4, b=5, c=6) + await asyncio.sleep(0.01) + await arq_redis.enqueue_job('third', 7, b=8) + jobs = await arq_redis.queued_jobs() + assert [dataclasses.asdict(j) for j in jobs] == [ + { + 'function': 'foobar', + 'args': (), + 'kwargs': {'a': 1, 'b': 2, 'c': 3}, + 'job_try': None, + 'enqueue_time': IsNow(tz='utc'), + 'score': IsInt(), + }, + { + 'function': 'second', + 'args': (4,), + 'kwargs': {'b': 5, 'c': 6}, + 'job_try': None, + 'enqueue_time': IsNow(tz='utc'), + 'score': IsInt(), + }, + { + 'function': 'third', + 'args': (7,), + 'kwargs': {'b': 8}, + 'job_try': None, + 'enqueue_time': IsNow(tz='utc'), + 'score': IsInt(), + }, + ] + assert jobs[0].score < jobs[1].score < jobs[2].score + assert isinstance(jobs[0], JobDef) + assert isinstance(jobs[1], JobDef) + assert isinstance(jobs[2], JobDef) + + +async def test_enqueue_multiple(arq_redis: ArqRedis, caplog): + caplog.set_level(logging.DEBUG) + results = await asyncio.gather(*[arq_redis.enqueue_job('foobar', i, _job_id='testing') for i in range(10)]) + assert sum(r is not None for r in results) == 1 + assert sum(r is None for r in results) == 9 + assert 'WatchVariableError' not in caplog.text From d0c89f9447f49510ef3508b9b773869b1436fbc2 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Fri, 10 Nov 2023 12:46:28 -0600 Subject: [PATCH 06/22] main cluster tests --- arq/__init__.py | 3 +- arq/connections.py | 63 +++++----- arq/worker.py | 9 +- tests/conftest.py | 49 ++++++-- tests/test_cluster.py | 285 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 366 insertions(+), 43 deletions(-) create mode 100644 tests/test_cluster.py diff --git a/arq/__init__.py b/arq/__init__.py index d32648cd..4a82b3cb 100644 --- a/arq/__init__.py +++ b/arq/__init__.py @@ -1,4 +1,4 @@ -from .connections import ArqRedis, create_pool +from .connections import ArqRedis, ArqRedisCluster, create_pool from .cron import cron from .version import VERSION from .worker import Retry, Worker, check_health, func, run_worker @@ -7,6 +7,7 @@ __all__ = ( 'ArqRedis', + 'ArqRedisCluster', 'create_pool', 'cron', 'VERSION', diff --git a/arq/connections.py b/arq/connections.py index d9b35b10..8f7fb22f 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -57,12 +57,20 @@ class RedisSettings: def from_dsn(cls, dsn: str) -> 'RedisSettings': conf = urlparse(dsn) assert conf.scheme in {'redis', 'rediss', 'unix'}, 'invalid DSN scheme' - + query_db = parse_qs(conf.query).get('db') + if query_db: + # e.g. redis://localhost:6379?db=1 + database = int(query_db[0]) + else: + database = int(conf.path.lstrip('/')) if conf.path else 0 return RedisSettings( host=conf.hostname or 'localhost', port=conf.port or 6379, ssl=conf.scheme == 'rediss', + username=conf.username, password=conf.password, + database=database, + unix_socket_path=conf.path if conf.scheme == 'unix' else None, ) def __repr__(self) -> str: @@ -227,7 +235,6 @@ def __init__( queued_jobs = ArqRedis.queued_jobs def pipeline(self, transaction: Any | None = None, shard_hint: Any | None = None) -> ClusterPipeline: - return ArqRedisClusterPipeline(self) @@ -243,25 +250,24 @@ def multi(self) -> None: self.watching = False def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> 'ClusterPipeline': - """ - Append a raw command to the pipeline. - - :param args: - | Raw command args - :param kwargs: - - - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` - or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - - Rest of the kwargs are passed to the Redis connection - """ cmd = PipelineCommand(len(self._command_stack), *args, **kwargs) if self.watching: - cmd.result = self._client.execute_command(*cmd.args, **cmd.kwargs) - - return cmd.result + return self.immediate_execute_command(cmd) self._command_stack.append(cmd) return self + async def immediate_execute_command(self, cmd: PipelineCommand): + try: + return await self._client.execute_command(*cmd.args, **cmd.kwargs) + except Exception as e: + cmd.result = e + + def _split_command_across_slots(self, command: str, *keys: KeyT) -> 'ClusterPipeline': + for slot_keys in self._client._partition_keys_by_slot(keys).values(): + if self.watching: + return self.execute_command(command, *slot_keys) + return self + async def create_pool( settings_: RedisSettings = None, @@ -294,7 +300,7 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: ) return client.master_for(settings.sentinel_master, redis_class=ArqRedis) - if settings.cluster_mode: + elif settings.cluster_mode: pool_factory = functools.partial( ArqRedisCluster, host=settings.host, @@ -355,18 +361,17 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: # TODO -async def log_redis_info(redis: 'RedisCluster[bytes]', log_func: Callable[[str], Any]) -> None: - # async with redis.pipeline() as pipe: - # pipe.info(section='Server') - # # type: ignore[unused-coroutine] - # pipe.info(section='Memory') # type: ignore[unused-coroutine] - # pipe.info(section='Clients') # type: ignore[unused-coroutine] - # pipe.dbsize() # type: ignore[unused-coroutine] - # info_server, info_memory, info_clients, key_count = await pipe.execute() - - redis_version = "info_server.get('redis_version', '?')" - mem_usage = "info_memory.get('used_memory_human', '?')" - clients_connected = " info_clients.get('connected_clients', '?')" +async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any]) -> None: + async with redis.pipeline() as pipe: + pipe.info(section='Server') # type: ignore[unused-coroutine] + pipe.info(section='Memory') # type: ignore[unused-coroutine] + pipe.info(section='Clients') # type: ignore[unused-coroutine] + pipe.dbsize() # type: ignore[unused-coroutine] + info_server, info_memory, info_clients, key_count = await pipe.execute() + + redis_version = info_server.get('redis_version', '?') + mem_usage = info_memory.get('used_memory_human', '?') + clients_connected = info_clients.get('connected_clients', '?') log_func( f'redis_version={redis_version} ' diff --git a/arq/worker.py b/arq/worker.py index f8a35862..82232785 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -15,7 +15,7 @@ from arq.cron import CronJob from arq.jobs import Deserializer, JobResult, SerializationError, Serializer, deserialize_job_raw, serialize_result -from .connections import ArqRedis, RedisSettings, create_pool, log_redis_info +from .connections import ArqRedis, ArqRedisCluster, RedisSettings, create_pool, log_redis_info from .constants import ( abort_job_max_age, abort_jobs_ss, @@ -346,7 +346,8 @@ async def main(self) -> None: ) logger.info('Starting worker for %d functions: %s', len(self.functions), ', '.join(self.functions)) - await log_redis_info(self.pool, logger.info) + if not isinstance(self._pool, ArqRedisCluster): + await log_redis_info(self.pool, logger.info) self.ctx['redis'] = self.pool if self.on_startup: await self.on_startup(self.ctx) @@ -429,12 +430,10 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: await self.sem.acquire() job_id = job_id_b.decode() in_progress_key = in_progress_key_prefix + job_id - async with self.pool.pipeline(transaction=True) as pipe: await pipe.watch(in_progress_key) ongoing_exists = await pipe.exists(in_progress_key) score = await pipe.zscore(self.queue_name, job_id) - if ongoing_exists or not score: # job already started elsewhere, or already finished and removed from queue self.sem.release() @@ -445,9 +444,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: pipe.psetex( # type: ignore[no-untyped-call] in_progress_key, int(self.in_progress_timeout_s * 1000), b'1' ) - try: - await pipe.execute() except (ResponseError, WatchError): # job already started elsewhere since we got 'existing' diff --git a/tests/conftest.py b/tests/conftest.py index c889a1ad..88d44a4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import msgpack import pytest -from arq.connections import ArqRedis, create_pool +from arq.connections import ArqRedis, RedisSettings, create_pool from arq.worker import Worker @@ -19,26 +19,39 @@ def _fix_loop(event_loop): async def arq_redis(loop): redis_ = ArqRedis( host='localhost', - port='6379', + port=6379, + encoding='utf-8', ) + await redis_.flushall() + yield redis_ + await redis_.close(close_connection_pool=True) + + +@pytest.fixture +async def arq_redis_cluster(loop): + settings = RedisSettings(host='localhost', port='6379', conn_timeout=5, cluster_mode=True) + redis_ = await create_pool(settings) + await redis_.flushall() + + yield redis_ await redis_.close() @pytest.fixture async def arq_redis_msgpack(loop): - redis_ = await ArqRedis( + redis_ = ArqRedis( host='localhost', - port='6379', + port=6379, encoding='utf-8', job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False), ) - + await redis_.flushall() yield redis_ - await redis_.close() + await redis_.close(close_connection_pool=True) @pytest.fixture @@ -58,6 +71,28 @@ def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_re await worker_.close() +@pytest.fixture +async def cluster_worker(arq_redis_cluster): + worker_: Worker = None + + def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_redis_cluster, **kwargs): + nonlocal worker_ + worker_ = Worker( + functions=functions, + redis_pool=arq_redis_cluster, + burst=burst, + poll_delay=poll_delay, + max_jobs=max_jobs, + **kwargs, + ) + return worker_ + + yield create + + if worker_: + await worker_.close() + + @pytest.fixture(name='create_pool') async def fix_create_pool(loop): pools = [] @@ -69,7 +104,7 @@ async def create_pool_(settings, *args, **kwargs): yield create_pool_ - await asyncio.gather(*[await p.close() for p in pools]) + await asyncio.gather(*[await p.close(close_connection_pool=True) for p in pools]) @pytest.fixture(name='cancel_remaining_task') diff --git a/tests/test_cluster.py b/tests/test_cluster.py new file mode 100644 index 00000000..96e8d9a1 --- /dev/null +++ b/tests/test_cluster.py @@ -0,0 +1,285 @@ +import asyncio +import dataclasses +import logging +from collections import Counter +from datetime import datetime, timezone +from random import shuffle +from time import time + +import pytest +from dirty_equals import IsInt, IsNow + +from arq import ArqRedisCluster +from arq.constants import default_queue_name +from arq.jobs import Job, JobDef, SerializationError +from arq.utils import timestamp_ms +from arq.worker import Retry, Worker, func + + +async def test_enqueue_job(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return 42 + + j = await arq_redis_cluster.enqueue_job('foobar') + worker: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await worker.main() + r = await j.result(poll_delay=0) + assert r == 42 # 1 + + +async def test_enqueue_job_different_queues(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return 42 + + j1 = await arq_redis_cluster.enqueue_job('foobar', _queue_name='arq:queue1') + j2 = await arq_redis_cluster.enqueue_job('foobar', _queue_name='arq:queue2') + worker1: Worker = cluster_worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue1') + worker2: Worker = cluster_worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue2') + + await worker1.main() + await worker2.main() + r1 = await j1.result(poll_delay=0) + r2 = await j2.result(poll_delay=0) + assert r1 == 42 # 1 + assert r2 == 42 # 2 + + +async def test_enqueue_job_nested(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return 42 + + async def parent_job(ctx): + inner_job = await ctx['redis'].enqueue_job('foobar') + return inner_job.job_id + + job = await arq_redis_cluster.enqueue_job('parent_job') + worker: Worker = cluster_worker(functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')]) + + await worker.main() + result = await job.result(poll_delay=0) + assert result is not None + inner_job = Job(result, arq_redis_cluster) + inner_result = await inner_job.result(poll_delay=0) + assert inner_result == 42 + + +# async def test_enqueue_job_nested_custom_serializer(arq_redis_msgpack: ArqRedisCluster, cluster_worker): +# async def foobar(ctx): +# return 42 + +# async def parent_job(ctx): +# inner_job = await ctx['redis'].enqueue_job('foobar') +# return inner_job.job_id + +# job = await arq_redis_msgpack.enqueue_job('parent_job') + +# worker: Worker = cluster_worker( +# functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], +# arq_redis=None, +# job_serializer=msgpack.packb, +# job_deserializer=functools.partial(msgpack.unpackb, raw=False), +# ) + +# await worker.main() +# result = await job.result(poll_delay=0) +# assert result is not None +# inner_job = Job(result, arq_redis_msgpack, _deserializer=functools.partial(msgpack.unpackb, raw=False)) +# inner_result = await inner_job.result(poll_delay=0) +# assert inner_result == 42 + + +async def test_enqueue_job_custom_queue(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return 42 + + async def parent_job(ctx): + inner_job = await ctx['redis'].enqueue_job('foobar') + return inner_job.job_id + + job = await arq_redis_cluster.enqueue_job('parent_job', _queue_name='spanner') + + worker: Worker = cluster_worker( + functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], + arq_redis=None, + queue_name='spanner', + ) + + await worker.main() + inner_job_id = await job.result(poll_delay=0) + assert inner_job_id is not None + inner_job = Job(inner_job_id, arq_redis_cluster, _queue_name='spanner') + inner_result = await inner_job.result(poll_delay=0) + assert inner_result == 42 + + +async def test_job_error(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + raise RuntimeError('foobar error') + + j = await arq_redis_cluster.enqueue_job('foobar') + worker: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await worker.main() + + with pytest.raises(RuntimeError, match='foobar error'): + await j.result(poll_delay=0) + + +async def test_job_info(arq_redis_cluster: ArqRedisCluster): + t_before = time() + j = await arq_redis_cluster.enqueue_job('foobar', 123, a=456) + info = await j.info() + assert info.enqueue_time == IsNow(tz='utc') + assert info.job_try is None + assert info.function == 'foobar' + assert info.args == (123,) + assert info.kwargs == {'a': 456} + assert abs(t_before * 1000 - info.score) < 1000 + + +async def test_repeat_job(arq_redis_cluster: ArqRedisCluster): + j1 = await arq_redis_cluster.enqueue_job('foobar', _job_id='job_id') + assert isinstance(j1, Job) + j2 = await arq_redis_cluster.enqueue_job('foobar', _job_id='job_id') + assert j2 is None + + +async def test_defer_until(arq_redis_cluster: ArqRedisCluster): + j1 = await arq_redis_cluster.enqueue_job( + 'foobar', _job_id='job_id', _defer_until=datetime(2032, 1, 1, tzinfo=timezone.utc) + ) + assert type(j1) == Job + assert isinstance(j1, Job) + score = await arq_redis_cluster.zscore(default_queue_name, 'job_id') + assert score == 1_956_528_000_000 + + +async def test_defer_by(arq_redis_cluster: ArqRedisCluster): + j1 = await arq_redis_cluster.enqueue_job('foobar', _job_id='job_id', _defer_by=20) + assert isinstance(j1, Job) + score = await arq_redis_cluster.zscore(default_queue_name, 'job_id') + ts = timestamp_ms() + assert score > ts + 19000 + assert ts + 21000 > score + + +async def test_mung(arq_redis_cluster: ArqRedisCluster, cluster_worker): + """ + check a job can't be enqueued multiple times with the same id + """ + counter = Counter() + + async def count(ctx, v): + counter[v] += 1 + + tasks = [] + for i in range(50): + tasks += [ + arq_redis_cluster.enqueue_job('count', i, _job_id=f'v-{i}'), + arq_redis_cluster.enqueue_job('count', i, _job_id=f'v-{i}'), + ] + shuffle(tasks) + await asyncio.gather(*tasks) + + worker: Worker = cluster_worker(functions=[func(count, name='count')]) + await worker.main() + assert counter.most_common(1)[0][1] == 1 # no job go enqueued twice + + +async def test_custom_try(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return ctx['job_try'] + + j1 = await arq_redis_cluster.enqueue_job('foobar') + w: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await w.main() + r = await j1.result(poll_delay=0) + assert r == 1 + + j2 = await arq_redis_cluster.enqueue_job('foobar', _job_try=3) + await w.main() + r = await j2.result(poll_delay=0) + assert r == 3 + + +async def test_custom_try2(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + if ctx['job_try'] == 3: + raise Retry() + return ctx['job_try'] + + j1 = await arq_redis_cluster.enqueue_job('foobar', _job_try=3) + w: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await w.main() + r = await j1.result(poll_delay=0) + assert r == 4 + + +async def test_cant_pickle_arg(arq_redis_cluster: ArqRedisCluster): + class Foobar: + def __getstate__(self): + raise TypeError("this doesn't pickle") + + with pytest.raises(SerializationError, match='unable to serialize job "foobar"'): + await arq_redis_cluster.enqueue_job('foobar', Foobar()) + + +async def test_cant_pickle_result(arq_redis_cluster: ArqRedisCluster, cluster_worker): + class Foobar: + def __getstate__(self): + raise TypeError("this doesn't pickle") + + async def foobar(ctx): + return Foobar() + + j1 = await arq_redis_cluster.enqueue_job('foobar') + w: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await w.main() + with pytest.raises(SerializationError, match='unable to serialize result'): + await j1.result(poll_delay=0) + + +async def test_get_jobs(arq_redis_cluster: ArqRedisCluster): + await arq_redis_cluster.enqueue_job('foobar', a=1, b=2, c=3) + await asyncio.sleep(0.01) + await arq_redis_cluster.enqueue_job('second', 4, b=5, c=6) + await asyncio.sleep(0.01) + await arq_redis_cluster.enqueue_job('third', 7, b=8) + jobs = await arq_redis_cluster.queued_jobs() + assert [dataclasses.asdict(j) for j in jobs] == [ + { + 'function': 'foobar', + 'args': (), + 'kwargs': {'a': 1, 'b': 2, 'c': 3}, + 'job_try': None, + 'enqueue_time': IsNow(tz='utc'), + 'score': IsInt(), + }, + { + 'function': 'second', + 'args': (4,), + 'kwargs': {'b': 5, 'c': 6}, + 'job_try': None, + 'enqueue_time': IsNow(tz='utc'), + 'score': IsInt(), + }, + { + 'function': 'third', + 'args': (7,), + 'kwargs': {'b': 8}, + 'job_try': None, + 'enqueue_time': IsNow(tz='utc'), + 'score': IsInt(), + }, + ] + assert jobs[0].score < jobs[1].score < jobs[2].score + assert isinstance(jobs[0], JobDef) + assert isinstance(jobs[1], JobDef) + assert isinstance(jobs[2], JobDef) + + +async def test_enqueue_multiple(arq_redis_cluster: ArqRedisCluster, caplog): + caplog.set_level(logging.DEBUG) + results = await asyncio.gather(*[arq_redis_cluster.enqueue_job('foobar', i, _job_id='testing') for i in range(10)]) + assert sum(r is not None for r in results) == 1 + assert sum(r is None for r in results) == 9 + assert 'WatchVariableError' not in caplog.text From 0a6b3b04eb8ae23da9d4ea21ae9a02ecc2815aa5 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Tue, 14 Nov 2023 09:06:42 -0600 Subject: [PATCH 07/22] main cluster tests --- arq/connections.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arq/connections.py b/arq/connections.py index 8f7fb22f..7d430e73 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -244,6 +244,7 @@ def __init__(self, client: RedisCluster) -> None: super().__init__(client) async def watch(self, *names: KeyT) -> None: + await self.immediate_execute_command('WATCH', *names) self.watching = True def multi(self) -> None: @@ -256,7 +257,7 @@ def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> 'Clu self._command_stack.append(cmd) return self - async def immediate_execute_command(self, cmd: PipelineCommand): + async def immediate_execute_command(self, cmd: PipelineCommand)-> Any: try: return await self._client.execute_command(*cmd.args, **cmd.kwargs) except Exception as e: From ba4c8b8edd9cb76db62e08b7cd9d18ea826afd80 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Tue, 14 Nov 2023 09:40:36 -0600 Subject: [PATCH 08/22] fixed conftest --- arq/connections.py | 2 +- tests/conftest.py | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index 7d430e73..db16a7c2 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -257,7 +257,7 @@ def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> 'Clu self._command_stack.append(cmd) return self - async def immediate_execute_command(self, cmd: PipelineCommand)-> Any: + async def immediate_execute_command(self, cmd: PipelineCommand) -> Any: try: return await self._client.execute_command(*cmd.args, **cmd.kwargs) except Exception as e: diff --git a/tests/conftest.py b/tests/conftest.py index 88d44a4c..4ba40918 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import msgpack import pytest +from redislite import Redis from arq.connections import ArqRedis, RedisSettings, create_pool from arq.worker import Worker @@ -31,13 +32,10 @@ async def arq_redis(loop): @pytest.fixture -async def arq_redis_cluster(loop): - settings = RedisSettings(host='localhost', port='6379', conn_timeout=5, cluster_mode=True) - redis_ = await create_pool(settings) - await redis_.flushall() - - yield redis_ - await redis_.close() +async def unix_socket_path(loop, tmp_path): + rdb = Redis(str(tmp_path / 'redis_test.db')) + yield rdb.socket_file + rdb.close() @pytest.fixture @@ -54,6 +52,16 @@ async def arq_redis_msgpack(loop): await redis_.close(close_connection_pool=True) +@pytest.fixture +async def arq_redis_cluster(loop): + settings = RedisSettings(host='localhost', port='6379', conn_timeout=5, cluster_mode=True) + redis_ = await create_pool(settings) + await redis_.flushall() + + yield redis_ + await redis_.close() + + @pytest.fixture async def worker(arq_redis): worker_: Worker = None @@ -104,7 +112,7 @@ async def create_pool_(settings, *args, **kwargs): yield create_pool_ - await asyncio.gather(*[await p.close(close_connection_pool=True) for p in pools]) + await asyncio.gather(*[p.close(close_connection_pool=True) for p in pools]) @pytest.fixture(name='cancel_remaining_task') From 88e06ee5bd2a75638c9aa7411053374302b59fe8 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Tue, 14 Nov 2023 09:59:52 -0600 Subject: [PATCH 09/22] removed job changes --- arq/connections.py | 6 +++--- arq/jobs.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index db16a7c2..5d20c689 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -9,7 +9,7 @@ from uuid import uuid4 from redis.asyncio import ConnectionPool, Redis -from redis.asyncio.cluster import ClusterPipeline, PipelineCommand, RedisCluster +from redis.asyncio.cluster import ClusterPipeline, PipelineCommand, RedisCluster # type: ignore from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError from redis.typing import EncodableT, KeyT @@ -213,7 +213,7 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef] return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs]) -class ArqRedisCluster(RedisCluster): +class ArqRedisCluster(RedisCluster): # type: ignore def __init__( self, job_serializer: Optional[Serializer] = None, @@ -238,7 +238,7 @@ def pipeline(self, transaction: Any | None = None, shard_hint: Any | None = None return ArqRedisClusterPipeline(self) -class ArqRedisClusterPipeline(ClusterPipeline): +class ArqRedisClusterPipeline(ClusterPipeline): # type: ignore def __init__(self, client: RedisCluster) -> None: self.watching = False super().__init__(client) diff --git a/arq/jobs.py b/arq/jobs.py index ce715c83..8028cbe7 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -7,13 +7,13 @@ from enum import Enum from typing import Any, Callable, Dict, Optional, Tuple -from redis.asyncio.cluster import RedisCluster +from redis.asyncio import Redis from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix from .utils import ms_to_datetime, poll, timestamp_ms logger = logging.getLogger('arq.jobs') -logging.basicConfig(level=logging.DEBUG) + Serializer = Callable[[Dict[str, Any]], bytes] Deserializer = Callable[[bytes], Dict[str, Any]] @@ -73,7 +73,7 @@ class Job: def __init__( self, job_id: str, - redis: 'RedisCluster[bytes]', + redis: 'Redis[bytes]', _queue_name: str = default_queue_name, _deserializer: Optional[Deserializer] = None, ): @@ -109,7 +109,6 @@ async def result( if v: info = deserialize_result(v, deserializer=self._deserializer) - print(info) if info.success: return info.result elif isinstance(info.result, (Exception, asyncio.CancelledError)): From bc546f67c378301c29dea6788a6d8410e9db2553 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Tue, 14 Nov 2023 12:48:26 -0600 Subject: [PATCH 10/22] adds cluster to CI --- .github/workflows/ci.yml | 10 ++++++++++ tests/conftest.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3d79f197..bcfd9018 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -78,6 +78,16 @@ jobs: steps: - uses: actions/checkout@v2 + - name: Test redis cluster + uses: vishnudxb/redis-cluster@1.0.9 + with: + master1-port: 5000 + master2-port: 5001 + master3-port: 5002 + slave1-port: 5003 + slave2-port: 5004 + slave3-port: 5005 + sleep-duration: 5 - name: set up python uses: actions/setup-python@v4 diff --git a/tests/conftest.py b/tests/conftest.py index 4ba40918..e0200192 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,7 +54,7 @@ async def arq_redis_msgpack(loop): @pytest.fixture async def arq_redis_cluster(loop): - settings = RedisSettings(host='localhost', port='6379', conn_timeout=5, cluster_mode=True) + settings = RedisSettings(host='localhost', port='5000', conn_timeout=5, cluster_mode=True) redis_ = await create_pool(settings) await redis_.flushall() From 204086dc1901db6ca329bf02ad7d1fed8317e2c8 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Thu, 16 Nov 2023 10:00:49 -0600 Subject: [PATCH 11/22] removed watch call --- arq/connections.py | 1 - tests/conftest.py | 9 ++++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index 5d20c689..44a05acb 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -244,7 +244,6 @@ def __init__(self, client: RedisCluster) -> None: super().__init__(client) async def watch(self, *names: KeyT) -> None: - await self.immediate_execute_command('WATCH', *names) self.watching = True def multi(self) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index e0200192..653822bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,7 +54,14 @@ async def arq_redis_msgpack(loop): @pytest.fixture async def arq_redis_cluster(loop): - settings = RedisSettings(host='localhost', port='5000', conn_timeout=5, cluster_mode=True) + settings = RedisSettings( + host='http://clustercfg.test-cluster-ssl.aqtke6.use2.cache.amazonaws.com', + port='6379', + conn_timeout=5, + cluster_mode=True, + ssl=True, + ssl_cert_reqs=None, + ) redis_ = await create_pool(settings) await redis_.flushall() From 751ad4401ef6607da5d0cb51df1d52ce04e78743 Mon Sep 17 00:00:00 2001 From: Rob Freedy Date: Thu, 16 Nov 2023 11:14:16 -0500 Subject: [PATCH 12/22] Switching back to localhost 5000 for tests on redis cluster --- tests/conftest.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 653822bf..1539e855 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,12 +55,10 @@ async def arq_redis_msgpack(loop): @pytest.fixture async def arq_redis_cluster(loop): settings = RedisSettings( - host='http://clustercfg.test-cluster-ssl.aqtke6.use2.cache.amazonaws.com', - port='6379', - conn_timeout=5, - cluster_mode=True, - ssl=True, - ssl_cert_reqs=None, + host=' localhost', + port='5000', + conn_timeout=60, + cluster_mode=True ) redis_ = await create_pool(settings) await redis_.flushall() From 5bc8a8efd22ca8fc92660661ecd6717f829b83dd Mon Sep 17 00:00:00 2001 From: Rob Freedy Date: Thu, 16 Nov 2023 11:18:26 -0500 Subject: [PATCH 13/22] Adding Redis Cluster Health Check in gha --- .github/workflows/ci.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bcfd9018..c3187b4c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -88,6 +88,12 @@ jobs: slave2-port: 5004 slave3-port: 5005 sleep-duration: 5 + - name: Redis Cluster Health Check + run: | + sudo apt-get install -y redis-tools + docker ps -a + redis-cli -h 127.0.0.1 -p 5000 ping + redis-cli -h 127.0.0.1 -p 5000 cluster nodes - name: set up python uses: actions/setup-python@v4 From e49c7e84328a5276fc56b0e08e748860d6d09b5f Mon Sep 17 00:00:00 2001 From: Rob Freedy Date: Thu, 16 Nov 2023 11:21:38 -0500 Subject: [PATCH 14/22] Testing removing the single node redis service for tests --- .github/workflows/ci.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c3187b4c..ed3b4dae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,12 +69,12 @@ jobs: runs-on: ${{ format('{0}-latest', matrix.os) }} - services: - redis: - image: redis:${{ matrix.redis }} - ports: - - 6379:6379 - options: --entrypoint redis-server + # services: + # redis: + # image: redis:${{ matrix.redis }} + # ports: + # - 6379:6379 + # options: --entrypoint redis-server steps: - uses: actions/checkout@v2 From 31b2361a30cea3293b0263d0fe4e0574b258a88d Mon Sep 17 00:00:00 2001 From: Rob Freedy Date: Thu, 16 Nov 2023 11:23:58 -0500 Subject: [PATCH 15/22] Trying single node on 6380 --- .github/workflows/ci.yml | 12 ++++++------ tests/conftest.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed3b4dae..df80f2b2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,12 +69,12 @@ jobs: runs-on: ${{ format('{0}-latest', matrix.os) }} - # services: - # redis: - # image: redis:${{ matrix.redis }} - # ports: - # - 6379:6379 - # options: --entrypoint redis-server + services: + redis: + image: redis:${{ matrix.redis }} + ports: + - 6380:6379 + options: --entrypoint redis-server steps: - uses: actions/checkout@v2 diff --git a/tests/conftest.py b/tests/conftest.py index 1539e855..03f5e172 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ def _fix_loop(event_loop): async def arq_redis(loop): redis_ = ArqRedis( host='localhost', - port=6379, + port=6380, encoding='utf-8', ) From c3dc1723d99e1f8f892a6797f073630067300bb2 Mon Sep 17 00:00:00 2001 From: Rob Freedy Date: Thu, 16 Nov 2023 11:28:27 -0500 Subject: [PATCH 16/22] Ridiculous port testing --- .github/workflows/ci.yml | 2 +- tests/conftest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index df80f2b2..b9da83e7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -73,7 +73,7 @@ jobs: redis: image: redis:${{ matrix.redis }} ports: - - 6380:6379 + - 9000:6379 options: --entrypoint redis-server steps: diff --git a/tests/conftest.py b/tests/conftest.py index 03f5e172..0c21d361 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ def _fix_loop(event_loop): async def arq_redis(loop): redis_ = ArqRedis( host='localhost', - port=6380, + port=9000, encoding='utf-8', ) From 37929554a35b18ebf5f98f828f1365c15508c739 Mon Sep 17 00:00:00 2001 From: Rob Freedy Date: Thu, 16 Nov 2023 11:31:19 -0500 Subject: [PATCH 17/22] attempting to test everything with a cluster --- .github/workflows/ci.yml | 12 ++++++------ tests/conftest.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b9da83e7..ed3b4dae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,12 +69,12 @@ jobs: runs-on: ${{ format('{0}-latest', matrix.os) }} - services: - redis: - image: redis:${{ matrix.redis }} - ports: - - 9000:6379 - options: --entrypoint redis-server + # services: + # redis: + # image: redis:${{ matrix.redis }} + # ports: + # - 6379:6379 + # options: --entrypoint redis-server steps: - uses: actions/checkout@v2 diff --git a/tests/conftest.py b/tests/conftest.py index 0c21d361..c4e6ba25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ def _fix_loop(event_loop): async def arq_redis(loop): redis_ = ArqRedis( host='localhost', - port=9000, + port=6379, encoding='utf-8', ) @@ -42,7 +42,7 @@ async def unix_socket_path(loop, tmp_path): async def arq_redis_msgpack(loop): redis_ = ArqRedis( host='localhost', - port=6379, + port=5000, encoding='utf-8', job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False), From dd8b63a84b8488081686d8749be39090b4e147f1 Mon Sep 17 00:00:00 2001 From: Rob Freedy Date: Thu, 16 Nov 2023 11:45:14 -0500 Subject: [PATCH 18/22] Cluster info print out in gha --- tests/conftest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c4e6ba25..7e83845f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ def _fix_loop(event_loop): @pytest.fixture async def arq_redis(loop): redis_ = ArqRedis( - host='localhost', + host='127.0.0.1', port=6379, encoding='utf-8', ) @@ -41,8 +41,8 @@ async def unix_socket_path(loop, tmp_path): @pytest.fixture async def arq_redis_msgpack(loop): redis_ = ArqRedis( - host='localhost', - port=5000, + host='127.0.0.1', + port=6379, encoding='utf-8', job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False), @@ -55,7 +55,7 @@ async def arq_redis_msgpack(loop): @pytest.fixture async def arq_redis_cluster(loop): settings = RedisSettings( - host=' localhost', + host=' 127.0.0.1', port='5000', conn_timeout=60, cluster_mode=True From be0eb7591281fa8a59d28a359c8d30a8ce488726 Mon Sep 17 00:00:00 2001 From: Rob Freedy Date: Thu, 16 Nov 2023 11:47:40 -0500 Subject: [PATCH 19/22] Actually testing cluster info --- .github/workflows/ci.yml | 1 + tests/conftest.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed3b4dae..6de59b1a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -94,6 +94,7 @@ jobs: docker ps -a redis-cli -h 127.0.0.1 -p 5000 ping redis-cli -h 127.0.0.1 -p 5000 cluster nodes + redis-cli -h 127.0.0.1 -p 5000 cluster info - name: set up python uses: actions/setup-python@v4 diff --git a/tests/conftest.py b/tests/conftest.py index 7e83845f..949b02af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,7 +55,7 @@ async def arq_redis_msgpack(loop): @pytest.fixture async def arq_redis_cluster(loop): settings = RedisSettings( - host=' 127.0.0.1', + host='127.0.0.1', port='5000', conn_timeout=60, cluster_mode=True From 482767c36ae6a9b7a0360bcd413d2f026d2856e1 Mon Sep 17 00:00:00 2001 From: Rob Freedy Date: Thu, 16 Nov 2023 11:50:20 -0500 Subject: [PATCH 20/22] Fixing timeout on cluster test --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 949b02af..06bccadc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,7 +57,7 @@ async def arq_redis_cluster(loop): settings = RedisSettings( host='127.0.0.1', port='5000', - conn_timeout=60, + conn_timeout=5, cluster_mode=True ) redis_ = await create_pool(settings) From beffb977e27002795b6b14976164313c9cf4a037 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Thu, 16 Nov 2023 14:41:15 -0600 Subject: [PATCH 21/22] swapped port between redis instances --- .github/workflows/ci.yml | 14 +++++++------- tests/conftest.py | 12 ++++++------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6de59b1a..0c92cd97 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,12 +69,12 @@ jobs: runs-on: ${{ format('{0}-latest', matrix.os) }} - # services: - # redis: - # image: redis:${{ matrix.redis }} - # ports: - # - 6379:6379 - # options: --entrypoint redis-server + services: + redis: + image: redis:${{ matrix.redis }} + ports: + - 7000:7000 + options: --entrypoint redis-server steps: - uses: actions/checkout@v2 @@ -87,7 +87,7 @@ jobs: slave1-port: 5003 slave2-port: 5004 slave3-port: 5005 - sleep-duration: 5 + sleep-duration: 10 - name: Redis Cluster Health Check run: | sudo apt-get install -y redis-tools diff --git a/tests/conftest.py b/tests/conftest.py index 06bccadc..05913c06 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,8 +19,8 @@ def _fix_loop(event_loop): @pytest.fixture async def arq_redis(loop): redis_ = ArqRedis( - host='127.0.0.1', - port=6379, + host='localhost', + port=7000, encoding='utf-8', ) @@ -41,8 +41,8 @@ async def unix_socket_path(loop, tmp_path): @pytest.fixture async def arq_redis_msgpack(loop): redis_ = ArqRedis( - host='127.0.0.1', - port=6379, + host='localhost', + port=7000, encoding='utf-8', job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False), @@ -55,8 +55,8 @@ async def arq_redis_msgpack(loop): @pytest.fixture async def arq_redis_cluster(loop): settings = RedisSettings( - host='127.0.0.1', - port='5000', + host='localhost', + port=6379, conn_timeout=5, cluster_mode=True ) From cb50e42b47fa5bbd6962c724368ca2a8f7080e15 Mon Sep 17 00:00:00 2001 From: jeanluciano Date: Fri, 17 Nov 2023 13:41:06 -0600 Subject: [PATCH 22/22] updated doc strings --- arq/connections.py | 16 +++++++++++++++- tests/conftest.py | 7 +------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index 44a05acb..0d8eb85a 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -214,6 +214,19 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef] class ArqRedisCluster(RedisCluster): # type: ignore + """ + Thin subclass of ``from redis.asyncio.cluster.RedisCluster`` which patches methods of RedisClusterPipeline + to support redis cluster`. + + :param redis_settings: an instance of ``arq.connections.RedisSettings``. + :param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps + :param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads + :param default_queue_name: the default queue name to use, defaults to ``arq.queue``. + :param expires_extra_ms: the default length of time from when a job is expected to start + after which the job expires, defaults to 1 day in ms. + :param kwargs: keyword arguments directly passed to ``from redis.asyncio.cluster.RedisCluster``. + """ + def __init__( self, job_serializer: Optional[Serializer] = None, @@ -281,7 +294,8 @@ async def create_pool( """ Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails. - Returns a :class:`arq.connections.ArqRedis` instance, thus allowing job enqueuing. + Returns a :class:`arq.connections.ArqRedis` instance or :class: `arq.connections.ArqRedisCluster` depending on + whether `cluster_mode` flag is enabled in `RedisSettings`, thus allowing job enqueuing. """ settings: RedisSettings = RedisSettings() if settings_ is None else settings_ diff --git a/tests/conftest.py b/tests/conftest.py index 05913c06..a0d86e23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,12 +54,7 @@ async def arq_redis_msgpack(loop): @pytest.fixture async def arq_redis_cluster(loop): - settings = RedisSettings( - host='localhost', - port=6379, - conn_timeout=5, - cluster_mode=True - ) + settings = RedisSettings(host='localhost', port=6379, conn_timeout=5, cluster_mode=True) redis_ = await create_pool(settings) await redis_.flushall()