Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster mode support #420

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,28 @@ jobs:
redis:
image: redis:${{ matrix.redis }}
ports:
- 6379:6379
- 7000:7000
options: --entrypoint redis-server

steps:
- uses: actions/checkout@v2
- name: Test redis cluster
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add cluster tests separately from the main tests as redis-cluster-test.

uses: vishnudxb/[email protected]
with:
master1-port: 5000
master2-port: 5001
master3-port: 5002
slave1-port: 5003
slave2-port: 5004
slave3-port: 5005
sleep-duration: 10
- name: Redis Cluster Health Check
run: |
sudo apt-get install -y redis-tools
docker ps -a
redis-cli -h 127.0.0.1 -p 5000 ping
redis-cli -h 127.0.0.1 -p 5000 cluster nodes
redis-cli -h 127.0.0.1 -p 5000 cluster info

- name: set up python
uses: actions/setup-python@v4
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/env*/
/venv*/
/.idea
__pycache__/
*.py[cod]
Expand Down
3 changes: 2 additions & 1 deletion arq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .connections import ArqRedis, create_pool
from .connections import ArqRedis, ArqRedisCluster, create_pool
from .cron import cron
from .version import VERSION
from .worker import Retry, Worker, check_health, func, run_worker
Expand All @@ -7,6 +7,7 @@

