diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3b7beab..8666584 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -5,7 +5,7 @@ on: jobs: lint: if: github.event_name == 'push' && !startsWith(github.event.ref, 'refs/tags') - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: @@ -41,7 +41,7 @@ jobs: - '3.9' - '3.10' os: - - ubuntu-latest + - ubuntu-22.04 services: redis-node-one: image: redis @@ -69,6 +69,18 @@ jobs: MONGODB_REPLICA_SET_MODE: primary MONGODB_REPLICA_SET_NAME: test_replica_set ALLOW_EMPTY_PASSWORD: 'yes' + postgres-node-one: + image: 'postgres:latest' + ports: + - '5432:5432' + env: + POSTGRES_PASSWORD: 'mysecretpassword' + postgres-node-two: + image: 'postgres:latest' + ports: + - '5433:5432' + env: + POSTGRES_PASSWORD: 'mysecretpassword' steps: - name: Checkout uses: actions/checkout@v3 diff --git a/pyproject.toml b/pyproject.toml index 5680751..0b8c29e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "sergeant" -version = "0.25.0" +version = "0.26.0" readme = "README.md" homepage = "https://github.com/Intsights/sergeant" repository = "https://github.com/Intsights/sergeant" @@ -39,6 +39,7 @@ psutil = "^5" pymongo = ">=3.0,<5.0" redis = "^4" typing_extensions = "^3.10" +psycopg = {version = "^3", extras = ["binary"]} [tool.poetry.dev-dependencies] pytest = "^7" diff --git a/sergeant/config.py b/sergeant/config.py index 2bf31c8..bc2e186 100644 --- a/sergeant/config.py +++ b/sergeant/config.py @@ -46,7 +46,7 @@ class Encoder: ) class Connector: params: typing.Dict[str, typing.Any] - type: typing_extensions.Literal['redis', 'mongo', 'local'] = 'local' + type: typing_extensions.Literal['redis', 'mongo', 'local', 'postgres'] = 'local' @dataclasses.dataclass( diff --git a/sergeant/connector/__init__.py b/sergeant/connector/__init__.py index 77ece41..9ef0a11 100644 --- a/sergeant/connector/__init__.py +++ b/sergeant/connector/__init__.py @@ -1,6 +1,7 @@ from . import _connector from . import local from . import mongo +from . import postgres from . import redis diff --git a/sergeant/connector/local.py b/sergeant/connector/local.py index 6181ec3..e49c73a 100644 --- a/sergeant/connector/local.py +++ b/sergeant/connector/local.py @@ -35,7 +35,7 @@ def acquire( expire_at = time.time() + ttl self.connection.execute( ''' - INSERT INTO locks (name, expireAt) + INSERT INTO locks (name, expire_at) VALUES(?, ?); ''', ( @@ -55,16 +55,16 @@ def acquire( def release( self, ) -> bool: - self.connection.execute( - ''' - DELETE FROM locks WHERE expireAt < ?; - ''', - ( - time.time(), - ), - ) - if self.acquired: + self.connection.execute( + ''' + DELETE FROM locks WHERE expire_at < ?; + ''', + ( + time.time(), + ), + ) + cursor = self.connection.execute( ''' DELETE FROM locks WHERE name = ?; @@ -85,7 +85,7 @@ def is_locked( cursor = self.connection.execute( ''' SELECT * FROM locks - WHERE name = ? AND expireAt > ?; + WHERE name = ? AND expire_at > ?; ''', ( self.name, @@ -107,8 +107,8 @@ def set_ttl( cursor = self.connection.execute( ''' UPDATE locks - SET expireAt = ? - WHERE name = ? AND expireAt > ?; + SET expire_at = ? + WHERE name = ? AND expire_at > ?; ''', ( now + ttl, @@ -125,7 +125,7 @@ def get_ttl( now = time.time() cursor = self.connection.execute( ''' - SELECT expireAt FROM locks + SELECT expire_at FROM locks WHERE name = ?; ''', ( @@ -173,9 +173,9 @@ def __init__( CREATE TABLE IF NOT EXISTS keys (name TEXT, value BLOB); CREATE UNIQUE INDEX IF NOT EXISTS key_by_name ON keys (name); - CREATE TABLE IF NOT EXISTS locks (name TEXT, expireAt REAL); + CREATE TABLE IF NOT EXISTS locks (name TEXT, expire_at REAL); CREATE UNIQUE INDEX IF NOT EXISTS lock_by_name ON locks (name); - CREATE INDEX IF NOT EXISTS lock_by_expireAt ON locks (expireAt); + CREATE INDEX IF NOT EXISTS lock_by_expire_at ON locks (expire_at); ''' ) diff --git a/sergeant/connector/mongo.py b/sergeant/connector/mongo.py index 2862ffc..8842ac3 100644 --- a/sergeant/connector/mongo.py +++ b/sergeant/connector/mongo.py @@ -412,15 +412,15 @@ def queue_length( ) -> int: queue_length = 0 - for i in range(self.number_of_connections): + for connection in self.connections: if include_delayed: - queue_length += self.next_connection.sergeant.task_queue.count_documents( + queue_length += connection.sergeant.task_queue.count_documents( filter={ 'queue_name': queue_name, }, ) else: - queue_length += self.next_connection.sergeant.task_queue.count_documents( + queue_length += connection.sergeant.task_queue.count_documents( filter={ 'queue_name': queue_name, 'priority': { diff --git a/sergeant/connector/postgres.py b/sergeant/connector/postgres.py new file mode 100644 index 0000000..70e5d1a --- /dev/null +++ b/sergeant/connector/postgres.py @@ -0,0 +1,496 @@ +import binascii +import math +import psycopg +import psycopg.rows +import random +import time +import typing + +from . import _connector + + +class Lock( + _connector.Lock, +): + def __init__( + self, + connection: psycopg.Connection, + name: str, + ) -> None: + self.connection = connection + self.name = name + + self.acquired = False + + def acquire( + self, + timeout: typing.Optional[float] = None, + check_interval: float = 1.0, + ttl: int = 60, + ) -> bool: + if timeout is not None: + time_to_stop = time.time() + timeout + else: + time_to_stop = None + + while True: + try: + expire_at = time.time() + ttl + with self.connection.cursor() as cursor: + cursor.execute( + query=''' + INSERT INTO locks (name, expire_at) + VALUES(%s, %s); + ''', + params=( + self.name, + expire_at, + ), + ) + self.acquired = True + + return True + except psycopg.errors.UniqueViolation: + if time_to_stop is not None and time.time() > time_to_stop: + return False + + time.sleep(check_interval) + + def release( + self, + ) -> bool: + if self.acquired: + with self.connection.cursor() as cursor: + cursor.execute( + query=''' + DELETE FROM locks WHERE expire_at < %s; + ''', + params=( + time.time(), + ), + ) + + cursor.execute( + query=''' + DELETE FROM locks WHERE name = %s; + ''', + params=( + self.name, + ), + ) + self.acquired = False + + return cursor.rowcount == 1 + + return False + + def is_locked( + self, + ) -> bool: + with self.connection.cursor() as cursor: + cursor.execute( + query=''' + SELECT * FROM locks + WHERE name = %s AND expire_at > %s; + ''', + params=( + self.name, + time.time(), + ), + ) + + result = cursor.fetchone() + if result is None: + return False + else: + return True + + def set_ttl( + self, + ttl: int, + ) -> bool: + now = time.time() + + with self.connection.cursor() as cursor: + cursor.execute( + query=''' + UPDATE locks + SET expire_at = %s + WHERE name = %s AND expire_at > %s; + ''', + params=( + now + ttl, + self.name, + now, + ), + ) + + return cursor.rowcount == 1 + + def get_ttl( + self, + ) -> typing.Optional[int]: + now = time.time() + + with self.connection.cursor() as cursor: + cursor.execute( + query=''' + SELECT expire_at FROM locks + WHERE name = %s; + ''', + params=( + self.name, + ), + ) + + result = cursor.fetchone() + if result is None: + return None + + expire_at: float = float(result[0]) + if expire_at <= now: + return None + else: + return math.ceil(expire_at - now) + + def __del__( + self, + ) -> None: + self.release() + + +class Connector( + _connector.Connector, +): + def __init__( + self, + connection_strings: typing.List[str], + ) -> None: + self.connections: typing.List[psycopg.Connection] = [] + + for connection_string in connection_strings: + connection = psycopg.connect( + conninfo=connection_string, + autocommit=True, + ) + with connection.cursor() as cursor: + try: + cursor.execute( + query=''' + CREATE DATABASE sergeant WITH ENCODING 'UTF8'; + ''' + ) + except psycopg.errors.DuplicateDatabase: + pass + + cursor.execute( + query=''' + CREATE TABLE IF NOT EXISTS task_queue (id bigserial, queue_name text, priority decimal, value bytea); + CREATE INDEX IF NOT EXISTS queue_by_priority ON task_queue (queue_name, priority); + CREATE INDEX IF NOT EXISTS id ON task_queue (id); + + CREATE TABLE IF NOT EXISTS keys (name text, value bytea); + CREATE UNIQUE INDEX IF NOT EXISTS key_by_name ON keys (name); + + CREATE TABLE IF NOT EXISTS locks (name text, expire_at decimal); + CREATE UNIQUE INDEX IF NOT EXISTS lock_by_name ON locks (name); + CREATE INDEX IF NOT EXISTS lock_by_expire_at ON locks (expire_at); + ''' + ) + + self.connections.append(connection) + + self.number_of_connections = len(self.connections) + self.current_connection_index = random.randint(0, self.number_of_connections - 1) + + @property + def next_connection( + self, + ) -> psycopg.Connection: + current_connection = self.connections[self.current_connection_index] + self.current_connection_index = (self.current_connection_index + 1) % self.number_of_connections + + return current_connection + + def key_set( + self, + key: str, + value: bytes, + ) -> bool: + key_server_location = binascii.crc32(key.encode()) % self.number_of_connections + key_server_connection = self.connections[key_server_location] + + with key_server_connection.cursor() as cursor: + try: + cursor.execute( + query=''' + INSERT INTO keys (name, value) + VALUES(%s, %s); + ''', + params=( + key, + value, + ), + ) + + return True + except psycopg.errors.UniqueViolation: + return False + + def key_get( + self, + key: str, + ) -> typing.Optional[bytes]: + key_server_location = binascii.crc32(key.encode()) % self.number_of_connections + key_server_connection = self.connections[key_server_location] + + with key_server_connection.cursor() as cursor: + cursor.execute( + query=''' + SELECT value FROM keys WHERE name = %s; + ''', + params=( + key, + ), + ) + + result = cursor.fetchone() + if result is None: + return None + else: + return result[0] + + def key_delete( + self, + key: str, + ) -> bool: + key_server_location = binascii.crc32(key.encode()) % self.number_of_connections + key_server_connection = self.connections[key_server_location] + + with key_server_connection.cursor() as cursor: + cursor.execute( + query=''' + DELETE FROM keys WHERE name = %s; + ''', + params=( + key, + ), + ) + + return cursor.rowcount == 1 + + def queue_pop( + self, + queue_name: str, + ) -> typing.Optional[bytes]: + for i in range(self.number_of_connections): + with self.next_connection.cursor( + row_factory=psycopg.rows.args_row( + func=lambda value: value, + ), + ) as cursor: + cursor.execute( + query=''' + DELETE FROM task_queue + WHERE id = any( + array( + SELECT id FROM task_queue + WHERE queue_name = %s AND priority <= %s + ORDER BY priority ASC + LIMIT 1 + ) + ) + RETURNING value; + ''', + params=( + queue_name, + time.time(), + ), + ) + + document = cursor.fetchone() + if document: + return document + + return None + + def queue_pop_bulk( + self, + queue_name: str, + number_of_items: int, + ) -> typing.List[bytes]: + values = [] + current_count = number_of_items + + for i in range(self.number_of_connections): + with self.next_connection.cursor( + row_factory=psycopg.rows.args_row( + func=lambda value: value, + ), + ) as cursor: + cursor.execute( + query=''' + DELETE FROM task_queue + WHERE id = any( + array( + SELECT id FROM task_queue + WHERE queue_name = %s AND priority <= %s + ORDER BY priority ASC + LIMIT %s + ) + ) + RETURNING value; + ''', + params=( + queue_name, + time.time(), + current_count, + ), + ) + + values += cursor.fetchall() + if len(values) == number_of_items: + return values + + current_count = number_of_items - len(values) + + return values + + def queue_push( + self, + queue_name: str, + item: bytes, + priority: str = 'NORMAL', + consumable_from: typing.Optional[float] = None, + ) -> bool: + if consumable_from is not None: + priority_value = consumable_from + elif priority == 'HIGH': + priority_value = 0.0 + elif priority == 'NORMAL': + priority_value = 1.0 + else: + priority_value = 1.0 + + with self.next_connection.cursor() as cursor: + cursor.execute( + query=''' + INSERT INTO task_queue (queue_name, priority, value) + VALUES(%s, %s, %s); + ''', + params=( + queue_name, + priority_value, + item, + ), + ) + + return cursor.rowcount == 0 + + def queue_push_bulk( + self, + queue_name: str, + items: typing.Iterable[bytes], + priority: str = 'NORMAL', + consumable_from: typing.Optional[float] = None, + ) -> bool: + if consumable_from is not None: + priority_value = consumable_from + elif priority == 'HIGH': + priority_value = 0.0 + elif priority == 'NORMAL': + priority_value = 1.0 + else: + priority_value = 1.0 + + with self.next_connection.cursor() as cursor: + values = ( + f'(\'{queue_name}\', {priority_value}, \'\\x{item.hex()}\')' + for item in items + ) + cursor.execute( + query=f''' + INSERT INTO task_queue (queue_name, priority, value) + VALUES { + ','.join(values) + }; + ''', + ) + + return cursor.rowcount > 0 + + def queue_length( + self, + queue_name: str, + include_delayed: bool, + ) -> int: + queue_length = 0 + + for connection in self.connections: + with connection.cursor() as cursor: + if include_delayed: + cursor.execute( + query=''' + SELECT COUNT(*) FROM task_queue + WHERE queue_name = %s; + ''', + params=( + queue_name, + ), + ) + else: + cursor.execute( + query=''' + SELECT COUNT(*) FROM task_queue + WHERE queue_name = %s AND priority <= %s; + ''', + params=( + queue_name, + time.time(), + ), + ) + + result = cursor.fetchone() + if result: + queue_length += result[0] + + return queue_length + + def queue_delete( + self, + queue_name: str, + ) -> bool: + deleted_count = 0 + + for connection in self.connections: + with connection.cursor() as cursor: + cursor.execute( + query=''' + DELETE FROM task_queue WHERE queue_name = %s; + ''', + params=( + queue_name, + ), + ) + + deleted_count += cursor.rowcount + + return deleted_count > 0 + + def lock( + self, + name: str, + ) -> Lock: + lock_server_location = binascii.crc32(name.encode()) % self.number_of_connections + lock_server_connection = self.connections[lock_server_location] + + return Lock( + connection=lock_server_connection, + name=name, + ) + + def __del__( + self, + ) -> None: + for connection in self.connections: + connection.close() diff --git a/sergeant/worker.py b/sergeant/worker.py index 52c367c..ebc5532 100644 --- a/sergeant/worker.py +++ b/sergeant/worker.py @@ -107,6 +107,8 @@ def init_broker( connector_obj: connector.Connector if self.config.connector.type == 'mongo': connector_obj = connector.mongo.Connector(**self.config.connector.params) + elif self.config.connector.type == 'postgres': + connector_obj = connector.postgres.Connector(**self.config.connector.params) elif self.config.connector.type == 'redis': connector_obj = connector.redis.Connector(**self.config.connector.params) elif self.config.connector.type == 'local': diff --git a/tests/broker/test_broker.py b/tests/broker/test_broker.py index ccd0ac4..220189f 100644 --- a/tests/broker/test_broker.py +++ b/tests/broker/test_broker.py @@ -807,3 +807,86 @@ def setUpClass( encoder=encoder_obj, ) ) + + +class PostgresSingleServerBrokerTestCase( + BrokerTestCase, +): + __test__ = True + + @classmethod + def setUpClass( + cls, + ): + cls.test_brokers = [] + connector_obj = sergeant.connector.postgres.Connector( + connection_strings=[ + 'postgresql://postgres:mysecretpassword@127.0.0.1:5432/', + ] + ) + + cls.test_broker = sergeant.broker.Broker( + connector=connector_obj, + encoder=sergeant.encoder.encoder.Encoder( + compressor_name=None, + serializer_name='pickle', + ), + ) + + compressor_names = list(sergeant.encoder.encoder.Encoder.compressors.keys()) + compressor_names.append(None) + serializer_names = sergeant.encoder.encoder.Encoder.serializers.keys() + for compressor_name in compressor_names: + for serializer_name in serializer_names: + encoder_obj = sergeant.encoder.encoder.Encoder( + compressor_name=compressor_name, + serializer_name=serializer_name, + ) + cls.test_brokers.append( + sergeant.broker.Broker( + connector=connector_obj, + encoder=encoder_obj, + ) + ) + + +class PostgresMultipleServersBrokerTestCase( + BrokerTestCase, +): + __test__ = True + + @classmethod + def setUpClass( + cls, + ): + cls.test_brokers = [] + connector_obj = sergeant.connector.postgres.Connector( + connection_strings=[ + 'postgresql://postgres:mysecretpassword@127.0.0.1:5432/', + 'postgresql://postgres:mysecretpassword@127.0.0.1:5433/', + ] + ) + + cls.test_broker = sergeant.broker.Broker( + connector=connector_obj, + encoder=sergeant.encoder.encoder.Encoder( + compressor_name=None, + serializer_name='pickle', + ), + ) + + compressor_names = list(sergeant.encoder.encoder.Encoder.compressors.keys()) + compressor_names.append(None) + serializer_names = sergeant.encoder.encoder.Encoder.serializers.keys() + for compressor_name in compressor_names: + for serializer_name in serializer_names: + encoder_obj = sergeant.encoder.encoder.Encoder( + compressor_name=compressor_name, + serializer_name=serializer_name, + ) + cls.test_brokers.append( + sergeant.broker.Broker( + connector=connector_obj, + encoder=encoder_obj, + ) + ) diff --git a/tests/connector/test_connector.py b/tests/connector/test_connector.py index 98797db..30dd36f 100644 --- a/tests/connector/test_connector.py +++ b/tests/connector/test_connector.py @@ -728,7 +728,7 @@ def setUp( self.connector = sergeant.connector.mongo.Connector( nodes=[ { - 'host': 'localhost', + 'host': 'mongo-node-one', 'port': 27017, 'replica_set': 'test_replica_set', }, @@ -747,12 +747,12 @@ def setUp( self.connector = sergeant.connector.mongo.Connector( nodes=[ { - 'host': 'localhost', + 'host': 'mongo-node-one', 'port': 27017, 'replica_set': 'test_replica_set', }, { - 'host': 'localhost', + 'host': 'mongo-node-two', 'port': 27018, 'replica_set': 'test_replica_set', }, @@ -771,3 +771,34 @@ def setUp( self.connector = sergeant.connector.local.Connector( file_path='/tmp/test.sqlite3', ) + + +class PostgresSingleServerConnectorTestCase( + ConnectorTestCase, +): + __test__ = True + + def setUp( + self, + ): + self.connector = sergeant.connector.postgres.Connector( + connection_strings=[ + 'postgresql://postgres:mysecretpassword@127.0.0.1:5432/', + ] + ) + + +class PostgresMultipleServersConnectorTestCase( + ConnectorTestCase, +): + __test__ = True + + def setUp( + self, + ): + self.connector = sergeant.connector.postgres.Connector( + connection_strings=[ + 'postgresql://postgres:mysecretpassword@127.0.0.1:5432/', + 'postgresql://postgres:mysecretpassword@127.0.0.1:5433/', + ] + )