Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow unix-socket connection in RedisSettings (socket_address) #271

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
coverage:
precision: 2
range: [95, 100]
status:
patch: false
project: false

comment:
layout: 'header, diff, flags, files, footer'
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu]
python-version: ['3.6', '3.7', '3.8', '3.9']
python-version: ['3.7', '3.8', '3.9', '3.10']

env:
PYTHON: ${{ matrix.python-version }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ __pycache__/
.vscode/
.venv/
/.auto-format
/scratch/
26 changes: 26 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@
History
-------

v0.23a1 (2022-03-09)
....................
* Fix jobs timeout by @kiriusm2 in #248
* Update ``index.rst`` by @Kludex in #266
* Improve some docs wording by @johtso in #285
* fix error when cron jobs were terminanted by @tobymao in #273
* add ``on_job_start`` and ``on_job_end`` hooks by @tobymao in #274
* Update argument docstring definition by @sondrelg in #278
* fix tests and uprev test dependencies, #288
* Add link to WorkerSettings in documentation by @JonasKs in #279
* Allow setting ``job_id`` on cron jobs by @JonasKs in #293
* Fix docs typo by @johtso in #296
* support aioredis v2 by @Yolley in #259
* support python 3.10, #298

v0.22 (2021-09-02)
..................
* fix package importing in example, #261, thanks @cdpath
* restrict ``aioredis`` to ``<2.0.0`` (soon we'll support ``aioredis>=2.0.0``), #258, thanks @PaxPrz
* auto setting version on release, 759fe03

v0.21 (2021-07-06)
..................
* CI improvements #243
* fix ``log_redis_info`` #255

v0.20 (2021-04-26)
..................

Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.DEFAULT_GOAL := all
isort = isort arq tests
black = black -S -l 120 --target-version py37 arq tests
black = black arq tests

.PHONY: install
install:
Expand All @@ -15,7 +15,7 @@ format:

.PHONY: lint
lint:
flake8 arq/ tests/
flake8 --max-complexity 10 --max-line-length 120 --ignore E203,W503 arq/ tests/
$(isort) --check-only --df
$(black) --check

Expand Down
2 changes: 1 addition & 1 deletion arq/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def cli(*, worker_settings: str, burst: bool, check: bool, watch: str, verbose:
else:
kwargs = {} if burst is None else {'burst': burst}
if watch:
asyncio.get_event_loop().run_until_complete(watch_reload(watch, worker_settings_))
asyncio.run(watch_reload(watch, worker_settings_))
else:
run_worker(worker_settings_, **kwargs)

Expand Down
123 changes: 65 additions & 58 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from datetime import datetime, timedelta
from operator import attrgetter
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
from urllib.parse import urlparse
from urllib.parse import parse_qs, urlparse
from uuid import uuid4

import aioredis
from aioredis import MultiExecError, Redis
from pydantic.validators import make_arbitrary_type_validator
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import RedisError, WatchError

from .constants import default_queue_name, job_key_prefix, result_key_prefix
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
Expand Down Expand Up @@ -40,6 +41,7 @@ class RedisSettings:

host: Union[str, List[Tuple[str, int]]] = 'localhost'
port: int = 6379
unix_socket_path: Optional[str] = None
database: int = 0
password: Optional[str] = None
ssl: Union[bool, None, SSLContext] = None
Expand All @@ -53,13 +55,19 @@ class RedisSettings:
@classmethod
def from_dsn(cls, dsn: str) -> 'RedisSettings':
conf = urlparse(dsn)
assert conf.scheme in {'redis', 'rediss'}, 'invalid DSN scheme'
assert conf.scheme in {'redis', 'rediss', 'unix'}, 'invalid DSN scheme'
if conf.query and 'db' in parse_qs(conf.query):
# e.q. redis://localhost:6379?db=1
database = int(parse_qs(conf.query)['db'][0])
else:
database = int(conf.path.lstrip('/')) if conf.path else 0
return RedisSettings(
host=conf.hostname or 'localhost',
port=conf.port or 6379,
ssl=conf.scheme == 'rediss',
password=conf.password,
database=int((conf.path or '0').strip('/')),
database=database,
unix_socket_path=conf.path if conf.scheme == 'unix' else None,
)

def __repr__(self) -> str:
Expand All @@ -70,20 +78,20 @@ def __repr__(self) -> str:
expires_extra_ms = 86_400_000


class ArqRedis(Redis): # type: ignore
class ArqRedis(Redis): # type: ignore[misc]
"""
Thin subclass of ``aioredis.Redis`` which adds :func:`arq.connections.enqueue_job`.
Thin subclass of ``redis.asyncio.Redis`` which adds :func:`arq.connections.enqueue_job`.

:param redis_settings: an instance of ``arq.connections.RedisSettings``.
:param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps
:param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads
:param default_queue_name: the default queue name to use, defaults to ``arq.queue``.
:param kwargs: keyword arguments directly passed to ``aioredis.Redis``.
:param kwargs: keyword arguments directly passed to ``redis.asyncio.Redis``.
"""

def __init__(
self,
pool_or_conn: Any,
pool_or_conn: Optional[ConnectionPool] = None,
job_serializer: Optional[Serializer] = None,
job_deserializer: Optional[Deserializer] = None,
default_queue_name: str = default_queue_name,
Expand All @@ -92,7 +100,9 @@ def __init__(
self.job_serializer = job_serializer
self.job_deserializer = job_deserializer
self.default_queue_name = default_queue_name
super().__init__(pool_or_conn, **kwargs)
if pool_or_conn:
kwargs['connection_pool'] = pool_or_conn
super().__init__(**kwargs)

async def enqueue_job(
self,
Expand Down Expand Up @@ -129,14 +139,10 @@ async def enqueue_job(
defer_by_ms = to_ms(_defer_by)
expires_ms = to_ms(_expires)

with await self as conn:
pipe = conn.pipeline()
pipe.unwatch()
pipe.watch(job_key)
job_exists = pipe.exists(job_key)
job_result_exists = pipe.exists(result_key_prefix + job_id)
await pipe.execute()
if await job_exists or await job_result_exists:
async with self.pipeline(transaction=True) as pipe:
await pipe.watch(job_key)
if any(await asyncio.gather(pipe.exists(job_key), pipe.exists(result_key_prefix + job_id))):
await pipe.reset()
return None

enqueue_time_ms = timestamp_ms()
Expand All @@ -150,24 +156,22 @@ async def enqueue_job(
expires_ms = expires_ms or score - enqueue_time_ms + expires_extra_ms

job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
tr = conn.multi_exec()
tr.psetex(job_key, expires_ms, job)
tr.zadd(_queue_name, score, job_id)
pipe.multi()
pipe.psetex(job_key, expires_ms, job)
pipe.zadd(_queue_name, {job_id: score})
try:
await tr.execute()
except MultiExecError:
await pipe.execute()
except WatchError:
# job got enqueued since we checked 'job_exists'
# https://github.com/samuelcolvin/arq/issues/131, avoid warnings in log
await asyncio.gather(*tr._results, return_exceptions=True)
return None
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)

async def _get_job_result(self, key: str) -> JobResult:
job_id = key[len(result_key_prefix) :]
async def _get_job_result(self, key: bytes) -> JobResult:
job_id = key[len(result_key_prefix) :].decode()
job = Job(job_id, self, _deserializer=self.job_deserializer)
r = await job.result_info()
if r is None:
raise KeyError(f'job "{key}" not found')
raise KeyError(f'job "{key.decode()}" not found')
r.job_id = job_id
return r

Expand All @@ -179,8 +183,8 @@ async def all_job_results(self) -> List[JobResult]:
results = await asyncio.gather(*[self._get_job_result(k) for k in keys])
return sorted(results, key=attrgetter('enqueue_time'))

async def _get_job_def(self, job_id: str, score: int) -> JobDef:
v = await self.get(job_key_prefix + job_id, encoding=None)
async def _get_job_def(self, job_id: bytes, score: int) -> JobDef:
v = await self.get(job_key_prefix + job_id.decode())
jd = deserialize_job(v, deserializer=self.job_deserializer)
jd.score = score
return jd
Expand All @@ -189,8 +193,8 @@ async def queued_jobs(self, *, queue_name: str = default_queue_name) -> List[Job
"""
Get information about queued, mostly useful when testing.
"""
jobs = await self.zrange(queue_name, withscores=True)
return await asyncio.gather(*[self._get_job_def(job_id, score) for job_id, score in jobs])
jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1)
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])


async def create_pool(
Expand All @@ -204,8 +208,7 @@ async def create_pool(
"""
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.

Similar to ``aioredis.create_redis_pool`` except it returns a :class:`arq.connections.ArqRedis` instance,
thus allowing job enqueuing.
Returns a :class:`arq.connections.ArqRedis` instance, thus allowing job enqueuing.
"""
settings: RedisSettings = RedisSettings() if settings_ is None else settings_

Expand All @@ -214,32 +217,34 @@ async def create_pool(
), "str provided for 'host' but 'sentinel' is true; list of sentinels expected"

if settings.sentinel:
addr: Any = settings.host

async def pool_factory(*args: Any, **kwargs: Any) -> Redis:
client = await aioredis.sentinel.create_sentinel_pool(*args, ssl=settings.ssl, **kwargs)
return client.master_for(settings.sentinel_master)
def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
client = Sentinel(*args, sentinels=settings.host, ssl=settings.ssl, **kwargs)
return client.master_for(settings.sentinel_master, redis_class=ArqRedis)

else:
pool_factory = functools.partial(
aioredis.create_pool, create_connection_timeout=settings.conn_timeout, ssl=settings.ssl
ArqRedis,
host=settings.host,
port=settings.port,
unix_socket_path=settings.unix_socket_path,
socket_connect_timeout=settings.conn_timeout,
ssl=settings.ssl,
)
addr = settings.host, settings.port

try:
pool = await pool_factory(addr, db=settings.database, password=settings.password, encoding='utf8')
pool = ArqRedis(
pool,
job_serializer=job_serializer,
job_deserializer=job_deserializer,
default_queue_name=default_queue_name,
)
pool = pool_factory(db=settings.database, password=settings.password, encoding='utf8')
pool.job_serializer = job_serializer
pool.job_deserializer = job_deserializer
pool.default_queue_name = default_queue_name
await pool.ping()

except (ConnectionError, OSError, aioredis.RedisError, asyncio.TimeoutError) as e:
except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e:
if retry < settings.conn_retries:
logger.warning(
'redis connection error %s %s %s, %d retries remaining...',
addr,
'redis connection error %s:%s %s %s, %d retries remaining...',
settings.host,
settings.port,
e.__class__.__name__,
e,
settings.conn_retries - retry,
Expand All @@ -264,14 +269,16 @@ async def pool_factory(*args: Any, **kwargs: Any) -> Redis:


async def log_redis_info(redis: Redis, log_func: Callable[[str], Any]) -> None:
with await redis as r:
info_server, info_memory, info_clients, key_count = await asyncio.gather(
r.info(section='Server'), r.info(section='Memory'), r.info(section='Clients'), r.dbsize(),
)

redis_version = info_server.get('server', {}).get('redis_version', '?')
mem_usage = info_memory.get('memory', {}).get('used_memory_human', '?')
clients_connected = info_clients.get('clients', {}).get('connected_clients', '?')
async with redis.pipeline(transaction=True) as pipe:
pipe.info(section='Server')
pipe.info(section='Memory')
pipe.info(section='Clients')
pipe.dbsize()
info_server, info_memory, info_clients, key_count = await pipe.execute()

redis_version = info_server.get('redis_version', '?')
mem_usage = info_memory.get('used_memory_human', '?')
clients_connected = info_clients.get('connected_clients', '?')

log_func(
f'redis_version={redis_version} '
Expand Down
6 changes: 5 additions & 1 deletion arq/cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class CronJob:
microsecond: int
run_at_startup: bool
unique: bool
job_id: Optional[str]
timeout_s: Optional[float]
keep_result_s: Optional[float]
keep_result_forever: Optional[bool]
Expand Down Expand Up @@ -137,6 +138,7 @@ def cron(
microsecond: int = 123_456,
run_at_startup: bool = False,
unique: bool = True,
job_id: Optional[str] = None,
timeout: Optional[SecondsTimedelta] = None,
keep_result: Optional[float] = 0,
keep_result_forever: Optional[bool] = False,
Expand All @@ -159,7 +161,8 @@ def cron(
:param microsecond: microsecond(s) to run the job on,
defaults to 123456 as the world is busier at the top of a second, 0 - 1e6
:param run_at_startup: whether to run as worker starts
:param unique: whether the job should be only be executed once at each time
:param unique: whether the job should only be executed once at each time (useful if you have multiple workers)
:param job_id: ID of the job, can be used to enforce job uniqueness, spanning multiple cron schedules
:param timeout: job timeout
:param keep_result: how long to keep the result for
:param keep_result_forever: whether to keep results forever
Expand Down Expand Up @@ -188,6 +191,7 @@ def cron(
microsecond,
run_at_startup,
unique,
job_id,
timeout,
keep_result,
keep_result_forever,
Expand Down
Loading