__all__ = (
'ArqRedis',
'ArqRedisCluster',
'create_pool',
'cron',
'VERSION',
Expand Down
114 changes: 103 additions & 11 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,25 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, TypeVar, Union
from urllib.parse import parse_qs, urlparse
from uuid import uuid4

from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.cluster import ClusterPipeline, PipelineCommand, RedisCluster # type: ignore
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import RedisError, WatchError
from redis.typing import EncodableT, KeyT

from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
from .utils import timestamp_ms, to_ms, to_unix_ms

logger = logging.getLogger('arq.connections')
logging.basicConfig(level=logging.DEBUG)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logging.basicConfig(level=logging.DEBUG)



_KeyT = TypeVar('_KeyT', bound=KeyT)


@dataclass
Expand All @@ -27,7 +33,7 @@ class RedisSettings:
Used by :func:`arq.connections.create_pool` and :class:`arq.worker.Worker`.
"""

host: Union[str, List[Tuple[str, int]]] = 'localhost'
host: Union[str, List[Tuple[str, int]]] = 'test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
host: Union[str, List[Tuple[str, int]]] = 'test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com'
host: Union[str, List[Tuple[str, int]]] = 'localhost'

port: int = 6379
unix_socket_path: Optional[str] = None
database: int = 0
Expand All @@ -43,7 +49,7 @@ class RedisSettings:
conn_timeout: int = 1
conn_retries: int = 5
conn_retry_delay: int = 1

cluster_mode: bool = True
sentinel: bool = False
sentinel_master: str = 'mymaster'

Expand Down Expand Up @@ -168,7 +174,9 @@ async def enqueue_job(
except WatchError:
# job got enqueued since we checked 'job_exists'
return None
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
the_job = Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
logger.debug(the_job)
return the_job
Comment on lines +177 to +179
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert.


async def _get_job_result(self, key: bytes) -> JobResult:
job_id = key[len(result_key_prefix) :].decode()
Expand Down Expand Up @@ -205,6 +213,75 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])


class ArqRedisCluster(RedisCluster): # type: ignore
"""
Thin subclass of ``from redis.asyncio.cluster.RedisCluster`` which patches methods of RedisClusterPipeline
to support redis cluster`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
to support redis cluster`.
to support redis cluster.


:param redis_settings: an instance of ``arq.connections.RedisSettings``.
:param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps
:param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads
:param default_queue_name: the default queue name to use, defaults to ``arq.queue``.
:param expires_extra_ms: the default length of time from when a job is expected to start
after which the job expires, defaults to 1 day in ms.
:param kwargs: keyword arguments directly passed to ``from redis.asyncio.cluster.RedisCluster``.
"""

def __init__(
self,
job_serializer: Optional[Serializer] = None,
job_deserializer: Optional[Deserializer] = None,
default_queue_name: str = default_queue_name,
expires_extra_ms: int = expires_extra_ms,
**kwargs: Any,
) -> None:
self.job_serializer = job_serializer
self.job_deserializer = job_deserializer
self.default_queue_name = default_queue_name
self.expires_extra_ms = expires_extra_ms
super().__init__(**kwargs)

enqueue_job = ArqRedis.enqueue_job
_get_job_result = ArqRedis._get_job_result
all_job_results = ArqRedis.all_job_results
_get_job_def = ArqRedis._get_job_def
queued_jobs = ArqRedis.queued_jobs

def pipeline(self, transaction: Any | None = None, shard_hint: Any | None = None) -> ClusterPipeline:
return ArqRedisClusterPipeline(self)


class ArqRedisClusterPipeline(ClusterPipeline): # type: ignore
def __init__(self, client: RedisCluster) -> None:
self.watching = False
super().__init__(client)

async def watch(self, *names: KeyT) -> None:
self.watching = True

def multi(self) -> None:
self.watching = False

def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> 'ClusterPipeline':
cmd = PipelineCommand(len(self._command_stack), *args, **kwargs)
if self.watching:
return self.immediate_execute_command(cmd)
self._command_stack.append(cmd)
return self

async def immediate_execute_command(self, cmd: PipelineCommand) -> Any:
try:
return await self._client.execute_command(*cmd.args, **cmd.kwargs)
except Exception as e:
cmd.result = e

def _split_command_across_slots(self, command: str, *keys: KeyT) -> 'ClusterPipeline':
for slot_keys in self._client._partition_keys_by_slot(keys).values():
if self.watching:
return self.execute_command(command, *slot_keys)
return self


async def create_pool(
settings_: RedisSettings = None,
*,
Expand All @@ -217,7 +294,8 @@ async def create_pool(
"""
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.

Returns a :class:`arq.connections.ArqRedis` instance, thus allowing job enqueuing.
Returns a :class:`arq.connections.ArqRedis` instance or :class: `arq.connections.ArqRedisCluster` depending on
whether `cluster_mode` flag is enabled in `RedisSettings`, thus allowing job enqueuing.
"""
settings: RedisSettings = RedisSettings() if settings_ is None else settings_

Expand All @@ -236,9 +314,25 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
)
return client.master_for(settings.sentinel_master, redis_class=ArqRedis)

elif settings.cluster_mode:
pool_factory = functools.partial(
ArqRedisCluster,
host=settings.host,
port=settings.port,
socket_connect_timeout=settings.conn_timeout,
ssl=settings.ssl,
ssl_keyfile=settings.ssl_keyfile,
ssl_certfile=settings.ssl_certfile,
ssl_cert_reqs=settings.ssl_cert_reqs,
ssl_ca_certs=settings.ssl_ca_certs,
ssl_ca_data=settings.ssl_ca_data,
ssl_check_hostname=settings.ssl_check_hostname,
)
else:
pool_factory = functools.partial(
ArqRedis,
db=settings.database,
username=settings.username,
host=settings.host,
port=settings.port,
unix_socket_path=settings.unix_socket_path,
Expand All @@ -254,14 +348,11 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:

while True:
try:
pool = pool_factory(
db=settings.database, username=settings.username, password=settings.password, encoding='utf8'
)
pool = await pool_factory(password=settings.password, encoding='utf8')
pool.job_serializer = job_serializer
pool.job_deserializer = job_deserializer
pool.default_queue_name = default_queue_name
pool.expires_extra_ms = expires_extra_ms
await pool.ping()

except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e:
if retry < settings.conn_retries:
Expand All @@ -283,8 +374,9 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
return pool


# TODO
async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any]) -> None:
async with redis.pipeline(transaction=False) as pipe:
async with redis.pipeline() as pipe:
pipe.info(section='Server') # type: ignore[unused-coroutine]
pipe.info(section='Memory') # type: ignore[unused-coroutine]
pipe.info(section='Clients') # type: ignore[unused-coroutine]
Expand All @@ -299,5 +391,5 @@ async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any])
f'redis_version={redis_version} '
f'mem_usage={mem_usage} '
f'clients_connected={clients_connected} '
f'db_keys={key_count}'
f'db_keys={88}'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

)
13 changes: 8 additions & 5 deletions arq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from arq.cron import CronJob
from arq.jobs import Deserializer, JobResult, SerializationError, Serializer, deserialize_job_raw, serialize_result

from .connections import ArqRedis, RedisSettings, create_pool, log_redis_info
from .connections import ArqRedis, ArqRedisCluster, RedisSettings, create_pool, log_redis_info
from .constants import (
abort_job_max_age,
abort_jobs_ss,
Expand Down Expand Up @@ -44,6 +44,7 @@
from .typing import SecondsTimedelta, StartupShutdown, WorkerCoroutine, WorkerSettingsType # noqa F401

logger = logging.getLogger('arq.worker')
logging.basicConfig(level=logging.DEBUG)
no_result = object()


Expand Down Expand Up @@ -345,7 +346,8 @@ async def main(self) -> None:
)

logger.info('Starting worker for %d functions: %s', len(self.functions), ', '.join(self.functions))
await log_redis_info(self.pool, logger.info)
if not isinstance(self._pool, ArqRedisCluster):
await log_redis_info(self.pool, logger.info)
self.ctx['redis'] = self.pool
if self.on_startup:
await self.on_startup(self.ctx)
Expand All @@ -358,6 +360,7 @@ async def main(self) -> None:
await asyncio.gather(*self.tasks.values())
return None
queued_jobs = await self.pool.zcard(self.queue_name)

if queued_jobs == 0:
await asyncio.gather(*self.tasks.values())
return None
Expand Down Expand Up @@ -434,7 +437,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
if ongoing_exists or not score:
# job already started elsewhere, or already finished and removed from queue
self.sem.release()
logger.debug('job %s already running elsewhere', job_id)
# logger.debug('job %s already running elsewhere', job_id)
continue

pipe.multi()
Expand Down Expand Up @@ -843,7 +846,7 @@ async def close(self) -> None:
await self.pool.delete(self.health_check_key)
if self.on_shutdown:
await self.on_shutdown(self.ctx)
await self.pool.close(close_connection_pool=True)
await self.pool.close()
self._pool = None

def __repr__(self) -> str:
Expand Down Expand Up @@ -884,7 +887,7 @@ async def async_check_health(
else:
logger.info('Health check successful: %s', data)
r = 0
await redis.close(close_connection_pool=True)
await redis.close()
return r


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Changelog = 'https://github.com/samuelcolvin/arq/releases'
testpaths = 'tests'
filterwarnings = ['error']
asyncio_mode = 'auto'
timeout = 10

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert


[tool.coverage.run]
source = ['arq']
Expand Down
Loading
Loading