diff --git a/CHANGELOG.md b/CHANGELOG.md index bdc66c8..27a04d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,10 @@ # Changelog +## v1.5.0 11/15/24 +- Use executemany instead of execute when appropriate in PostgreSQLClient +- Add capability to retry connecting to a database to the MySQL, PostgreSQL, and Redshift clients +- Automatically close database connection upon error in the MySQL, PostgreSQL, and Redshift clients +- Delete old PostgreSQLPoolClient, which was not production ready + ## v1.4.0 9/23/24 - Added SFTP client diff --git a/README.md b/README.md index f695d8b..deb9698 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,6 @@ This package contains common Python utility classes and functions. * Downloading files from a remote SSH SFTP server * Connecting to and querying a MySQL database * Connecting to and querying a PostgreSQL database -* Connecting to and querying a PostgreSQL database using a connection pool * Connecting to and querying Redshift * Making requests to the Oauth2 authenticated APIs such as NYPL Platform API and Sierra @@ -37,7 +36,7 @@ kinesis_client = KinesisClient(...) # Do not use any version below 1.0.0 # All available optional dependencies can be found in pyproject.toml. # See the "Managing dependencies" section below for more details. -nypl-py-utils[kinesis-client,config-helper]==1.4.0 +nypl-py-utils[kinesis-client,config-helper]==1.5.0 ``` ## Developing locally @@ -63,7 +62,7 @@ The optional dependency sets also give the developer the option to manually list ### Using PostgreSQLClient in an AWS Lambda Because `psycopg` requires a statically linked version of the `libpq` library, the `PostgreSQLClient` cannot be installed as-is in an AWS Lambda function. Instead, it must be packaged as follows: ```bash -pip install --target ./package nypl-py-utils[postgresql-client]==1.4.0 +pip install --target ./package nypl-py-utils[postgresql-client]==1.5.0 pip install \ --platform manylinux2014_x86_64 \ diff --git a/pyproject.toml b/pyproject.toml index ed151cd..fbd2558 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "nypl_py_utils" -version = "1.4.0" +version = "1.5.0" authors = [ { name="Aaron Friedman", email="aaronfriedman@nypl.org" }, ] @@ -45,9 +45,6 @@ oauth2-api-client = [ postgresql-client = [ "psycopg[binary]>=3.1.6" ] -postgresql-pool-client = [ - "psycopg[binary,pool]>=3.1.6" -] redshift-client = [ "botocore>=1.29.5", "redshift-connector>=2.0.909" @@ -74,7 +71,7 @@ research-catalog-identifier-helper = [ "requests>=2.28.1" ] development = [ - "nypl_py_utils[avro-client,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,postgresql-pool-client,redshift-client,s3-client,secrets-manager-client,sftp-client,config-helper,obfuscation-helper,research-catalog-identifier-helper]", + "nypl_py_utils[avro-client,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,redshift-client,s3-client,secrets-manager-client,sftp-client,config-helper,obfuscation-helper,research-catalog-identifier-helper]", "flake8>=6.0.0", "freezegun>=1.2.2", "mock>=4.0.3", diff --git a/src/nypl_py_utils/classes/mysql_client.py b/src/nypl_py_utils/classes/mysql_client.py index 94bb3c7..a755d5b 100644 --- a/src/nypl_py_utils/classes/mysql_client.py +++ b/src/nypl_py_utils/classes/mysql_client.py @@ -1,4 +1,5 @@ import mysql.connector +import time from nypl_py_utils.functions.log_helper import create_log @@ -15,35 +16,50 @@ def __init__(self, host, port, database, user, password): self.user = user self.password = password - def connect(self, **kwargs): + def connect(self, retry_count=0, backoff_factor=5, **kwargs): """ Connects to a MySQL database using the given credentials. - Keyword args can be passed into the connection to set certain options. - All possible arguments can be found here: - https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html. - - Common arguments include: - autocommit: bool - Whether to automatically commit each query rather than running - them as part of a transaction. By default False. + Parameters + ---------- + retry_count: int, optional + The number of times to retry connecting before throwing an error. + By default no retry occurs. + backoff_factor: int, optional + The backoff factor when retrying. The amount of time to wait before + retrying is backoff_factor ** number_of_retries_made. + kwargs: + All possible arguments can be found here: + https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html """ self.logger.info('Connecting to {} database'.format(self.database)) - try: - self.conn = mysql.connector.connect( - host=self.host, - port=self.port, - database=self.database, - user=self.user, - password=self.password, - **kwargs) - except mysql.connector.Error as e: - self.logger.error( - 'Error connecting to {name} database: {error}'.format( - name=self.database, error=e)) - raise MySQLClientError( - 'Error connecting to {name} database: {error}'.format( - name=self.database, error=e)) from None + attempt_count = 0 + while attempt_count <= retry_count: + try: + try: + self.conn = mysql.connector.connect( + host=self.host, + port=self.port, + database=self.database, + user=self.user, + password=self.password, + **kwargs) + except (mysql.connector.Error): + if attempt_count < retry_count: + self.logger.info('Failed to connect -- retrying') + time.sleep(backoff_factor ** attempt_count) + attempt_count += 1 + else: + raise + else: + break + except Exception as e: + self.logger.error( + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) + raise MySQLClientError( + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) from None def execute_query(self, query, query_params=None, **kwargs): """ @@ -83,6 +99,8 @@ def execute_query(self, query, query_params=None, **kwargs): return cursor.fetchall() except Exception as e: self.conn.rollback() + cursor.close() + self.close_connection() self.logger.error( ('Error executing {name} database query \'{query}\': {error}') .format(name=self.database, query=query, error=e)) diff --git a/src/nypl_py_utils/classes/postgresql_client.py b/src/nypl_py_utils/classes/postgresql_client.py index 05c7a97..82c6b9e 100644 --- a/src/nypl_py_utils/classes/postgresql_client.py +++ b/src/nypl_py_utils/classes/postgresql_client.py @@ -1,4 +1,5 @@ import psycopg +import time from nypl_py_utils.functions.log_helper import create_log @@ -6,43 +7,54 @@ class PostgreSQLClient: """Client for managing individual connections to a PostgreSQL database""" - def __init__(self, host, port, db_name, user, password): + def __init__(self, host, port, database, user, password): self.logger = create_log('postgresql_client') self.conn = None self.conn_info = ('postgresql://{user}:{password}@{host}:{port}/' - '{db_name}').format(user=user, password=password, - host=host, port=port, - db_name=db_name) + '{database}').format(user=user, password=password, + host=host, port=port, + database=database) + self.database = database - self.db_name = db_name - - def connect(self, **kwargs): + def connect(self, retry_count=0, backoff_factor=5, **kwargs): """ Connects to a PostgreSQL database using the given credentials. - Keyword args can be passed into the connection to set certain options. - All possible arguments can be found here: - https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect. - - Common arguments include: - autocommit: bool - Whether to automatically commit each query rather than running - them as part of a transaction. By default False. - row_factory: RowFactory - A psycopg RowFactory that determines how the data will be - returned. Defaults to tuple_row, which returns the rows as a - list of tuples. + Parameters + ---------- + retry_count: int, optional + The number of times to retry connecting before throwing an error. + By default no retry occurs. + backoff_factor: int, optional + The backoff factor when retrying. The amount of time to wait before + retrying is backoff_factor ** number_of_retries_made. + kwargs: + All possible arguments (such as the row_factory) can be found here: + https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect """ - self.logger.info('Connecting to {} database'.format(self.db_name)) - try: - self.conn = psycopg.connect(self.conn_info, **kwargs) - except psycopg.Error as e: - self.logger.error( - 'Error connecting to {name} database: {error}'.format( - name=self.db_name, error=e)) - raise PostgreSQLClientError( - 'Error connecting to {name} database: {error}'.format( - name=self.db_name, error=e)) from None + self.logger.info('Connecting to {} database'.format(self.database)) + attempt_count = 0 + while attempt_count <= retry_count: + try: + try: + self.conn = psycopg.connect(self.conn_info, **kwargs) + except (psycopg.OperationalError, + psycopg.errors.ConnectionTimeout): + if attempt_count < retry_count: + self.logger.info('Failed to connect -- retrying') + time.sleep(backoff_factor ** attempt_count) + attempt_count += 1 + else: + raise + else: + break + except Exception as e: + self.logger.error( + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) + raise PostgreSQLClientError( + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) from None def execute_query(self, query, query_params=None, **kwargs): """ @@ -53,7 +65,11 @@ def execute_query(self, query, query_params=None, **kwargs): query: str The query to execute query_params: sequence, optional - The values to be used in a parameterized query + The values to be used in a parameterized query. The values can be + for a single insert query -- e.g. execute_query( + "INSERT INTO x VALUES (%s, %s)", (1, "a")) + or for multiple -- e.g execute_transaction( + "INSERT INTO x VALUES (%s, %s)", [(1, "a"), (2, "b")]) kwargs: All possible arguments can be found here: https://www.psycopg.org/psycopg3/docs/api/cursors.html#psycopg.Cursor.execute @@ -65,30 +81,38 @@ def execute_query(self, query, query_params=None, **kwargs): based on the connection's row_factory if there's something to return (even if the result set is empty). """ - self.logger.info('Querying {} database'.format(self.db_name)) + self.logger.info('Querying {} database'.format(self.database)) self.logger.debug('Executing query {}'.format(query)) try: cursor = self.conn.cursor() - cursor.execute(query, query_params, **kwargs) + if query_params is not None and all( + isinstance(param, tuple) or isinstance(param, list) + for param in query_params + ): + cursor.executemany(query, query_params, **kwargs) + else: + cursor.execute(query, query_params, **kwargs) self.conn.commit() return None if cursor.description is None else cursor.fetchall() except Exception as e: self.conn.rollback() + cursor.close() + self.close_connection() self.logger.error( ('Error executing {name} database query \'{query}\': ' '{error}').format( - name=self.db_name, query=query, error=e)) + name=self.database, query=query, error=e)) raise PostgreSQLClientError( ('Error executing {name} database query \'{query}\': ' '{error}').format( - name=self.db_name, query=query, error=e)) from None + name=self.database, query=query, error=e)) from None finally: cursor.close() def close_connection(self): """Closes the database connection""" self.logger.debug('Closing {} database connection'.format( - self.db_name)) + self.database)) self.conn.close() diff --git a/src/nypl_py_utils/classes/postgresql_pool_client.py b/src/nypl_py_utils/classes/postgresql_pool_client.py deleted file mode 100644 index beaf589..0000000 --- a/src/nypl_py_utils/classes/postgresql_pool_client.py +++ /dev/null @@ -1,137 +0,0 @@ -import psycopg - -from nypl_py_utils.functions.log_helper import create_log -from psycopg.rows import tuple_row -from psycopg_pool import ConnectionPool - - -class PostgreSQLPoolClient: - """Client for managing a connection pool to a PostgreSQL database""" - - def __init__(self, host, port, db_name, user, password, conn_timeout=300.0, - **kwargs): - """ - Creates (but does not open) a connection pool. - - Parameters - ---------- - host, port, db_name, user, password: str - Required connection information - kwargs: dict, optional - Keyword args to be passed into the ConnectionPool. All possible - arguments can be found here: - https://www.psycopg.org/psycopg3/docs/api/pool.html#psycopg_pool.ConnectionPool. - - Common arguments include: - min_size/max_size: The minimum and maximum size of the pool, by - default 0 and 1 - max_idle: Half the number of seconds a connection can stay idle - before being automatically closed, by default 90.0, which - corresponds to 3 minutes of idle time. Note that if - min_size is greater than 0, this won't apply to the first - min_size connections, which will stay open until manually - closed. - """ - self.logger = create_log('postgresql_client') - self.conn_info = ('postgresql://{user}:{password}@{host}:{port}/' - '{db_name}').format(user=user, password=password, - host=host, port=port, - db_name=db_name) - - self.db_name = db_name - self.kwargs = kwargs - self.kwargs['min_size'] = kwargs.get('min_size', 0) - self.kwargs['max_size'] = kwargs.get('max_size', 1) - self.kwargs['max_idle'] = kwargs.get('max_idle', 90.0) - - if self.kwargs['max_idle'] > 150.0: - self.logger.error(( - 'max_idle is too high -- values over 150 seconds are unsafe ' - 'and may lead to connection leakages in ECS')) - raise PostgreSQLPoolClientError(( - 'max_idle is too high -- values over 150 seconds are unsafe ' - 'and may lead to connection leakages in ECS')) from None - - self.pool = ConnectionPool(self.conn_info, open=False, **self.kwargs) - - def connect(self, timeout=300.0): - """ - Opens the connection pool and connects to the given PostgreSQL database - min_size number of times - - Parameters - ---------- - conn_timeout: float, optional - The number of seconds to try connecting before throwing an error. - Defaults to 300 seconds. - """ - self.logger.info('Connecting to {} database'.format(self.db_name)) - try: - if self.pool is None: - self.pool = ConnectionPool( - self.conn_info, open=False, **self.kwargs) - self.pool.open(wait=True, timeout=timeout) - except psycopg.Error as e: - self.logger.error( - 'Error connecting to {name} database: {error}'.format( - name=self.db_name, error=e)) - raise PostgreSQLPoolClientError( - 'Error connecting to {name} database: {error}'.format( - name=self.db_name, error=e)) from None - - def execute_query(self, query, query_params=None, row_factory=tuple_row, - **kwargs): - """ - Requests a connection from the pool and uses it to execute an arbitrary - query. After the query is complete, either commits it or rolls it back, - and then returns the connection to the pool. - - Parameters - ---------- - query: str - The query to execute - query_params: sequence, optional - The values to be used in a parameterized query - row_factory: RowFactory, optional - A psycopg RowFactory that determines how the data will be returned. - Defaults to tuple_row, which returns the rows as a list of tuples. - kwargs: - All possible arguments can be found here: - https://www.psycopg.org/psycopg3/docs/api/cursors.html#psycopg.Cursor.execute - - Returns - ------- - None or sequence - None if the cursor has nothing to return. Some type of sequence - based on the row_factory input if there's something to return - (even if the result set is empty). - """ - self.logger.info('Querying {} database'.format(self.db_name)) - self.logger.debug('Executing query {}'.format(query)) - with self.pool.connection() as conn: - try: - conn.row_factory = row_factory - cursor = conn.execute(query, query_params, **kwargs) - return (None if cursor.description is None - else cursor.fetchall()) - except Exception as e: - self.logger.error( - ('Error executing {name} database query \'{query}\': ' - '{error}').format( - name=self.db_name, query=query, error=e)) - raise PostgreSQLPoolClientError( - ('Error executing {name} database query \'{query}\': ' - '{error}').format( - name=self.db_name, query=query, error=e)) from None - - def close_pool(self): - """Closes the connection pool""" - self.logger.debug('Closing {} database connection pool'.format( - self.db_name)) - self.pool.close() - self.pool = None - - -class PostgreSQLPoolClientError(Exception): - def __init__(self, message=None): - self.message = message diff --git a/src/nypl_py_utils/classes/redshift_client.py b/src/nypl_py_utils/classes/redshift_client.py index 17c4558..9f594ef 100644 --- a/src/nypl_py_utils/classes/redshift_client.py +++ b/src/nypl_py_utils/classes/redshift_client.py @@ -1,6 +1,6 @@ import redshift_connector +import time -from botocore.exceptions import ClientError from nypl_py_utils.functions.log_helper import create_log @@ -15,23 +15,46 @@ def __init__(self, host, database, user, password): self.user = user self.password = password - def connect(self): - """Connects to a Redshift database using the given credentials""" + def connect(self, retry_count=0, backoff_factor=5): + """ + Connects to a Redshift database using the given credentials. + + Parameters + ---------- + retry_count: int, optional + The number of times to retry connecting before throwing an error. + By default no retry occurs. + backoff_factor: int, optional + The backoff factor when retrying. The amount of time to wait before + retrying is backoff_factor ** number_of_retries_made. + """ self.logger.info('Connecting to {} database'.format(self.database)) - try: - self.conn = redshift_connector.connect( - host=self.host, - database=self.database, - user=self.user, - password=self.password, - sslmode='verify-full') - except ClientError as e: - self.logger.error( - 'Error connecting to {name} database: {error}'.format( - name=self.database, error=e)) - raise RedshiftClientError( - 'Error connecting to {name} database: {error}'.format( - name=self.database, error=e)) from None + attempt_count = 0 + while attempt_count <= retry_count: + try: + try: + self.conn = redshift_connector.connect( + host=self.host, + database=self.database, + user=self.user, + password=self.password, + sslmode='verify-full') + except (redshift_connector.InterfaceError): + if attempt_count < retry_count: + self.logger.info('Failed to connect -- retrying') + time.sleep(backoff_factor ** attempt_count) + attempt_count += 1 + else: + raise + else: + break + except Exception as e: + self.logger.error( + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) + raise RedshiftClientError( + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) from None def execute_query(self, query, dataframe=False): """ @@ -62,6 +85,8 @@ def execute_query(self, query, dataframe=False): return cursor.fetchall() except Exception as e: self.conn.rollback() + cursor.close() + self.close_connection() self.logger.error( ('Error executing {name} database query \'{query}\': {error}') .format(name=self.database, query=query, error=e)) @@ -83,10 +108,9 @@ def execute_transaction(self, queries): A list of tuples containing a query and the values to be used if the query is parameterized (or None if it's not). The values can be for a single insert query -- e.g. execute_transaction( - "INSERT INTO x VALUES (%s, %s)", (1, "a")) + [("INSERT INTO x VALUES (%s, %s)", (1, "a"))]) or for multiple -- e.g execute_transaction( - "INSERT INTO x VALUES (%s, %s)", [(1, "a"), (2, "b")]) - + [("INSERT INTO x VALUES (%s, %s)", [(1, "a"), (2, "b")])]) """ self.logger.info('Executing transaction against {} database'.format( self.database)) @@ -106,6 +130,8 @@ def execute_transaction(self, queries): self.conn.commit() except Exception as e: self.conn.rollback() + cursor.close() + self.close_connection() self.logger.error( ('Error executing {name} database transaction: {error}') .format(name=self.database, error=e)) diff --git a/tests/test_mysql_client.py b/tests/test_mysql_client.py index a1f8a87..39508f2 100644 --- a/tests/test_mysql_client.py +++ b/tests/test_mysql_client.py @@ -1,3 +1,4 @@ +import mysql.connector import pytest from nypl_py_utils.classes.mysql_client import MySQLClient, MySQLClientError @@ -22,6 +23,21 @@ def test_connect(self, mock_mysql_conn, test_instance): user='test_user', password='test_password') + def test_connect_retry_success(self, mock_mysql_conn, test_instance, + mocker): + mock_mysql_conn.side_effect = [mysql.connector.Error, + mocker.MagicMock()] + test_instance.connect(retry_count=2, backoff_factor=0) + assert mock_mysql_conn.call_count == 2 + + def test_connect_retry_fail(self, mock_mysql_conn, test_instance): + mock_mysql_conn.side_effect = mysql.connector.Error + + with pytest.raises(MySQLClientError): + test_instance.connect(retry_count=2, backoff_factor=0) + + assert mock_mysql_conn.call_count == 3 + def test_execute_read_query(self, mock_mysql_conn, test_instance, mocker): test_instance.connect() @@ -75,7 +91,8 @@ def test_execute_query_with_exception( test_instance.execute_query('test query') test_instance.conn.rollback.assert_called_once() - mock_cursor.close.assert_called_once() + mock_cursor.close.assert_called() + test_instance.conn.close.assert_called_once() def test_close_connection(self, mock_mysql_conn, test_instance): test_instance.connect() diff --git a/tests/test_postgresql_client.py b/tests/test_postgresql_client.py index 99e5042..af93625 100644 --- a/tests/test_postgresql_client.py +++ b/tests/test_postgresql_client.py @@ -2,6 +2,7 @@ from nypl_py_utils.classes.postgresql_client import ( PostgreSQLClient, PostgreSQLClientError) +from psycopg import OperationalError class TestPostgreSQLClient: @@ -12,14 +13,27 @@ def mock_pg_conn(self, mocker): @pytest.fixture def test_instance(self): - return PostgreSQLClient('test_host', 'test_port', 'test_db_name', + return PostgreSQLClient('test_host', 'test_port', 'test_database', 'test_user', 'test_password') def test_connect(self, mock_pg_conn, test_instance): test_instance.connect() mock_pg_conn.assert_called_once_with( 'postgresql://test_user:test_password@test_host:test_port/' + - 'test_db_name') + 'test_database') + + def test_connect_retry_success(self, mock_pg_conn, test_instance, mocker): + mock_pg_conn.side_effect = [OperationalError(), mocker.MagicMock()] + test_instance.connect(retry_count=2, backoff_factor=0) + assert mock_pg_conn.call_count == 2 + + def test_connect_retry_fail(self, mock_pg_conn, test_instance): + mock_pg_conn.side_effect = OperationalError() + + with pytest.raises(PostgreSQLClientError): + test_instance.connect(retry_count=2, backoff_factor=0) + + assert mock_pg_conn.call_count == 3 def test_execute_read_query(self, mock_pg_conn, test_instance, mocker): test_instance.connect() @@ -63,6 +77,22 @@ def test_execute_write_query_with_params(self, mock_pg_conn, test_instance, test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() + def test_execute_write_query_with_many_params( + self, mock_pg_conn, test_instance, mocker): + test_instance.connect() + + mock_cursor = mocker.MagicMock() + mock_cursor.description = None + test_instance.conn.cursor.return_value = mock_cursor + + assert test_instance.execute_query( + 'test query %s %s', query_params=[('a', 1), ('b', None), (None, 2)] + ) is None + mock_cursor.executemany.assert_called_once_with( + 'test query %s %s', [('a', 1), ('b', None), (None, 2)]) + test_instance.conn.commit.assert_called_once() + mock_cursor.close.assert_called_once() + def test_execute_query_with_exception( self, mock_pg_conn, test_instance, mocker): test_instance.connect() @@ -75,7 +105,8 @@ def test_execute_query_with_exception( test_instance.execute_query('test query') test_instance.conn.rollback.assert_called_once() - mock_cursor.close.assert_called_once() + mock_cursor.close.assert_called() + test_instance.conn.close.assert_called_once() def test_close_connection(self, mock_pg_conn, test_instance): test_instance.connect() diff --git a/tests/test_postgresql_pool_client.py b/tests/test_postgresql_pool_client.py deleted file mode 100644 index 82f22b6..0000000 --- a/tests/test_postgresql_pool_client.py +++ /dev/null @@ -1,121 +0,0 @@ -import pytest - -from nypl_py_utils.classes.postgresql_pool_client import ( - PostgreSQLPoolClient, PostgreSQLPoolClientError) -from psycopg import Error - - -class TestPostgreSQLPoolClient: - - @pytest.fixture - def test_instance(self, mocker): - mocker.patch('psycopg_pool.ConnectionPool.open') - mocker.patch('psycopg_pool.ConnectionPool.close') - return PostgreSQLPoolClient('test_host', 'test_port', 'test_db_name', - 'test_user', 'test_password') - - def test_init(self, test_instance): - assert test_instance.pool.conninfo == ( - 'postgresql://test_user:test_password@test_host:test_port/' + - 'test_db_name') - assert test_instance.pool._opened is False - assert test_instance.pool.min_size == 0 - assert test_instance.pool.max_size == 1 - - def test_init_with_long_max_idle(self): - with pytest.raises(PostgreSQLPoolClientError): - PostgreSQLPoolClient( - 'test_host', 'test_port', 'test_db_name', 'test_user', - 'test_password', max_idle=300.0) - - def test_connect(self, test_instance): - test_instance.connect() - test_instance.pool.open.assert_called_once_with(wait=True, - timeout=300.0) - - def test_connect_with_exception(self, mocker): - mocker.patch('psycopg_pool.ConnectionPool.open', - side_effect=Error()) - - test_instance = PostgreSQLPoolClient( - 'test_host', 'test_port', 'test_db_name', 'test_user', - 'test_password') - - with pytest.raises(PostgreSQLPoolClientError): - test_instance.connect(timeout=1.0) - - def test_execute_read_query(self, test_instance, mocker): - test_instance.connect() - - mock_cursor = mocker.MagicMock() - mock_cursor.description = [('description', None, None)] - mock_cursor.fetchall.return_value = [(1, 2, 3), ('a', 'b', 'c')] - mock_conn = mocker.MagicMock() - mock_conn.execute.return_value = mock_cursor - mock_conn_context = mocker.MagicMock() - mock_conn_context.__enter__.return_value = mock_conn - mocker.patch('psycopg_pool.ConnectionPool.connection', - return_value=mock_conn_context) - - assert test_instance.execute_query( - 'test query') == [(1, 2, 3), ('a', 'b', 'c')] - mock_conn.execute.assert_called_once_with('test query', None) - mock_cursor.fetchall.assert_called_once() - - def test_execute_write_query(self, test_instance, mocker): - test_instance.connect() - - mock_cursor = mocker.MagicMock() - mock_cursor.description = None - mock_conn = mocker.MagicMock() - mock_conn.execute.return_value = mock_cursor - mock_conn_context = mocker.MagicMock() - mock_conn_context.__enter__.return_value = mock_conn - mocker.patch('psycopg_pool.ConnectionPool.connection', - return_value=mock_conn_context) - - assert test_instance.execute_query('test query') is None - mock_conn.execute.assert_called_once_with('test query', None) - - def test_execute_write_query_with_params(self, test_instance, mocker): - test_instance.connect() - - mock_cursor = mocker.MagicMock() - mock_cursor.description = None - mock_conn = mocker.MagicMock() - mock_conn.execute.return_value = mock_cursor - mock_conn_context = mocker.MagicMock() - mock_conn_context.__enter__.return_value = mock_conn - mocker.patch('psycopg_pool.ConnectionPool.connection', - return_value=mock_conn_context) - - assert test_instance.execute_query( - 'test query %s %s', query_params=('a', 1)) is None - mock_conn.execute.assert_called_once_with('test query %s %s', - ('a', 1)) - - def test_execute_query_with_exception(self, test_instance, mocker): - test_instance.connect() - - mock_conn = mocker.MagicMock() - mock_conn.execute.side_effect = Exception() - mock_conn_context = mocker.MagicMock() - mock_conn_context.__enter__.return_value = mock_conn - mocker.patch('psycopg_pool.ConnectionPool.connection', - return_value=mock_conn_context) - - with pytest.raises(PostgreSQLPoolClientError): - test_instance.execute_query('test query') - - def test_close_pool(self, test_instance): - test_instance.connect() - test_instance.close_pool() - assert test_instance.pool is None - - def test_reopen_pool(self, test_instance, mocker): - test_instance.connect() - test_instance.close_pool() - test_instance.connect() - test_instance.pool.open.assert_has_calls([ - mocker.call(wait=True, timeout=300), - mocker.call(wait=True, timeout=300)]) diff --git a/tests/test_redshift_client.py b/tests/test_redshift_client.py index 7d6219d..e33024e 100644 --- a/tests/test_redshift_client.py +++ b/tests/test_redshift_client.py @@ -2,6 +2,7 @@ from nypl_py_utils.classes.redshift_client import ( RedshiftClient, RedshiftClientError) +from redshift_connector import InterfaceError class TestRedshiftClient: @@ -23,6 +24,20 @@ def test_connect(self, mock_redshift_conn, test_instance): password='test_password', sslmode='verify-full') + def test_connect_retry_success(self, mock_redshift_conn, test_instance, + mocker): + mock_redshift_conn.side_effect = [InterfaceError(), mocker.MagicMock()] + test_instance.connect(retry_count=2, backoff_factor=0) + assert mock_redshift_conn.call_count == 2 + + def test_connect_retry_fail(self, mock_redshift_conn, test_instance): + mock_redshift_conn.side_effect = InterfaceError() + + with pytest.raises(RedshiftClientError): + test_instance.connect(retry_count=2, backoff_factor=0) + + assert mock_redshift_conn.call_count == 3 + def test_execute_query(self, mock_redshift_conn, test_instance, mocker): test_instance.connect() @@ -60,7 +75,8 @@ def test_execute_query_with_exception( test_instance.execute_query('test query') test_instance.conn.rollback.assert_called_once() - mock_cursor.close.assert_called_once() + mock_cursor.close.assert_called() + test_instance.conn.close.assert_called_once() def test_execute_transaction(self, mock_redshift_conn, test_instance, mocker): @@ -119,7 +135,8 @@ def test_execute_transaction_with_exception( mocker.call('query 2', None)]) test_instance.conn.commit.assert_not_called() test_instance.conn.rollback.assert_called_once() - mock_cursor.close.assert_called_once() + mock_cursor.close.assert_called() + test_instance.conn.close.assert_called_once() def test_close_connection(self, mock_redshift_conn, test_instance): test_instance.connect()