From ab740d7e60592a221a4a658c196d2543889aa0c3 Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Wed, 26 Jun 2024 11:04:27 -0500 Subject: [PATCH 01/14] Avro interpreter initial update --- src/nypl_py_utils/classes/avro_encoder.py | 91 +++++++++++++---------- tests/test_avro_encoder.py | 14 ++-- 2 files changed, 59 insertions(+), 46 deletions(-) diff --git a/src/nypl_py_utils/classes/avro_encoder.py b/src/nypl_py_utils/classes/avro_encoder.py index 8ff5229..53dd260 100644 --- a/src/nypl_py_utils/classes/avro_encoder.py +++ b/src/nypl_py_utils/classes/avro_encoder.py @@ -7,17 +7,52 @@ from nypl_py_utils.functions.log_helper import create_log from requests.exceptions import JSONDecodeError, RequestException - -class AvroEncoder: +class AvroInterpreter: """ - Class for encoding records using an Avro schema. Takes as input the + Base class for Avro schema interaction. Takes as input the Platform API endpoint from which to fetch the schema in JSON format. """ - + def __init__(self, platform_schema_url): self.logger = create_log('avro_encoder') self.schema = avro.schema.parse( - self._get_json_schema(platform_schema_url)) + self.get_json_schema(platform_schema_url)) + + def get_json_schema(self, platform_schema_url): + """ + Fetches a JSON response from the input Platform API endpoint and + interprets it as an Avro schema. + """ + self.logger.info('Fetching Avro schema from {}'.format( + platform_schema_url)) + try: + response = requests.get(platform_schema_url) + response.raise_for_status() + except RequestException as e: + self.logger.error( + 'Failed to retrieve schema from {url}: {error}'.format( + url=platform_schema_url, error=e)) + raise AvroInterpreterError( + 'Failed to retrieve schema from {url}: {error}'.format( + url=platform_schema_url, error=e)) from None + + try: + json_response = response.json() + return json_response['data']['schema'] + except (JSONDecodeError, KeyError) as e: + self.logger.error( + 'Retrieved schema is malformed: {errorType} {errorMessage}' + .format(errorType=type(e), errorMessage=e)) + raise AvroInterpreterError( + 'Retrieved schema is malformed: {errorType} {errorMessage}' + .format(errorType=type(e), errorMessage=e)) from None + + +class AvroEncoder(AvroInterpreter): + """ + Class for encoding records using an Avro schema. Takes as input the + Platform API endpoint from which to fetch the schema in JSON format. + """ def encode_record(self, record): """ @@ -36,7 +71,7 @@ def encode_record(self, record): return output_stream.getvalue() except AvroException as e: self.logger.error('Failed to encode record: {}'.format(e)) - raise AvroEncoderError( + raise AvroInterpreterError( 'Failed to encode record: {}'.format(e)) from None def encode_batch(self, record_list): @@ -60,10 +95,17 @@ def encode_batch(self, record_list): output_stream.truncate(0) except AvroException as e: self.logger.error('Failed to encode record: {}'.format(e)) - raise AvroEncoderError( + raise AvroInterpreterError( 'Failed to encode record: {}'.format(e)) from None return encoded_records + +class AvroDecoder(AvroInterpreter): + """ + Class for decoding records using an Avro schema. Takes as input the + Platform API endpoint from which to fetch the schema in JSON format. + """ + def decode_record(self, record): """ Decodes a single record represented as a byte string using the given @@ -80,39 +122,10 @@ def decode_record(self, record): return datum_reader.read(decoder) except Exception as e: self.logger.error('Failed to decode record: {}'.format(e)) - raise AvroEncoderError( + raise AvroInterpreterError( 'Failed to decode record: {}'.format(e)) from None - def _get_json_schema(self, platform_schema_url): - """ - Fetches a JSON response from the input Platform API endpoint and - interprets it as an Avro schema. - """ - self.logger.info('Fetching Avro schema from {}'.format( - platform_schema_url)) - try: - response = requests.get(platform_schema_url) - response.raise_for_status() - except RequestException as e: - self.logger.error( - 'Failed to retrieve schema from {url}: {error}'.format( - url=platform_schema_url, error=e)) - raise AvroEncoderError( - 'Failed to retrieve schema from {url}: {error}'.format( - url=platform_schema_url, error=e)) from None - - try: - json_response = response.json() - return json_response['data']['schema'] - except (JSONDecodeError, KeyError) as e: - self.logger.error( - 'Retrieved schema is malformed: {errorType} {errorMessage}' - .format(errorType=type(e), errorMessage=e)) - raise AvroEncoderError( - 'Retrieved schema is malformed: {errorType} {errorMessage}' - .format(errorType=type(e), errorMessage=e)) from None - -class AvroEncoderError(Exception): +class AvroInterpreterError(Exception): def __init__(self, message=None): - self.message = message + self.message = message \ No newline at end of file diff --git a/tests/test_avro_encoder.py b/tests/test_avro_encoder.py index 079088a..a259656 100644 --- a/tests/test_avro_encoder.py +++ b/tests/test_avro_encoder.py @@ -1,7 +1,7 @@ import json import pytest -from nypl_py_utils.classes.avro_encoder import AvroEncoder, AvroEncoderError +from nypl_py_utils.classes.avro_encoder import AvroEncoder, AvroInterpreterError from requests.exceptions import ConnectTimeout _TEST_SCHEMA = {'data': {'schema': json.dumps({ @@ -33,19 +33,19 @@ def test_get_json_schema(self, test_instance): def test_request_error(self, requests_mock): requests_mock.get('https://test_schema_url', exc=ConnectTimeout) - with pytest.raises(AvroEncoderError): + with pytest.raises(AvroInterpreterError): AvroEncoder('https://test_schema_url') def test_bad_json_error(self, requests_mock): requests_mock.get( 'https://test_schema_url', text='bad json') - with pytest.raises(AvroEncoderError): + with pytest.raises(AvroInterpreterError): AvroEncoder('https://test_schema_url') def test_missing_key_error(self, requests_mock): requests_mock.get( 'https://test_schema_url', text=json.dumps({'field': 'value'})) - with pytest.raises(AvroEncoderError): + with pytest.raises(AvroInterpreterError): AvroEncoder('https://test_schema_url') def test_encode_record(self, test_instance): @@ -56,7 +56,7 @@ def test_encode_record(self, test_instance): def test_encode_record_error(self, test_instance): TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} - with pytest.raises(AvroEncoderError): + with pytest.raises(AvroInterpreterError): test_instance.encode_record(TEST_RECORD) def test_encode_batch(self, test_instance): @@ -75,7 +75,7 @@ def test_encode_batch_error(self, test_instance): BAD_BATCH = [ {'patron_id': 123, 'library_branch': 'aa'}, {'patron_id': 456, 'bad_field': 'bad'}] - with pytest.raises(AvroEncoderError): + with pytest.raises(AvroInterpreterError): test_instance.encode_batch(BAD_BATCH) def test_decode_record(self, test_instance): @@ -86,5 +86,5 @@ def test_decode_record(self, test_instance): def test_decode_record_error(self, test_instance): TEST_ENCODED_RECORD = b'bad-encoding' - with pytest.raises(AvroEncoderError): + with pytest.raises(AvroInterpreterError): test_instance.decode_record(TEST_ENCODED_RECORD) From e20e33c83d21271a81f82000d7e026652c884c75 Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Mon, 1 Jul 2024 12:00:59 -0500 Subject: [PATCH 02/14] Tests are only slightly working... --- .python-version | 2 +- CHANGELOG.md | 3 + pyproject.toml | 4 +- .../{avro_encoder.py => avro_client.py} | 41 ++- tests/test_avro_encoder.py | 90 ----- tests/test_avro_interpreter.py | 328 ++++++++++++++++++ 6 files changed, 362 insertions(+), 106 deletions(-) rename src/nypl_py_utils/classes/{avro_encoder.py => avro_client.py} (81%) delete mode 100644 tests/test_avro_encoder.py create mode 100644 tests/test_avro_interpreter.py diff --git a/.python-version b/.python-version index fd15561..9f3d4c1 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.9.16 \ No newline at end of file +3.9.16 diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f1f08d..f7b9675 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Changelog +## v1.1.6 6/26/24 +- Generalized Avro functions and separated encoding/decoding behavior. + ## v1.1.5 6/6/24 - Use executemany instead of execute when appropriate in RedshiftClient.execute_transaction diff --git a/pyproject.toml b/pyproject.toml index 6939787..731fc5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [] "Bug Tracker" = "https://github.com/NYPL/python-utils/issues" [project.optional-dependencies] -avro-encoder = [ +avro-client = [ "avro>=1.11.1", "requests>=2.28.1" ] @@ -67,7 +67,7 @@ research-catalog-identifier-helper = [ "requests>=2.28.1" ] development = [ - "nypl_py_utils[avro-encoder,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,postgresql-pool-client,redshift-client,s3-client,config-helper,obfuscation-helper,research-catalog-identifier-helper]", + "nypl_py_utils[avro-client,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,postgresql-pool-client,redshift-client,s3-client,config-helper,obfuscation-helper,research-catalog-identifier-helper]", "flake8>=6.0.0", "freezegun>=1.2.2", "mock>=4.0.3", diff --git a/src/nypl_py_utils/classes/avro_encoder.py b/src/nypl_py_utils/classes/avro_client.py similarity index 81% rename from src/nypl_py_utils/classes/avro_encoder.py rename to src/nypl_py_utils/classes/avro_client.py index 53dd260..f4ba11e 100644 --- a/src/nypl_py_utils/classes/avro_encoder.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -1,13 +1,14 @@ import avro.schema +import base64 import requests -from avro.errors import AvroException +from avro.schema import AvroException from avro.io import BinaryDecoder, BinaryEncoder, DatumReader, DatumWriter from io import BytesIO from nypl_py_utils.functions.log_helper import create_log from requests.exceptions import JSONDecodeError, RequestException -class AvroInterpreter: +class AvroClient: """ Base class for Avro schema interaction. Takes as input the Platform API endpoint from which to fetch the schema in JSON format. @@ -32,7 +33,7 @@ def get_json_schema(self, platform_schema_url): self.logger.error( 'Failed to retrieve schema from {url}: {error}'.format( url=platform_schema_url, error=e)) - raise AvroInterpreterError( + raise AvroClientError( 'Failed to retrieve schema from {url}: {error}'.format( url=platform_schema_url, error=e)) from None @@ -43,12 +44,12 @@ def get_json_schema(self, platform_schema_url): self.logger.error( 'Retrieved schema is malformed: {errorType} {errorMessage}' .format(errorType=type(e), errorMessage=e)) - raise AvroInterpreterError( + raise AvroClientError( 'Retrieved schema is malformed: {errorType} {errorMessage}' .format(errorType=type(e), errorMessage=e)) from None -class AvroEncoder(AvroInterpreter): +class AvroEncoder(AvroClient): """ Class for encoding records using an Avro schema. Takes as input the Platform API endpoint from which to fetch the schema in JSON format. @@ -71,7 +72,7 @@ def encode_record(self, record): return output_stream.getvalue() except AvroException as e: self.logger.error('Failed to encode record: {}'.format(e)) - raise AvroInterpreterError( + raise AvroClientError( 'Failed to encode record: {}'.format(e)) from None def encode_batch(self, record_list): @@ -95,26 +96,40 @@ def encode_batch(self, record_list): output_stream.truncate(0) except AvroException as e: self.logger.error('Failed to encode record: {}'.format(e)) - raise AvroInterpreterError( + raise AvroClientError( 'Failed to encode record: {}'.format(e)) from None return encoded_records -class AvroDecoder(AvroInterpreter): +class AvroDecoder(AvroClient): """ Class for decoding records using an Avro schema. Takes as input the Platform API endpoint from which to fetch the schema in JSON format. """ - def decode_record(self, record): + def decode_record(self, record, encoding="binary"): """ - Decodes a single record represented as a byte string using the given - Avro schema. + Decodes a single record represented either as a byte or + base64 string, using the given Avro schema. Returns a dictionary where each key is a field in the schema. """ self.logger.debug('Decoding {rec} using {schema} schema'.format( rec=record, schema=self.schema.name)) + + if encoding == "base64": + return self._decode_base64(record) + elif encoding == "binary": + return self._decode_binary(record) + else: + self.logger.error('Failed to decode record due to encoding type: {}'.format(encoding)) + raise AvroClientError( + 'Invalid encoding type: {}'.format(encoding)) + + def _decode_base64(self, record): + return base64.b64decode(record).decode('utf-8') + + def _decode_binary(self, record): datum_reader = DatumReader(self.schema) with BytesIO(record) as input_stream: decoder = BinaryDecoder(input_stream) @@ -122,10 +137,10 @@ def decode_record(self, record): return datum_reader.read(decoder) except Exception as e: self.logger.error('Failed to decode record: {}'.format(e)) - raise AvroInterpreterError( + raise AvroClientError( 'Failed to decode record: {}'.format(e)) from None -class AvroInterpreterError(Exception): +class AvroClientError(Exception): def __init__(self, message=None): self.message = message \ No newline at end of file diff --git a/tests/test_avro_encoder.py b/tests/test_avro_encoder.py deleted file mode 100644 index a259656..0000000 --- a/tests/test_avro_encoder.py +++ /dev/null @@ -1,90 +0,0 @@ -import json -import pytest - -from nypl_py_utils.classes.avro_encoder import AvroEncoder, AvroInterpreterError -from requests.exceptions import ConnectTimeout - -_TEST_SCHEMA = {'data': {'schema': json.dumps({ - 'name': 'TestSchema', - 'type': 'record', - 'fields': [ - { - 'name': 'patron_id', - 'type': 'int' - }, - { - 'name': 'library_branch', - 'type': ['null', 'string'] - } - ] -})}} - - -class TestAvroEncoder: - - @pytest.fixture - def test_instance(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) - return AvroEncoder('https://test_schema_url') - - def test_get_json_schema(self, test_instance): - assert test_instance.schema == _TEST_SCHEMA['data']['schema'] - - def test_request_error(self, requests_mock): - requests_mock.get('https://test_schema_url', exc=ConnectTimeout) - with pytest.raises(AvroInterpreterError): - AvroEncoder('https://test_schema_url') - - def test_bad_json_error(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text='bad json') - with pytest.raises(AvroInterpreterError): - AvroEncoder('https://test_schema_url') - - def test_missing_key_error(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text=json.dumps({'field': 'value'})) - with pytest.raises(AvroInterpreterError): - AvroEncoder('https://test_schema_url') - - def test_encode_record(self, test_instance): - TEST_RECORD = {'patron_id': 123, 'library_branch': 'aa'} - encoded_record = test_instance.encode_record(TEST_RECORD) - assert type(encoded_record) is bytes - assert test_instance.decode_record(encoded_record) == TEST_RECORD - - def test_encode_record_error(self, test_instance): - TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} - with pytest.raises(AvroInterpreterError): - test_instance.encode_record(TEST_RECORD) - - def test_encode_batch(self, test_instance): - TEST_BATCH = [ - {'patron_id': 123, 'library_branch': 'aa'}, - {'patron_id': 456, 'library_branch': None}, - {'patron_id': 789, 'library_branch': 'bb'}] - encoded_records = test_instance.encode_batch(TEST_BATCH) - assert len(encoded_records) == len(TEST_BATCH) - for i in range(3): - assert type(encoded_records[i]) is bytes - assert test_instance.decode_record( - encoded_records[i]) == TEST_BATCH[i] - - def test_encode_batch_error(self, test_instance): - BAD_BATCH = [ - {'patron_id': 123, 'library_branch': 'aa'}, - {'patron_id': 456, 'bad_field': 'bad'}] - with pytest.raises(AvroInterpreterError): - test_instance.encode_batch(BAD_BATCH) - - def test_decode_record(self, test_instance): - TEST_DECODED_RECORD = {'patron_id': 123, 'library_branch': 'aa'} - TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' - assert test_instance.decode_record( - TEST_ENCODED_RECORD) == TEST_DECODED_RECORD - - def test_decode_record_error(self, test_instance): - TEST_ENCODED_RECORD = b'bad-encoding' - with pytest.raises(AvroInterpreterError): - test_instance.decode_record(TEST_ENCODED_RECORD) diff --git a/tests/test_avro_interpreter.py b/tests/test_avro_interpreter.py new file mode 100644 index 0000000..d5a8b7a --- /dev/null +++ b/tests/test_avro_interpreter.py @@ -0,0 +1,328 @@ +import json +import pytest + +from src.nypl_py_utils.classes.avro_client import AvroDecoder, AvroEncoder, AvroClientError +from requests.exceptions import ConnectTimeout + +_TEST_SCHEMA = {'data': {'schema': json.dumps({ + 'name': 'TestSchema', + 'type': 'record', + 'fields': [ + { + 'name': 'patron_id', + 'type': 'int' + }, + { + 'name': 'library_branch', + 'type': ['null', 'string'] + } + ] +})}} + +_CIRC_TRANS_SCHEMA = {'data': {'schema': json.dumps({ + "name": "CircTransRecord", + "type": "record", + "fields": [ + { + "name": "id", + "type": "int", + "doc": "System-generated sequential ID." + }, + { + "name": "patron_id", + "type": [ + "null", + "int" + ], + "doc": "De-identified Patron ID for record" + }, + { + "name": "item_id", + "type": [ + "null", + "int" + ], + "doc": "Item ID for record" + }, + { + "name": "volume_id", + "type": [ + "null", + "int" + ], + "doc": "Volume ID for record" + }, + { + "name": "bib_id", + "type": [ + "null", + "int" + ], + "doc": "Bib ID for record" + }, + { + "name": "transaction_gmt", + "type": [ + "null", + "string" + ], + "doc": "Transaction date in UNIX format." + }, + { + "name": "application_name", + "type": "string", + "doc": "The name of the program that generated the transaction. Valid program names are: circ (includes transactions made using PC Circ) circa (for transactions written by selfcheckwebserver and in-house use [transaction codes 'u' and 's'], which use webpac to execute transactions.) milcirc milmyselfcheck readreq selfcheck" + }, + { + "name": "source_code", + "type": "string", + "doc": "The transaction source. Possible values are: local INN-Reach ILL" + }, + { + "name": "op_code", + "type": [ + "null", + "string" + ], + "doc": "Type of transaction: o = checkout i = checkin n = hold nb = bib hold ni = item hold nv = volume hold h = hold with recall hb = hold recall bib hi = hold recall item hv = hold recall volume f = filled hold r = renewal b = booking u = use count" + }, + { + "name": "stat_group_code_num", + "type": [ + "null", + "int" + ], + "doc": "The number of the terminal at which the transaction occurred or the user-specified statistics group number for PC-Circ transactions. Also stores the login's statistics group number for circulation transactions performed with the following Circa applications: checkout checkin count internal use" + }, + { + "name": "due_date_gmt", + "type": [ + "null", + "string" + ], + "doc": "Due date in UNIX format. The application of this date depends on the op_code for the transaction. The due date is not included for bookings (op_code b) or filled holds (op_code f). For op_code 'i' (checkin), this is the original due date. For op_code 'r' (renewal), this is the renewal due date. For op_code 'o' (checkouts), this is the item due date. For op_codes 'n' (holds) and 'h' (holds with recall), a non-zero entry indicates that the hold is for a checked-out item that is due on the specified date." + }, + { + "name": "count_type_code_num", + "type": [ + "null", + "int" + ], + "doc": "Indicates the type of use count (for op_code 'u'): Code Number Count Type 1 INTL USE (fixflds 93) 2 COPY USE (fixflds 94) 3 IUSE3 (fixflds 74) 4 PIUSE: generated by the system" + }, + { + "name": "itype_code_num", + "type": [ + "null", + "int" + ], + "doc": "Item type code. (Defined by the library.)" + }, + { + "name": "icode1", + "type": [ + "null", + "int" + ], + "doc": "Item code 1. (Defined by the library.)" + }, + { + "name": "icode2", + "type": [ + "null", + "string" + ], + "doc": "Item code 2. (Defined by the library.)" + }, + { + "name": "item_location_code", + "type": [ + "null", + "string" + ], + "doc": "A five-character location code, right-padded with spaces, from the associated item record." + }, + { + "name": "item_agency_code_num", + "type": [ + "null", + "int" + ], + "doc": "A one-character AGENCY code from the associated item record." + }, + { + "name": "ptype_code", + "type": [ + "null", + "string" + ], + "doc": "Patron type code. (Defined by the library.)" + }, + { + "name": "pcode1", + "type": [ + "null", + "string" + ], + "doc": "Patron code 1. (Defined by the library.)" + }, + { + "name": "pcode2", + "type": [ + "null", + "string" + ], + "doc": "Patron code 2. (Defined by the library.)" + }, + { + "name": "pcode3", + "type": [ + "null", + "int" + ], + "doc": "Patron code 3. (Defined by the library.)" + }, + { + "name": "pcode4", + "type": [ + "null", + "int" + ], + "doc": "Patron code 4. (Defined by the library.)" + }, + { + "name": "patron_home_library_code", + "type": [ + "null", + "string" + ], + "doc": "A five-character location code, right-padded with spaces, from the associated patron record." + }, + { + "name": "patron_agency_code_num", + "type": [ + "null", + "int" + ], + "doc": "A one-character AGENCY code from the associated patron record." + }, + { + "name": "loanrule_code_num", + "type": [ + "null", + "int" + ] + } + ] +})}} + + +_TEST_RECORDS = { + "records": [ + { + "recordId": "789", + "data": "oAICoAICoAICoAICoAICLDIwMTctMTEtMTQgMTE6NDM6NDktMDUMc2llcnJhCmxvY2FsAgJvAhICLDIwMTctMTItMDUgMDQ6MDA6MDAtMDUCAAKUAgIAAgItAgpld2EwbgIAAgQxMAICLQICLQICAgACCmV3ICAgAgACCA==", + "approximateArrivalTimestamp": "2012-04-23T18:25:43.511Z" + }, + { + "recordId": "123", + "data": "oAICoAICoAICoAICoAICLDIwMTctMTEtMTQgMTE6NDM6NTAtMDUMc2llcnJhCmxvY2FsAgJvAlgCLDIwMTctMTItMDUgMDQ6MDA6MDAtMDUCAAKSAwIAAgItAgpmd2owYQIAAgQ2MAICLQICcgIEAgACCmZ3ICAgAgACCg==", + "approximateArrivalTimestamp": "2012-04-23T18:25:43.511Z" + }, + { + "recordId": "456", + "data": "lgsCSDlhNmZiYmU5LWJkMTAtNDA2Ny05ZmVhLWEwODM4ZGU2YzUyNwIGNzE1Ah4yMzQ1Njc4OTA5ODc2NTQCHDMyMTAxMDk2MTE1MjE1AgZQVUwAAAIyMjAxNy0xMC0wNFQxNjo0MToyNS0wNDowMAA=", + "approximateArrivalTimestamp": "2012-04-23T18:25:43.511Z" + } + ], + "region": "us-east-1", + "deliveryStreamArn": "arn:aws:kinesis:EXAMPLE", + "invocationId": "invocationIdExample" +} + + +class TestAvroEncoder: + + @pytest.fixture + def test_avro_encoder_instance(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + return AvroEncoder('https://test_schema_url') + + @pytest.fixture + def test_avro_decoder_instance(self, requests_mock): + # requests_mock.get( + # 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + requests_mock.get( + 'https://test_schema_url', text=json.dumps(_CIRC_TRANS_SCHEMA)) + return AvroDecoder('https://test_schema_url') + + def test_get_json_schema(self, test_avro_encoder_instance, test_avro_decoder_instance): + assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data']['schema'] + assert test_avro_decoder_instance.schema == _CIRC_TRANS_SCHEMA['data']['schema'] + + def test_request_error(self, requests_mock): + requests_mock.get('https://test_schema_url', exc=ConnectTimeout) + with pytest.raises(AvroClientError): + AvroEncoder('https://test_schema_url') + + def test_bad_json_error(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text='bad json') + with pytest.raises(AvroClientError): + AvroEncoder('https://test_schema_url') + + def test_missing_key_error(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text=json.dumps({'field': 'value'})) + with pytest.raises(AvroClientError): + AvroEncoder('https://test_schema_url') + + def test_encode_record(self, test_avro_encoder_instance, test_avro_decoder_instance): + TEST_RECORD = {'patron_id': 123, 'library_branch': 'aa'} + encoded_record = test_avro_encoder_instance.encode_record(TEST_RECORD) + assert type(encoded_record) is bytes + assert test_avro_decoder_instance.decode_record(encoded_record) == TEST_RECORD + + def test_encode_record_error(self, test_avro_encoder_instance): + TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} + with pytest.raises(AvroClientError): + test_avro_encoder_instance.encode_record(TEST_RECORD) + + def test_encode_batch(self, test_avro_encoder_instance, test_avro_decoder_instance): + TEST_BATCH = [ + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'library_branch': None}, + {'patron_id': 789, 'library_branch': 'bb'}] + encoded_records = test_avro_encoder_instance.encode_batch(TEST_BATCH) + assert len(encoded_records) == len(TEST_BATCH) + for i in range(3): + assert type(encoded_records[i]) is bytes + assert test_avro_decoder_instance.decode_record( + encoded_records[i]) == TEST_BATCH[i] + + def test_encode_batch_error(self, test_avro_encoder_instance): + BAD_BATCH = [ + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'bad_field': 'bad'}] + with pytest.raises(AvroClientError): + test_avro_encoder_instance.encode_batch(BAD_BATCH) + + def test_decode_record(self, test_avro_decoder_instance): + TEST_DECODED_RECORD = {'patron_id': 123, 'library_branch': 'aa'} + TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' + assert test_avro_decoder_instance.decode_record( + TEST_ENCODED_RECORD) == TEST_DECODED_RECORD + + def test_decode_record_error(self, test_avro_decoder_instance): + TEST_ENCODED_RECORD = b'bad-encoding' + with pytest.raises(AvroClientError): + test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) + + def test_decode_records(self, test_avro_decoder_instance): + test_record = _TEST_RECORDS.get('records') + result = "" + if (type(test_record) is list): + data = json.dumps(test_record[0]["data"]) + result = test_avro_decoder_instance.decode_record(record=data, encoding="base64") + print(result) From f284f9e9220604469aca2ba16a678e1f74f49ead Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Mon, 1 Jul 2024 18:56:59 -0500 Subject: [PATCH 03/14] It's working --- .gitignore | 1 + pyproject.toml | 8 + src/nypl_py_utils/classes/avro_client.py | 18 +- tests/test_avro_client.py | 102 +++++++ tests/test_avro_interpreter.py | 328 ----------------------- 5 files changed, 125 insertions(+), 332 deletions(-) create mode 100644 tests/test_avro_client.py delete mode 100644 tests/test_avro_interpreter.py diff --git a/.gitignore b/.gitignore index 28aac37..bd867e1 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__/ *env/ *.py[cod] *$py.class +*.pytest_cache \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 731fc5f..534e902 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,3 +75,11 @@ development = [ "pytest-mock>=3.10.0", "requests-mock>=1.10.0" ] + +[tool.pytest.ini_options] +minversion = "8.0" +addopts = "-ra -q" +pythonpath = "src" +testpaths = [ + "tests" +] \ No newline at end of file diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index f4ba11e..1bdd72d 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -1,5 +1,6 @@ import avro.schema import base64 +import json import requests from avro.schema import AvroException @@ -114,8 +115,8 @@ def decode_record(self, record, encoding="binary"): Returns a dictionary where each key is a field in the schema. """ - self.logger.debug('Decoding {rec} using {schema} schema'.format( - rec=record, schema=self.schema.name)) + self.logger.info('Decoding {rec} of type {type} using {schema} schema'.format( + rec=record, type=encoding, schema=self.schema.name)) if encoding == "base64": return self._decode_base64(record) @@ -125,9 +126,18 @@ def decode_record(self, record, encoding="binary"): self.logger.error('Failed to decode record due to encoding type: {}'.format(encoding)) raise AvroClientError( 'Invalid encoding type: {}'.format(encoding)) - + def _decode_base64(self, record): - return base64.b64decode(record).decode('utf-8') + decoded_data = base64.b64decode(record).decode("utf-8") + try: + return json.loads(decoded_data) + except Exception as e: + if isinstance(decoded_data, bytes): + self._decode_binary(decoded_data) + else: + self.logger.error('Failed to decode record: {}'.format(e)) + raise AvroClientError( + 'Failed to decode record: {}'.format(e)) from None def _decode_binary(self, record): datum_reader = DatumReader(self.schema) diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py new file mode 100644 index 0000000..7e26981 --- /dev/null +++ b/tests/test_avro_client.py @@ -0,0 +1,102 @@ +import json +import pytest + +from nypl_py_utils.classes.avro_client import AvroDecoder, AvroEncoder, AvroClientError +from requests.exceptions import ConnectTimeout + +_TEST_SCHEMA = {'data': {'schema': json.dumps({ + 'name': 'TestSchema', + 'type': 'record', + 'fields': [ + { + 'name': 'patron_id', + 'type': 'int' + }, + { + 'name': 'library_branch', + 'type': ['null', 'string'] + } + ] +})}} + +class TestAvroClient: + + @pytest.fixture + def test_avro_encoder_instance(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + return AvroEncoder('https://test_schema_url') + + @pytest.fixture + def test_avro_decoder_instance(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + return AvroDecoder('https://test_schema_url') + + def test_get_json_schema(self, test_avro_encoder_instance, test_avro_decoder_instance): + assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data']['schema'] + assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data']['schema'] + + def test_request_error(self, requests_mock): + requests_mock.get('https://test_schema_url', exc=ConnectTimeout) + with pytest.raises(AvroClientError): + AvroEncoder('https://test_schema_url') + + def test_bad_json_error(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text='bad json') + with pytest.raises(AvroClientError): + AvroEncoder('https://test_schema_url') + + def test_missing_key_error(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text=json.dumps({'field': 'value'})) + with pytest.raises(AvroClientError): + AvroEncoder('https://test_schema_url') + + def test_encode_record(self, test_avro_encoder_instance, test_avro_decoder_instance): + TEST_RECORD = {'patron_id': 123, 'library_branch': 'aa'} + encoded_record = test_avro_encoder_instance.encode_record(TEST_RECORD) + assert type(encoded_record) is bytes + assert test_avro_decoder_instance.decode_record(encoded_record) == TEST_RECORD + + def test_encode_record_error(self, test_avro_encoder_instance): + TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} + with pytest.raises(AvroClientError): + test_avro_encoder_instance.encode_record(TEST_RECORD) + + def test_encode_batch(self, test_avro_encoder_instance, test_avro_decoder_instance): + TEST_BATCH = [ + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'library_branch': None}, + {'patron_id': 789, 'library_branch': 'bb'}] + encoded_records = test_avro_encoder_instance.encode_batch(TEST_BATCH) + assert len(encoded_records) == len(TEST_BATCH) + for i in range(3): + assert type(encoded_records[i]) is bytes + assert test_avro_decoder_instance.decode_record( + encoded_records[i]) == TEST_BATCH[i] + + def test_encode_batch_error(self, test_avro_encoder_instance): + BAD_BATCH = [ + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'bad_field': 'bad'}] + with pytest.raises(AvroClientError): + test_avro_encoder_instance.encode_batch(BAD_BATCH) + + def test_decode_record_binary(self, test_avro_decoder_instance): + TEST_DECODED_RECORD = {"patron_id": 123, "library_branch": "aa"} + TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' + assert test_avro_decoder_instance.decode_record( + TEST_ENCODED_RECORD) == TEST_DECODED_RECORD + + def test_decode_record_b64(self, test_avro_decoder_instance): + TEST_DECODED_RECORD = {"patron_id'": 123, "library_branch": "aa"} + TEST_ENCODED_RECORD = "eyJwYXRyb25faWQnIjogMTIzLCAibGlicmFyeV9icmFuY2giOiAiYWEifQ==" + assert test_avro_decoder_instance.decode_record( + TEST_ENCODED_RECORD, "base64") == TEST_DECODED_RECORD + + def test_decode_record_error(self, test_avro_decoder_instance): + TEST_ENCODED_RECORD = b'bad-encoding' + with pytest.raises(AvroClientError): + test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) diff --git a/tests/test_avro_interpreter.py b/tests/test_avro_interpreter.py deleted file mode 100644 index d5a8b7a..0000000 --- a/tests/test_avro_interpreter.py +++ /dev/null @@ -1,328 +0,0 @@ -import json -import pytest - -from src.nypl_py_utils.classes.avro_client import AvroDecoder, AvroEncoder, AvroClientError -from requests.exceptions import ConnectTimeout - -_TEST_SCHEMA = {'data': {'schema': json.dumps({ - 'name': 'TestSchema', - 'type': 'record', - 'fields': [ - { - 'name': 'patron_id', - 'type': 'int' - }, - { - 'name': 'library_branch', - 'type': ['null', 'string'] - } - ] -})}} - -_CIRC_TRANS_SCHEMA = {'data': {'schema': json.dumps({ - "name": "CircTransRecord", - "type": "record", - "fields": [ - { - "name": "id", - "type": "int", - "doc": "System-generated sequential ID." - }, - { - "name": "patron_id", - "type": [ - "null", - "int" - ], - "doc": "De-identified Patron ID for record" - }, - { - "name": "item_id", - "type": [ - "null", - "int" - ], - "doc": "Item ID for record" - }, - { - "name": "volume_id", - "type": [ - "null", - "int" - ], - "doc": "Volume ID for record" - }, - { - "name": "bib_id", - "type": [ - "null", - "int" - ], - "doc": "Bib ID for record" - }, - { - "name": "transaction_gmt", - "type": [ - "null", - "string" - ], - "doc": "Transaction date in UNIX format." - }, - { - "name": "application_name", - "type": "string", - "doc": "The name of the program that generated the transaction. Valid program names are: circ (includes transactions made using PC Circ) circa (for transactions written by selfcheckwebserver and in-house use [transaction codes 'u' and 's'], which use webpac to execute transactions.) milcirc milmyselfcheck readreq selfcheck" - }, - { - "name": "source_code", - "type": "string", - "doc": "The transaction source. Possible values are: local INN-Reach ILL" - }, - { - "name": "op_code", - "type": [ - "null", - "string" - ], - "doc": "Type of transaction: o = checkout i = checkin n = hold nb = bib hold ni = item hold nv = volume hold h = hold with recall hb = hold recall bib hi = hold recall item hv = hold recall volume f = filled hold r = renewal b = booking u = use count" - }, - { - "name": "stat_group_code_num", - "type": [ - "null", - "int" - ], - "doc": "The number of the terminal at which the transaction occurred or the user-specified statistics group number for PC-Circ transactions. Also stores the login's statistics group number for circulation transactions performed with the following Circa applications: checkout checkin count internal use" - }, - { - "name": "due_date_gmt", - "type": [ - "null", - "string" - ], - "doc": "Due date in UNIX format. The application of this date depends on the op_code for the transaction. The due date is not included for bookings (op_code b) or filled holds (op_code f). For op_code 'i' (checkin), this is the original due date. For op_code 'r' (renewal), this is the renewal due date. For op_code 'o' (checkouts), this is the item due date. For op_codes 'n' (holds) and 'h' (holds with recall), a non-zero entry indicates that the hold is for a checked-out item that is due on the specified date." - }, - { - "name": "count_type_code_num", - "type": [ - "null", - "int" - ], - "doc": "Indicates the type of use count (for op_code 'u'): Code Number Count Type 1 INTL USE (fixflds 93) 2 COPY USE (fixflds 94) 3 IUSE3 (fixflds 74) 4 PIUSE: generated by the system" - }, - { - "name": "itype_code_num", - "type": [ - "null", - "int" - ], - "doc": "Item type code. (Defined by the library.)" - }, - { - "name": "icode1", - "type": [ - "null", - "int" - ], - "doc": "Item code 1. (Defined by the library.)" - }, - { - "name": "icode2", - "type": [ - "null", - "string" - ], - "doc": "Item code 2. (Defined by the library.)" - }, - { - "name": "item_location_code", - "type": [ - "null", - "string" - ], - "doc": "A five-character location code, right-padded with spaces, from the associated item record." - }, - { - "name": "item_agency_code_num", - "type": [ - "null", - "int" - ], - "doc": "A one-character AGENCY code from the associated item record." - }, - { - "name": "ptype_code", - "type": [ - "null", - "string" - ], - "doc": "Patron type code. (Defined by the library.)" - }, - { - "name": "pcode1", - "type": [ - "null", - "string" - ], - "doc": "Patron code 1. (Defined by the library.)" - }, - { - "name": "pcode2", - "type": [ - "null", - "string" - ], - "doc": "Patron code 2. (Defined by the library.)" - }, - { - "name": "pcode3", - "type": [ - "null", - "int" - ], - "doc": "Patron code 3. (Defined by the library.)" - }, - { - "name": "pcode4", - "type": [ - "null", - "int" - ], - "doc": "Patron code 4. (Defined by the library.)" - }, - { - "name": "patron_home_library_code", - "type": [ - "null", - "string" - ], - "doc": "A five-character location code, right-padded with spaces, from the associated patron record." - }, - { - "name": "patron_agency_code_num", - "type": [ - "null", - "int" - ], - "doc": "A one-character AGENCY code from the associated patron record." - }, - { - "name": "loanrule_code_num", - "type": [ - "null", - "int" - ] - } - ] -})}} - - -_TEST_RECORDS = { - "records": [ - { - "recordId": "789", - "data": "oAICoAICoAICoAICoAICLDIwMTctMTEtMTQgMTE6NDM6NDktMDUMc2llcnJhCmxvY2FsAgJvAhICLDIwMTctMTItMDUgMDQ6MDA6MDAtMDUCAAKUAgIAAgItAgpld2EwbgIAAgQxMAICLQICLQICAgACCmV3ICAgAgACCA==", - "approximateArrivalTimestamp": "2012-04-23T18:25:43.511Z" - }, - { - "recordId": "123", - "data": "oAICoAICoAICoAICoAICLDIwMTctMTEtMTQgMTE6NDM6NTAtMDUMc2llcnJhCmxvY2FsAgJvAlgCLDIwMTctMTItMDUgMDQ6MDA6MDAtMDUCAAKSAwIAAgItAgpmd2owYQIAAgQ2MAICLQICcgIEAgACCmZ3ICAgAgACCg==", - "approximateArrivalTimestamp": "2012-04-23T18:25:43.511Z" - }, - { - "recordId": "456", - "data": "lgsCSDlhNmZiYmU5LWJkMTAtNDA2Ny05ZmVhLWEwODM4ZGU2YzUyNwIGNzE1Ah4yMzQ1Njc4OTA5ODc2NTQCHDMyMTAxMDk2MTE1MjE1AgZQVUwAAAIyMjAxNy0xMC0wNFQxNjo0MToyNS0wNDowMAA=", - "approximateArrivalTimestamp": "2012-04-23T18:25:43.511Z" - } - ], - "region": "us-east-1", - "deliveryStreamArn": "arn:aws:kinesis:EXAMPLE", - "invocationId": "invocationIdExample" -} - - -class TestAvroEncoder: - - @pytest.fixture - def test_avro_encoder_instance(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) - return AvroEncoder('https://test_schema_url') - - @pytest.fixture - def test_avro_decoder_instance(self, requests_mock): - # requests_mock.get( - # 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) - requests_mock.get( - 'https://test_schema_url', text=json.dumps(_CIRC_TRANS_SCHEMA)) - return AvroDecoder('https://test_schema_url') - - def test_get_json_schema(self, test_avro_encoder_instance, test_avro_decoder_instance): - assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data']['schema'] - assert test_avro_decoder_instance.schema == _CIRC_TRANS_SCHEMA['data']['schema'] - - def test_request_error(self, requests_mock): - requests_mock.get('https://test_schema_url', exc=ConnectTimeout) - with pytest.raises(AvroClientError): - AvroEncoder('https://test_schema_url') - - def test_bad_json_error(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text='bad json') - with pytest.raises(AvroClientError): - AvroEncoder('https://test_schema_url') - - def test_missing_key_error(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text=json.dumps({'field': 'value'})) - with pytest.raises(AvroClientError): - AvroEncoder('https://test_schema_url') - - def test_encode_record(self, test_avro_encoder_instance, test_avro_decoder_instance): - TEST_RECORD = {'patron_id': 123, 'library_branch': 'aa'} - encoded_record = test_avro_encoder_instance.encode_record(TEST_RECORD) - assert type(encoded_record) is bytes - assert test_avro_decoder_instance.decode_record(encoded_record) == TEST_RECORD - - def test_encode_record_error(self, test_avro_encoder_instance): - TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} - with pytest.raises(AvroClientError): - test_avro_encoder_instance.encode_record(TEST_RECORD) - - def test_encode_batch(self, test_avro_encoder_instance, test_avro_decoder_instance): - TEST_BATCH = [ - {'patron_id': 123, 'library_branch': 'aa'}, - {'patron_id': 456, 'library_branch': None}, - {'patron_id': 789, 'library_branch': 'bb'}] - encoded_records = test_avro_encoder_instance.encode_batch(TEST_BATCH) - assert len(encoded_records) == len(TEST_BATCH) - for i in range(3): - assert type(encoded_records[i]) is bytes - assert test_avro_decoder_instance.decode_record( - encoded_records[i]) == TEST_BATCH[i] - - def test_encode_batch_error(self, test_avro_encoder_instance): - BAD_BATCH = [ - {'patron_id': 123, 'library_branch': 'aa'}, - {'patron_id': 456, 'bad_field': 'bad'}] - with pytest.raises(AvroClientError): - test_avro_encoder_instance.encode_batch(BAD_BATCH) - - def test_decode_record(self, test_avro_decoder_instance): - TEST_DECODED_RECORD = {'patron_id': 123, 'library_branch': 'aa'} - TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' - assert test_avro_decoder_instance.decode_record( - TEST_ENCODED_RECORD) == TEST_DECODED_RECORD - - def test_decode_record_error(self, test_avro_decoder_instance): - TEST_ENCODED_RECORD = b'bad-encoding' - with pytest.raises(AvroClientError): - test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) - - def test_decode_records(self, test_avro_decoder_instance): - test_record = _TEST_RECORDS.get('records') - result = "" - if (type(test_record) is list): - data = json.dumps(test_record[0]["data"]) - result = test_avro_decoder_instance.decode_record(record=data, encoding="base64") - print(result) From f94b52ecb9368102dd69048afbf20118b1519f4a Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Mon, 1 Jul 2024 23:04:35 -0500 Subject: [PATCH 04/14] update imports --- src/nypl_py_utils/classes/avro_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index 1bdd72d..ae2c988 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -3,7 +3,7 @@ import json import requests -from avro.schema import AvroException +from avro.errors import AvroException from avro.io import BinaryDecoder, BinaryEncoder, DatumReader, DatumWriter from io import BytesIO from nypl_py_utils.functions.log_helper import create_log From 69f12e5faff41be6c8e6e070a37e5682c66bcad4 Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Tue, 2 Jul 2024 09:18:57 -0500 Subject: [PATCH 05/14] Reformatted with black --- CHANGELOG.md | 2 +- src/nypl_py_utils/classes/avro_client.py | 97 +++++---- src/nypl_py_utils/classes/kinesis_client.py | 61 +++--- src/nypl_py_utils/classes/kms_client.py | 26 ++- src/nypl_py_utils/classes/mysql_client.py | 38 ++-- .../classes/oauth2_api_client.py | 84 ++++---- .../classes/postgresql_client.py | 44 ++-- .../classes/postgresql_pool_client.py | 86 ++++---- src/nypl_py_utils/classes/redshift_client.py | 64 +++--- src/nypl_py_utils/classes/s3_client.py | 52 +++-- src/nypl_py_utils/functions/config_helper.py | 20 +- src/nypl_py_utils/functions/log_helper.py | 15 +- .../functions/obfuscation_helper.py | 14 +- .../research_catalog_identifier_helper.py | 65 +++--- tests/test_avro_client.py | 112 +++++----- tests/test_config_helper.py | 70 +++--- tests/test_kinesis_client.py | 135 +++++++----- tests/test_kms_client.py | 20 +- tests/test_log_helper.py | 54 +++-- tests/test_mysql_client.py | 51 ++--- tests/test_oauth2_api_client.py | 199 ++++++++++-------- tests/test_obfuscation_helper.py | 19 +- tests/test_postgresql_client.py | 45 ++-- tests/test_postgresql_pool_client.py | 86 ++++---- tests/test_redshift_client.py | 115 +++++----- ...test_research_catalog_identifier_helper.py | 102 +++++---- tests/test_s3_client.py | 14 +- 27 files changed, 942 insertions(+), 748 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f7b9675..c776943 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog ## v1.1.6 6/26/24 -- Generalized Avro functions and separated encoding/decoding behavior. +- Generalized Avro functions and separated encoding/decoding behavior ## v1.1.5 6/6/24 - Use executemany instead of execute when appropriate in RedshiftClient.execute_transaction diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index ae2c988..6aab3b4 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -9,46 +9,53 @@ from nypl_py_utils.functions.log_helper import create_log from requests.exceptions import JSONDecodeError, RequestException + class AvroClient: """ Base class for Avro schema interaction. Takes as input the Platform API endpoint from which to fetch the schema in JSON format. """ - + def __init__(self, platform_schema_url): - self.logger = create_log('avro_encoder') - self.schema = avro.schema.parse( - self.get_json_schema(platform_schema_url)) - + self.logger = create_log("avro_encoder") + self.schema = avro.schema.parse(self.get_json_schema(platform_schema_url)) + def get_json_schema(self, platform_schema_url): """ Fetches a JSON response from the input Platform API endpoint and interprets it as an Avro schema. """ - self.logger.info('Fetching Avro schema from {}'.format( - platform_schema_url)) + self.logger.info("Fetching Avro schema from {}".format(platform_schema_url)) try: response = requests.get(platform_schema_url) response.raise_for_status() except RequestException as e: self.logger.error( - 'Failed to retrieve schema from {url}: {error}'.format( - url=platform_schema_url, error=e)) + "Failed to retrieve schema from {url}: {error}".format( + url=platform_schema_url, error=e + ) + ) raise AvroClientError( - 'Failed to retrieve schema from {url}: {error}'.format( - url=platform_schema_url, error=e)) from None + "Failed to retrieve schema from {url}: {error}".format( + url=platform_schema_url, error=e + ) + ) from None try: json_response = response.json() - return json_response['data']['schema'] + return json_response["data"]["schema"] except (JSONDecodeError, KeyError) as e: self.logger.error( - 'Retrieved schema is malformed: {errorType} {errorMessage}' - .format(errorType=type(e), errorMessage=e)) + "Retrieved schema is malformed: {errorType} {errorMessage}".format( + errorType=type(e), errorMessage=e + ) + ) raise AvroClientError( - 'Retrieved schema is malformed: {errorType} {errorMessage}' - .format(errorType=type(e), errorMessage=e)) from None - + "Retrieved schema is malformed: {errorType} {errorMessage}".format( + errorType=type(e), errorMessage=e + ) + ) from None + class AvroEncoder(AvroClient): """ @@ -63,8 +70,8 @@ def encode_record(self, record): Returns the encoded record as a byte string. """ self.logger.debug( - 'Encoding record using {schema} schema'.format( - schema=self.schema.name)) + "Encoding record using {schema} schema".format(schema=self.schema.name) + ) datum_writer = DatumWriter(self.schema) with BytesIO() as output_stream: encoder = BinaryEncoder(output_stream) @@ -72,9 +79,8 @@ def encode_record(self, record): datum_writer.write(record, encoder) return output_stream.getvalue() except AvroException as e: - self.logger.error('Failed to encode record: {}'.format(e)) - raise AvroClientError( - 'Failed to encode record: {}'.format(e)) from None + self.logger.error("Failed to encode record: {}".format(e)) + raise AvroClientError("Failed to encode record: {}".format(e)) from None def encode_batch(self, record_list): """ @@ -83,8 +89,10 @@ def encode_batch(self, record_list): Returns a list of byte strings where each string is an encoded record. """ self.logger.info( - 'Encoding ({num_rec}) records using {schema} schema'.format( - num_rec=len(record_list), schema=self.schema.name)) + "Encoding ({num_rec}) records using {schema} schema".format( + num_rec=len(record_list), schema=self.schema.name + ) + ) encoded_records = [] datum_writer = DatumWriter(self.schema) with BytesIO() as output_stream: @@ -96,9 +104,10 @@ def encode_batch(self, record_list): output_stream.seek(0) output_stream.truncate(0) except AvroException as e: - self.logger.error('Failed to encode record: {}'.format(e)) + self.logger.error("Failed to encode record: {}".format(e)) raise AvroClientError( - 'Failed to encode record: {}'.format(e)) from None + "Failed to encode record: {}".format(e) + ) from None return encoded_records @@ -110,34 +119,37 @@ class AvroDecoder(AvroClient): def decode_record(self, record, encoding="binary"): """ - Decodes a single record represented either as a byte or + Decodes a single record represented either as a byte or base64 string, using the given Avro schema. Returns a dictionary where each key is a field in the schema. """ - self.logger.info('Decoding {rec} of type {type} using {schema} schema'.format( - rec=record, type=encoding, schema=self.schema.name)) - + self.logger.info( + "Decoding {rec} of type {type} using {schema} schema".format( + rec=record, type=encoding, schema=self.schema.name + ) + ) + if encoding == "base64": return self._decode_base64(record) elif encoding == "binary": return self._decode_binary(record) else: - self.logger.error('Failed to decode record due to encoding type: {}'.format(encoding)) - raise AvroClientError( - 'Invalid encoding type: {}'.format(encoding)) - + self.logger.error( + "Failed to decode record due to encoding type: {}".format(encoding) + ) + raise AvroClientError("Invalid encoding type: {}".format(encoding)) + def _decode_base64(self, record): - decoded_data = base64.b64decode(record).decode("utf-8") + decoded_data = base64.b64decode(record) try: return json.loads(decoded_data) except Exception as e: if isinstance(decoded_data, bytes): - self._decode_binary(decoded_data) + return self._decode_binary(decoded_data) else: - self.logger.error('Failed to decode record: {}'.format(e)) - raise AvroClientError( - 'Failed to decode record: {}'.format(e)) from None + self.logger.error("Failed to decode record: {}".format(e)) + raise AvroClientError("Failed to decode record: {}".format(e)) from None def _decode_binary(self, record): datum_reader = DatumReader(self.schema) @@ -146,11 +158,10 @@ def _decode_binary(self, record): try: return datum_reader.read(decoder) except Exception as e: - self.logger.error('Failed to decode record: {}'.format(e)) - raise AvroClientError( - 'Failed to decode record: {}'.format(e)) from None + self.logger.error("Failed to decode record: {}".format(e)) + raise AvroClientError("Failed to decode record: {}".format(e)) from None class AvroClientError(Exception): def __init__(self, message=None): - self.message = message \ No newline at end of file + self.message = message diff --git a/src/nypl_py_utils/classes/kinesis_client.py b/src/nypl_py_utils/classes/kinesis_client.py index 1c25b2b..ce28b53 100644 --- a/src/nypl_py_utils/classes/kinesis_client.py +++ b/src/nypl_py_utils/classes/kinesis_client.py @@ -17,20 +17,19 @@ class KinesisClient: """ def __init__(self, stream_arn, batch_size, max_retries=5): - self.logger = create_log('kinesis_client') + self.logger = create_log("kinesis_client") self.stream_arn = stream_arn self.batch_size = batch_size self.max_retries = max_retries try: self.kinesis_client = boto3.client( - 'kinesis', region_name=os.environ.get('AWS_REGION', - 'us-east-1')) + "kinesis", region_name=os.environ.get("AWS_REGION", "us-east-1") + ) except ClientError as e: - self.logger.error( - 'Could not create Kinesis client: {err}'.format(err=e)) + self.logger.error("Could not create Kinesis client: {err}".format(err=e)) raise KinesisClientError( - 'Could not create Kinesis client: {err}'.format(err=e) + "Could not create Kinesis client: {err}".format(err=e) ) from None def close(self): @@ -45,10 +44,11 @@ def send_records(self, records): """ records_sent_since_pause = 0 for i in range(0, len(records), self.batch_size): - encoded_batch = records[i:i + self.batch_size] - kinesis_records = [{'Data': record, 'PartitionKey': - str(int(time.time() * 1000000000))} - for record in encoded_batch] + encoded_batch = records[i : i + self.batch_size] + kinesis_records = [ + {"Data": record, "PartitionKey": str(int(time.time() * 1000000000))} + for record in encoded_batch + ] if records_sent_since_pause + len(encoded_batch) > 1000: records_sent_since_pause = 0 @@ -63,32 +63,41 @@ def _send_kinesis_format_records(self, kinesis_records, call_count): """ if call_count > self.max_retries: self.logger.error( - 'Failed to send records to Kinesis {} times in a row'.format( - call_count-1)) + "Failed to send records to Kinesis {} times in a row".format( + call_count - 1 + ) + ) raise KinesisClientError( - 'Failed to send records to Kinesis {} times in a row'.format( - call_count-1)) from None + "Failed to send records to Kinesis {} times in a row".format( + call_count - 1 + ) + ) from None try: self.logger.info( - 'Sending ({count}) records to {arn} Kinesis stream'.format( - count=len(kinesis_records), arn=self.stream_arn)) + "Sending ({count}) records to {arn} Kinesis stream".format( + count=len(kinesis_records), arn=self.stream_arn + ) + ) response = self.kinesis_client.put_records( - Records=kinesis_records, StreamARN=self.stream_arn) - if response['FailedRecordCount'] > 0: + Records=kinesis_records, StreamARN=self.stream_arn + ) + if response["FailedRecordCount"] > 0: self.logger.warning( - 'Failed to send {} records to Kinesis'.format( - response['FailedRecordCount'])) + "Failed to send {} records to Kinesis".format( + response["FailedRecordCount"] + ) + ) failed_records = [] - for i in range(len(response['Records'])): - if 'ErrorCode' in response['Records'][i]: + for i in range(len(response["Records"])): + if "ErrorCode" in response["Records"][i]: failed_records.append(kinesis_records[i]) - self._send_kinesis_format_records(failed_records, call_count+1) + self._send_kinesis_format_records(failed_records, call_count + 1) except ClientError as e: - self.logger.error( - 'Error sending records to Kinesis: {}'.format(e)) + self.logger.error("Error sending records to Kinesis: {}".format(e)) raise KinesisClientError( - 'Error sending records to Kinesis: {}'.format(e)) from None + "Error sending records to Kinesis: {}".format(e) + ) from None class KinesisClientError(Exception): diff --git a/src/nypl_py_utils/classes/kms_client.py b/src/nypl_py_utils/classes/kms_client.py index 26ecdef..abf1684 100644 --- a/src/nypl_py_utils/classes/kms_client.py +++ b/src/nypl_py_utils/classes/kms_client.py @@ -11,16 +11,17 @@ class KmsClient: """Client for interacting with a KMS client""" def __init__(self): - self.logger = create_log('kms_client') + self.logger = create_log("kms_client") try: self.kms_client = boto3.client( - 'kms', region_name=os.environ.get('AWS_REGION', 'us-east-1')) + "kms", region_name=os.environ.get("AWS_REGION", "us-east-1") + ) except ClientError as e: - self.logger.error( - 'Could not create KMS client: {err}'.format(err=e)) + self.logger.error("Could not create KMS client: {err}".format(err=e)) raise KmsClientError( - 'Could not create KMS client: {err}'.format(err=e)) from None + "Could not create KMS client: {err}".format(err=e) + ) from None def close(self): self.kms_client.close() @@ -30,16 +31,19 @@ def decrypt(self, encrypted_text): This method takes a base 64 KMS-encoded string and uses the KMS client to decrypt it into a usable string. """ - self.logger.debug('Decrypting \'{}\''.format(encrypted_text)) + self.logger.debug("Decrypting '{}'".format(encrypted_text)) try: decoded_text = b64decode(encrypted_text) return self.kms_client.decrypt(CiphertextBlob=decoded_text)[ - 'Plaintext'].decode('utf-8') + "Plaintext" + ].decode("utf-8") except (ClientError, base64Error, TypeError) as e: - self.logger.error('Could not decrypt \'{val}\': {err}'.format( - val=encrypted_text, err=e)) - raise KmsClientError('Could not decrypt \'{val}\': {err}'.format( - val=encrypted_text, err=e)) from None + self.logger.error( + "Could not decrypt '{val}': {err}".format(val=encrypted_text, err=e) + ) + raise KmsClientError( + "Could not decrypt '{val}': {err}".format(val=encrypted_text, err=e) + ) from None class KmsClientError(Exception): diff --git a/src/nypl_py_utils/classes/mysql_client.py b/src/nypl_py_utils/classes/mysql_client.py index 94bb3c7..7828c24 100644 --- a/src/nypl_py_utils/classes/mysql_client.py +++ b/src/nypl_py_utils/classes/mysql_client.py @@ -7,7 +7,7 @@ class MySQLClient: """Client for managing connections to a MySQL database""" def __init__(self, host, port, database, user, password): - self.logger = create_log('mysql_client') + self.logger = create_log("mysql_client") self.conn = None self.host = host self.port = port @@ -28,7 +28,7 @@ def connect(self, **kwargs): Whether to automatically commit each query rather than running them as part of a transaction. By default False. """ - self.logger.info('Connecting to {} database'.format(self.database)) + self.logger.info("Connecting to {} database".format(self.database)) try: self.conn = mysql.connector.connect( host=self.host, @@ -36,14 +36,19 @@ def connect(self, **kwargs): database=self.database, user=self.user, password=self.password, - **kwargs) + **kwargs, + ) except mysql.connector.Error as e: self.logger.error( - 'Error connecting to {name} database: {error}'.format( - name=self.database, error=e)) + "Error connecting to {name} database: {error}".format( + name=self.database, error=e + ) + ) raise MySQLClientError( - 'Error connecting to {name} database: {error}'.format( - name=self.database, error=e)) from None + "Error connecting to {name} database: {error}".format( + name=self.database, error=e + ) + ) from None def execute_query(self, query, query_params=None, **kwargs): """ @@ -71,8 +76,8 @@ def execute_query(self, query, query_params=None, **kwargs): or dictionaries (based on the dictionary input) if there's something to return (even if the result set is empty). """ - self.logger.info('Querying {} database'.format(self.database)) - self.logger.debug('Executing query {}'.format(query)) + self.logger.info("Querying {} database".format(self.database)) + self.logger.debug("Executing query {}".format(query)) try: cursor = self.conn.cursor(**kwargs) cursor.execute(query, query_params) @@ -84,18 +89,21 @@ def execute_query(self, query, query_params=None, **kwargs): except Exception as e: self.conn.rollback() self.logger.error( - ('Error executing {name} database query \'{query}\': {error}') - .format(name=self.database, query=query, error=e)) + ("Error executing {name} database query '{query}': {error}").format( + name=self.database, query=query, error=e + ) + ) raise MySQLClientError( - ('Error executing {name} database query \'{query}\': {error}') - .format(name=self.database, query=query, error=e)) from None + ("Error executing {name} database query '{query}': {error}").format( + name=self.database, query=query, error=e + ) + ) from None finally: cursor.close() def close_connection(self): """Closes the database connection""" - self.logger.debug('Closing {} database connection'.format( - self.database)) + self.logger.debug("Closing {} database connection".format(self.database)) self.conn.close() diff --git a/src/nypl_py_utils/classes/oauth2_api_client.py b/src/nypl_py_utils/classes/oauth2_api_client.py index 5a47efb..5e9761f 100644 --- a/src/nypl_py_utils/classes/oauth2_api_client.py +++ b/src/nypl_py_utils/classes/oauth2_api_client.py @@ -15,20 +15,24 @@ class Oauth2ApiClient: API where empty responses are returned intermittently. """ - def __init__(self, client_id=None, client_secret=None, base_url=None, - token_url=None, with_retries=False): - self.client_id = client_id \ - or os.environ.get('NYPL_API_CLIENT_ID', None) - self.client_secret = client_secret \ - or os.environ.get('NYPL_API_CLIENT_SECRET', None) - self.token_url = token_url \ - or os.environ.get('NYPL_API_TOKEN_URL', None) - self.base_url = base_url \ - or os.environ.get('NYPL_API_BASE_URL', None) + def __init__( + self, + client_id=None, + client_secret=None, + base_url=None, + token_url=None, + with_retries=False, + ): + self.client_id = client_id or os.environ.get("NYPL_API_CLIENT_ID", None) + self.client_secret = client_secret or os.environ.get( + "NYPL_API_CLIENT_SECRET", None + ) + self.token_url = token_url or os.environ.get("NYPL_API_TOKEN_URL", None) + self.base_url = base_url or os.environ.get("NYPL_API_BASE_URL", None) self.oauth_client = None - self.logger = create_log('oauth2_api_client') + self.logger = create_log("oauth2_api_client") self.with_retries = with_retries @@ -36,7 +40,7 @@ def get(self, request_path, **kwargs): """ Issue an HTTP GET on the given request_path """ - resp = self._do_http_method('GET', request_path, **kwargs) + resp = self._do_http_method("GET", request_path, **kwargs) # This try/except block is to handle one of at least two possible # Sierra server errors. One is an empty response, and another is a # response with a 200 status code but response text in HTML declaring @@ -46,25 +50,28 @@ def get(self, request_path, **kwargs): except Exception: # build default server error response resp = Response() - resp.message = 'Oauth2 Client: Bad response from OauthClient' + resp.message = "Oauth2 Client: Bad response from OauthClient" resp.status_code = 500 - self.logger.warning(f'Get request using path {request_path} \ -returned response text:\n{resp.text}') + self.logger.warning( + f"Get request using path {request_path} \ +returned response text:\n{resp.text}" + ) # if client has specified that we want to retry failed requests and # we haven't hit max retries if self.with_retries is True: - retries = kwargs.get('retries', 0) + 1 + retries = kwargs.get("retries", 0) + 1 if retries < 3: self.logger.warning( - f'Retrying get request due to empty response from\ -Oauth2 Client using path: {request_path}. Retry #{retries}') + f"Retrying get request due to empty response from\ +Oauth2 Client using path: {request_path}. Retry #{retries}" + ) sleep(pow(2, retries - 1)) - kwargs['retries'] = retries + kwargs["retries"] = retries # try request again resp = self.get(request_path, **kwargs) else: - resp.message = 'Oauth2 Client: Request failed after 3 \ - empty responses received from Oauth2 Client' + resp.message = "Oauth2 Client: Request failed after 3 \ + empty responses received from Oauth2 Client" # Return request. If retries returned real data, it will be here, # otherwise it will be the default 500 response generated earlier. return resp @@ -73,21 +80,21 @@ def post(self, request_path, json, **kwargs): """ Issue an HTTP POST on the given request_path with given JSON body """ - kwargs['json'] = json - return self._do_http_method('POST', request_path, **kwargs) + kwargs["json"] = json + return self._do_http_method("POST", request_path, **kwargs) def patch(self, request_path, json, **kwargs): """ Issue an HTTP PATCH on the given request_path with given JSON body """ - kwargs['json'] = json - return self._do_http_method('PATCH', request_path, **kwargs) + kwargs["json"] = json + return self._do_http_method("PATCH", request_path, **kwargs) def delete(self, request_path, **kwargs): """ Issue an HTTP DELETE on the given request_path """ - return self._do_http_method('DELETE', request_path, **kwargs) + return self._do_http_method("DELETE", request_path, **kwargs) def _do_http_method(self, method, request_path, **kwargs): """ @@ -96,25 +103,26 @@ def _do_http_method(self, method, request_path, **kwargs): if not self.oauth_client: self._create_oauth_client() - url = f'{self.base_url}/{request_path}' - self.logger.debug(f'{method} {url}') + url = f"{self.base_url}/{request_path}" + self.logger.debug(f"{method} {url}") try: # Build kwargs cleaned of local variables: - kwargs_cleaned = {k: kwargs[k] for k in kwargs - if not k.startswith('_do_http_method_')} + kwargs_cleaned = { + k: kwargs[k] for k in kwargs if not k.startswith("_do_http_method_") + } resp = self.oauth_client.request(method, url, **kwargs_cleaned) resp.raise_for_status() return resp except TokenExpiredError: - self.logger.debug('TokenExpiredError encountered') + self.logger.debug("TokenExpiredError encountered") # Raise error after 3 successive token refreshes - kwargs['_do_http_method_token_refreshes'] = \ - kwargs.get('_do_http_method_token_refreshes', 0) + 1 - if kwargs['_do_http_method_token_refreshes'] > 3: - raise Oauth2ApiClientError('Exhausted token refreshes') \ - from None + kwargs["_do_http_method_token_refreshes"] = ( + kwargs.get("_do_http_method_token_refreshes", 0) + 1 + ) + if kwargs["_do_http_method_token_refreshes"] > 3: + raise Oauth2ApiClientError("Exhausted token refreshes") from None self._generate_access_token() return self._do_http_method(method, request_path, **kwargs) @@ -131,11 +139,11 @@ def _generate_access_token(self): """ Fetch and store a fresh token """ - self.logger.debug(f'Refreshing token via @{self.token_url}') + self.logger.debug(f"Refreshing token via @{self.token_url}") self.oauth_client.fetch_token( token_url=self.token_url, client_id=self.client_id, - client_secret=self.client_secret + client_secret=self.client_secret, ) diff --git a/src/nypl_py_utils/classes/postgresql_client.py b/src/nypl_py_utils/classes/postgresql_client.py index 05c7a97..569a203 100644 --- a/src/nypl_py_utils/classes/postgresql_client.py +++ b/src/nypl_py_utils/classes/postgresql_client.py @@ -7,12 +7,11 @@ class PostgreSQLClient: """Client for managing individual connections to a PostgreSQL database""" def __init__(self, host, port, db_name, user, password): - self.logger = create_log('postgresql_client') + self.logger = create_log("postgresql_client") self.conn = None - self.conn_info = ('postgresql://{user}:{password}@{host}:{port}/' - '{db_name}').format(user=user, password=password, - host=host, port=port, - db_name=db_name) + self.conn_info = ( + "postgresql://{user}:{password}@{host}:{port}/" "{db_name}" + ).format(user=user, password=password, host=host, port=port, db_name=db_name) self.db_name = db_name @@ -33,16 +32,20 @@ def connect(self, **kwargs): returned. Defaults to tuple_row, which returns the rows as a list of tuples. """ - self.logger.info('Connecting to {} database'.format(self.db_name)) + self.logger.info("Connecting to {} database".format(self.db_name)) try: self.conn = psycopg.connect(self.conn_info, **kwargs) except psycopg.Error as e: self.logger.error( - 'Error connecting to {name} database: {error}'.format( - name=self.db_name, error=e)) + "Error connecting to {name} database: {error}".format( + name=self.db_name, error=e + ) + ) raise PostgreSQLClientError( - 'Error connecting to {name} database: {error}'.format( - name=self.db_name, error=e)) from None + "Error connecting to {name} database: {error}".format( + name=self.db_name, error=e + ) + ) from None def execute_query(self, query, query_params=None, **kwargs): """ @@ -65,8 +68,8 @@ def execute_query(self, query, query_params=None, **kwargs): based on the connection's row_factory if there's something to return (even if the result set is empty). """ - self.logger.info('Querying {} database'.format(self.db_name)) - self.logger.debug('Executing query {}'.format(query)) + self.logger.info("Querying {} database".format(self.db_name)) + self.logger.debug("Executing query {}".format(query)) try: cursor = self.conn.cursor() cursor.execute(query, query_params, **kwargs) @@ -75,20 +78,21 @@ def execute_query(self, query, query_params=None, **kwargs): except Exception as e: self.conn.rollback() self.logger.error( - ('Error executing {name} database query \'{query}\': ' - '{error}').format( - name=self.db_name, query=query, error=e)) + ("Error executing {name} database query '{query}': " "{error}").format( + name=self.db_name, query=query, error=e + ) + ) raise PostgreSQLClientError( - ('Error executing {name} database query \'{query}\': ' - '{error}').format( - name=self.db_name, query=query, error=e)) from None + ("Error executing {name} database query '{query}': " "{error}").format( + name=self.db_name, query=query, error=e + ) + ) from None finally: cursor.close() def close_connection(self): """Closes the database connection""" - self.logger.debug('Closing {} database connection'.format( - self.db_name)) + self.logger.debug("Closing {} database connection".format(self.db_name)) self.conn.close() diff --git a/src/nypl_py_utils/classes/postgresql_pool_client.py b/src/nypl_py_utils/classes/postgresql_pool_client.py index beaf589..47a2ca9 100644 --- a/src/nypl_py_utils/classes/postgresql_pool_client.py +++ b/src/nypl_py_utils/classes/postgresql_pool_client.py @@ -8,8 +8,9 @@ class PostgreSQLPoolClient: """Client for managing a connection pool to a PostgreSQL database""" - def __init__(self, host, port, db_name, user, password, conn_timeout=300.0, - **kwargs): + def __init__( + self, host, port, db_name, user, password, conn_timeout=300.0, **kwargs + ): """ Creates (but does not open) a connection pool. @@ -32,25 +33,30 @@ def __init__(self, host, port, db_name, user, password, conn_timeout=300.0, min_size connections, which will stay open until manually closed. """ - self.logger = create_log('postgresql_client') - self.conn_info = ('postgresql://{user}:{password}@{host}:{port}/' - '{db_name}').format(user=user, password=password, - host=host, port=port, - db_name=db_name) + self.logger = create_log("postgresql_client") + self.conn_info = ( + "postgresql://{user}:{password}@{host}:{port}/" "{db_name}" + ).format(user=user, password=password, host=host, port=port, db_name=db_name) self.db_name = db_name self.kwargs = kwargs - self.kwargs['min_size'] = kwargs.get('min_size', 0) - self.kwargs['max_size'] = kwargs.get('max_size', 1) - self.kwargs['max_idle'] = kwargs.get('max_idle', 90.0) - - if self.kwargs['max_idle'] > 150.0: - self.logger.error(( - 'max_idle is too high -- values over 150 seconds are unsafe ' - 'and may lead to connection leakages in ECS')) - raise PostgreSQLPoolClientError(( - 'max_idle is too high -- values over 150 seconds are unsafe ' - 'and may lead to connection leakages in ECS')) from None + self.kwargs["min_size"] = kwargs.get("min_size", 0) + self.kwargs["max_size"] = kwargs.get("max_size", 1) + self.kwargs["max_idle"] = kwargs.get("max_idle", 90.0) + + if self.kwargs["max_idle"] > 150.0: + self.logger.error( + ( + "max_idle is too high -- values over 150 seconds are unsafe " + "and may lead to connection leakages in ECS" + ) + ) + raise PostgreSQLPoolClientError( + ( + "max_idle is too high -- values over 150 seconds are unsafe " + "and may lead to connection leakages in ECS" + ) + ) from None self.pool = ConnectionPool(self.conn_info, open=False, **self.kwargs) @@ -65,22 +71,24 @@ def connect(self, timeout=300.0): The number of seconds to try connecting before throwing an error. Defaults to 300 seconds. """ - self.logger.info('Connecting to {} database'.format(self.db_name)) + self.logger.info("Connecting to {} database".format(self.db_name)) try: if self.pool is None: - self.pool = ConnectionPool( - self.conn_info, open=False, **self.kwargs) + self.pool = ConnectionPool(self.conn_info, open=False, **self.kwargs) self.pool.open(wait=True, timeout=timeout) except psycopg.Error as e: self.logger.error( - 'Error connecting to {name} database: {error}'.format( - name=self.db_name, error=e)) + "Error connecting to {name} database: {error}".format( + name=self.db_name, error=e + ) + ) raise PostgreSQLPoolClientError( - 'Error connecting to {name} database: {error}'.format( - name=self.db_name, error=e)) from None + "Error connecting to {name} database: {error}".format( + name=self.db_name, error=e + ) + ) from None - def execute_query(self, query, query_params=None, row_factory=tuple_row, - **kwargs): + def execute_query(self, query, query_params=None, row_factory=tuple_row, **kwargs): """ Requests a connection from the pool and uses it to execute an arbitrary query. After the query is complete, either commits it or rolls it back, @@ -106,28 +114,28 @@ def execute_query(self, query, query_params=None, row_factory=tuple_row, based on the row_factory input if there's something to return (even if the result set is empty). """ - self.logger.info('Querying {} database'.format(self.db_name)) - self.logger.debug('Executing query {}'.format(query)) + self.logger.info("Querying {} database".format(self.db_name)) + self.logger.debug("Executing query {}".format(query)) with self.pool.connection() as conn: try: conn.row_factory = row_factory cursor = conn.execute(query, query_params, **kwargs) - return (None if cursor.description is None - else cursor.fetchall()) + return None if cursor.description is None else cursor.fetchall() except Exception as e: self.logger.error( - ('Error executing {name} database query \'{query}\': ' - '{error}').format( - name=self.db_name, query=query, error=e)) + ( + "Error executing {name} database query '{query}': " "{error}" + ).format(name=self.db_name, query=query, error=e) + ) raise PostgreSQLPoolClientError( - ('Error executing {name} database query \'{query}\': ' - '{error}').format( - name=self.db_name, query=query, error=e)) from None + ( + "Error executing {name} database query '{query}': " "{error}" + ).format(name=self.db_name, query=query, error=e) + ) from None def close_pool(self): """Closes the connection pool""" - self.logger.debug('Closing {} database connection pool'.format( - self.db_name)) + self.logger.debug("Closing {} database connection pool".format(self.db_name)) self.pool.close() self.pool = None diff --git a/src/nypl_py_utils/classes/redshift_client.py b/src/nypl_py_utils/classes/redshift_client.py index 17c4558..2fc3ad9 100644 --- a/src/nypl_py_utils/classes/redshift_client.py +++ b/src/nypl_py_utils/classes/redshift_client.py @@ -8,7 +8,7 @@ class RedshiftClient: """Client for managing connections to Redshift""" def __init__(self, host, database, user, password): - self.logger = create_log('redshift_client') + self.logger = create_log("redshift_client") self.conn = None self.host = host self.database = database @@ -17,21 +17,26 @@ def __init__(self, host, database, user, password): def connect(self): """Connects to a Redshift database using the given credentials""" - self.logger.info('Connecting to {} database'.format(self.database)) + self.logger.info("Connecting to {} database".format(self.database)) try: self.conn = redshift_connector.connect( host=self.host, database=self.database, user=self.user, password=self.password, - sslmode='verify-full') + sslmode="verify-full", + ) except ClientError as e: self.logger.error( - 'Error connecting to {name} database: {error}'.format( - name=self.database, error=e)) + "Error connecting to {name} database: {error}".format( + name=self.database, error=e + ) + ) raise RedshiftClientError( - 'Error connecting to {name} database: {error}'.format( - name=self.database, error=e)) from None + "Error connecting to {name} database: {error}".format( + name=self.database, error=e + ) + ) from None def execute_query(self, query, dataframe=False): """ @@ -51,8 +56,8 @@ def execute_query(self, query, dataframe=False): A list of tuples or a pandas DataFrame (based on the `dataframe` input) """ - self.logger.info('Querying {} database'.format(self.database)) - self.logger.debug('Executing query {}'.format(query)) + self.logger.info("Querying {} database".format(self.database)) + self.logger.debug("Executing query {}".format(query)) try: cursor = self.conn.cursor() cursor.execute(query) @@ -63,11 +68,15 @@ def execute_query(self, query, dataframe=False): except Exception as e: self.conn.rollback() self.logger.error( - ('Error executing {name} database query \'{query}\': {error}') - .format(name=self.database, query=query, error=e)) + ("Error executing {name} database query '{query}': {error}").format( + name=self.database, query=query, error=e + ) + ) raise RedshiftClientError( - ('Error executing {name} database query \'{query}\': {error}') - .format(name=self.database, query=query, error=e)) from None + ("Error executing {name} database query '{query}': {error}").format( + name=self.database, query=query, error=e + ) + ) from None finally: cursor.close() @@ -88,37 +97,40 @@ def execute_transaction(self, queries): "INSERT INTO x VALUES (%s, %s)", [(1, "a"), (2, "b")]) """ - self.logger.info('Executing transaction against {} database'.format( - self.database)) + self.logger.info( + "Executing transaction against {} database".format(self.database) + ) try: cursor = self.conn.cursor() - cursor.execute('BEGIN TRANSACTION;') + cursor.execute("BEGIN TRANSACTION;") for query in queries: - self.logger.debug('Executing query {}'.format(query)) + self.logger.debug("Executing query {}".format(query)) if query[1] is not None and all( - isinstance(el, tuple) or isinstance(el, list) - for el in query[1] + isinstance(el, tuple) or isinstance(el, list) for el in query[1] ): cursor.executemany(query[0], query[1]) else: cursor.execute(query[0], query[1]) - cursor.execute('END TRANSACTION;') + cursor.execute("END TRANSACTION;") self.conn.commit() except Exception as e: self.conn.rollback() self.logger.error( - ('Error executing {name} database transaction: {error}') - .format(name=self.database, error=e)) + ("Error executing {name} database transaction: {error}").format( + name=self.database, error=e + ) + ) raise RedshiftClientError( - ('Error executing {name} database transaction: {error}') - .format(name=self.database, error=e)) from None + ("Error executing {name} database transaction: {error}").format( + name=self.database, error=e + ) + ) from None finally: cursor.close() def close_connection(self): """Closes the database connection""" - self.logger.debug('Closing {} database connection'.format( - self.database)) + self.logger.debug("Closing {} database connection".format(self.database)) self.conn.close() diff --git a/src/nypl_py_utils/classes/s3_client.py b/src/nypl_py_utils/classes/s3_client.py index af71531..4536624 100644 --- a/src/nypl_py_utils/classes/s3_client.py +++ b/src/nypl_py_utils/classes/s3_client.py @@ -15,56 +15,66 @@ class S3Client: """ def __init__(self, bucket, resource): - self.logger = create_log('s3_client') + self.logger = create_log("s3_client") self.bucket = bucket self.resource = resource try: self.s3_client = boto3.client( - 's3', region_name=os.environ.get('AWS_REGION', 'us-east-1')) + "s3", region_name=os.environ.get("AWS_REGION", "us-east-1") + ) except ClientError as e: - self.logger.error( - 'Could not create S3 client: {err}'.format(err=e)) + self.logger.error("Could not create S3 client: {err}".format(err=e)) raise S3ClientError( - 'Could not create S3 client: {err}'.format(err=e)) from None + "Could not create S3 client: {err}".format(err=e) + ) from None def close(self): self.s3_client.close() def fetch_cache(self): """Fetches a JSON file from S3 and returns the resulting dictionary""" - self.logger.info('Fetching {file} from S3 bucket {bucket}'.format( - file=self.resource, bucket=self.bucket)) + self.logger.info( + "Fetching {file} from S3 bucket {bucket}".format( + file=self.resource, bucket=self.bucket + ) + ) try: output_stream = BytesIO() - self.s3_client.download_fileobj( - self.bucket, self.resource, output_stream) + self.s3_client.download_fileobj(self.bucket, self.resource, output_stream) return json.loads(output_stream.getvalue()) except ClientError as e: self.logger.error( - 'Error retrieving {file} from S3 bucket {bucket}: {error}' - .format(file=self.resource, bucket=self.bucket, error=e)) + "Error retrieving {file} from S3 bucket {bucket}: {error}".format( + file=self.resource, bucket=self.bucket, error=e + ) + ) raise S3ClientError( - 'Error retrieving {file} from S3 bucket {bucket}: {error}' - .format(file=self.resource, bucket=self.bucket, error=e) + "Error retrieving {file} from S3 bucket {bucket}: {error}".format( + file=self.resource, bucket=self.bucket, error=e + ) ) from None def set_cache(self, state): """Writes a dictionary to JSON and uploads the resulting file to S3""" self.logger.info( - 'Setting {file} in S3 bucket {bucket} to {state}'.format( - file=self.resource, bucket=self.bucket, state=state)) + "Setting {file} in S3 bucket {bucket} to {state}".format( + file=self.resource, bucket=self.bucket, state=state + ) + ) try: input_stream = BytesIO(json.dumps(state).encode()) - self.s3_client.upload_fileobj( - input_stream, self.bucket, self.resource) + self.s3_client.upload_fileobj(input_stream, self.bucket, self.resource) except ClientError as e: self.logger.error( - 'Error uploading {file} to S3 bucket {bucket}: {error}' - .format(file=self.resource, bucket=self.bucket, error=e)) + "Error uploading {file} to S3 bucket {bucket}: {error}".format( + file=self.resource, bucket=self.bucket, error=e + ) + ) raise S3ClientError( - 'Error uploading {file} to S3 bucket {bucket}: {error}' - .format(file=self.resource, bucket=self.bucket, error=e) + "Error uploading {file} to S3 bucket {bucket}: {error}".format( + file=self.resource, bucket=self.bucket, error=e + ) ) from None diff --git a/src/nypl_py_utils/functions/config_helper.py b/src/nypl_py_utils/functions/config_helper.py index 7edb5ea..c0192e5 100644 --- a/src/nypl_py_utils/functions/config_helper.py +++ b/src/nypl_py_utils/functions/config_helper.py @@ -5,7 +5,7 @@ from nypl_py_utils.classes.kms_client import KmsClient from nypl_py_utils.functions.log_helper import create_log -logger = create_log('config_helper') +logger = create_log("config_helper") def load_env_file(run_type, file_string): @@ -30,29 +30,31 @@ def load_env_file(run_type, file_string): env_dict = None open_file = file_string.format(run_type) - logger.info('Loading env file {}'.format(open_file)) + logger.info("Loading env file {}".format(open_file)) try: - with open(open_file, 'r') as env_stream: + with open(open_file, "r") as env_stream: try: env_dict = yaml.safe_load(env_stream) except yaml.YAMLError: - logger.error('Invalid YAML file: {}'.format(open_file)) + logger.error("Invalid YAML file: {}".format(open_file)) raise ConfigHelperError( - 'Invalid YAML file: {}'.format(open_file)) from None + "Invalid YAML file: {}".format(open_file) + ) from None except FileNotFoundError: - logger.error('Could not find config file {}'.format(open_file)) + logger.error("Could not find config file {}".format(open_file)) raise ConfigHelperError( - 'Could not find config file {}'.format(open_file)) from None + "Could not find config file {}".format(open_file) + ) from None if env_dict: - for key, value in env_dict.get('PLAINTEXT_VARIABLES', {}).items(): + for key, value in env_dict.get("PLAINTEXT_VARIABLES", {}).items(): if type(value) is list: os.environ[key] = json.dumps(value) else: os.environ[key] = str(value) kms_client = KmsClient() - for key, value in env_dict.get('ENCRYPTED_VARIABLES', {}).items(): + for key, value in env_dict.get("ENCRYPTED_VARIABLES", {}).items(): if type(value) is list: decrypted_list = [kms_client.decrypt(v) for v in value] os.environ[key] = json.dumps(decrypted_list) diff --git a/src/nypl_py_utils/functions/log_helper.py b/src/nypl_py_utils/functions/log_helper.py index 7d7bf78..7eb7b83 100644 --- a/src/nypl_py_utils/functions/log_helper.py +++ b/src/nypl_py_utils/functions/log_helper.py @@ -3,11 +3,11 @@ import sys levels = { - 'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, } @@ -18,13 +18,12 @@ def create_log(module): console_log = logging.StreamHandler(stream=sys.stdout) - log_level = os.environ.get('LOG_LEVEL', 'info').lower() + log_level = os.environ.get("LOG_LEVEL", "info").lower() logger.setLevel(levels[log_level]) console_log.setLevel(levels[log_level]) - formatter = logging.Formatter( - '%(asctime)s | %(name)s | %(levelname)s: %(message)s') + formatter = logging.Formatter("%(asctime)s | %(name)s | %(levelname)s: %(message)s") console_log.setFormatter(formatter) logger.addHandler(console_log) diff --git a/src/nypl_py_utils/functions/obfuscation_helper.py b/src/nypl_py_utils/functions/obfuscation_helper.py index 4209f86..e3cc4ae 100644 --- a/src/nypl_py_utils/functions/obfuscation_helper.py +++ b/src/nypl_py_utils/functions/obfuscation_helper.py @@ -3,7 +3,7 @@ from nypl_py_utils.functions.log_helper import create_log -logger = create_log('obfuscation_helper') +logger = create_log("obfuscation_helper") def obfuscate(input): @@ -16,11 +16,11 @@ def obfuscate(input): but is converted to a string before being obfuscated. The obfuscation salt is read from the `BCRYPT_SALT` environment variable. """ - logger.debug('Obfuscating input \'{}\' with environment salt'.format( - input)) - hash = bcrypt.hashpw(str(input).encode(), - os.environ['BCRYPT_SALT'].encode()).decode() - return hash.split(os.environ['BCRYPT_SALT'])[-1] + logger.debug("Obfuscating input '{}' with environment salt".format(input)) + hash = bcrypt.hashpw( + str(input).encode(), os.environ["BCRYPT_SALT"].encode() + ).decode() + return hash.split(os.environ["BCRYPT_SALT"])[-1] def obfuscate_with_salt(input, salt): @@ -28,6 +28,6 @@ def obfuscate_with_salt(input, salt): This method is the same as `obfuscate` above but takes the obfuscation salt as a string input. """ - logger.debug('Obfuscating input \'{}\' with custom salt'.format(input)) + logger.debug("Obfuscating input '{}' with custom salt".format(input)) hash = bcrypt.hashpw(str(input).encode(), salt.encode()).decode() return hash.split(salt)[-1] diff --git a/src/nypl_py_utils/functions/research_catalog_identifier_helper.py b/src/nypl_py_utils/functions/research_catalog_identifier_helper.py index 4079faf..00afc63 100644 --- a/src/nypl_py_utils/functions/research_catalog_identifier_helper.py +++ b/src/nypl_py_utils/functions/research_catalog_identifier_helper.py @@ -16,60 +16,51 @@ def parse_research_catalog_identifier(identifier: str): - id: The numeric string id """ if not isinstance(identifier, str): - raise ResearchCatalogIdentifierError( - f'Invalid RC identifier: {identifier}') + raise ResearchCatalogIdentifierError(f"Invalid RC identifier: {identifier}") # Extract prefix from the identifier: - match = re.match(r'^([a-z]+)', identifier) + match = re.match(r"^([a-z]+)", identifier) if match is None: - raise ResearchCatalogIdentifierError( - f'Invalid RC identifier: {identifier}') + raise ResearchCatalogIdentifierError(f"Invalid RC identifier: {identifier}") prefix = match[0] # The id is the identifier without the prefix: - id = identifier.replace(prefix, '') + id = identifier.replace(prefix, "") nyplType = None nyplSource = None # Look up nyplType and nyplSource in nypl-core based on the prefix: for _nyplSource, mapping in nypl_core_source_mapping().items(): - if mapping.get('bibPrefix') == prefix: - nyplType = 'bib' - elif mapping.get('itemPrefix') == prefix: - nyplType = 'item' - elif mapping.get('holdingPrefix') == prefix: - nyplType = 'holding' + if mapping.get("bibPrefix") == prefix: + nyplType = "bib" + elif mapping.get("itemPrefix") == prefix: + nyplType = "item" + elif mapping.get("holdingPrefix") == prefix: + nyplType = "holding" if nyplType is not None: nyplSource = _nyplSource break if nyplSource is None: - raise ResearchCatalogIdentifierError( - f'Invalid RC identifier: {identifier}') + raise ResearchCatalogIdentifierError(f"Invalid RC identifier: {identifier}") - return { - 'nyplSource': nyplSource, - 'nyplType': nyplType, - 'id': id - } + return {"nyplSource": nyplSource, "nyplType": nyplType, "id": id} -def research_catalog_id_prefix(nyplSource: str, nyplType='bib'): +def research_catalog_id_prefix(nyplSource: str, nyplType="bib"): """ Given a nyplSource (e.g. 'sierra-nypl') and nyplType (e.g. 'item'), returns the relevant prefix used in the RC identifier (e.g. 'i') """ if nypl_core_source_mapping().get(nyplSource) is None: - raise ResearchCatalogIdentifierError( - f'Invalid nyplSource: {nyplSource}') + raise ResearchCatalogIdentifierError(f"Invalid nyplSource: {nyplSource}") if not isinstance(nyplType, str): - raise ResearchCatalogIdentifierError( - f'Invalid nyplType: {nyplType}') + raise ResearchCatalogIdentifierError(f"Invalid nyplType: {nyplType}") - prefixKey = f'{nyplType}Prefix' + prefixKey = f"{nyplType}Prefix" if nypl_core_source_mapping()[nyplSource].get(prefixKey) is None: - raise ResearchCatalogIdentifierError(f'Invalid nyplType: {nyplType}') + raise ResearchCatalogIdentifierError(f"Invalid nyplType: {nyplType}") return nypl_core_source_mapping()[nyplSource][prefixKey] @@ -78,29 +69,33 @@ def nypl_core_source_mapping(): """ Builds a nypl-source-mapping by retrieving the mapping from NYPL-Core """ - name = 'nypl-core-source-mapping' + name = "nypl-core-source-mapping" if not CACHE.get(name) is None: return CACHE[name] - url = os.environ.get('NYPL_CORE_SOURCE_MAPPING_URL', - 'https://raw.githubusercontent.com/NYPL/nypl-core/master/mappings/recap-discovery/nypl-source-mapping.json') # noqa + url = os.environ.get( + "NYPL_CORE_SOURCE_MAPPING_URL", + "https://raw.githubusercontent.com/NYPL/nypl-core/master/mappings/recap-discovery/nypl-source-mapping.json", + ) # noqa try: response = requests.get(url) response.raise_for_status() except RequestException as e: raise ResearchCatalogIdentifierError( - 'Failed to retrieve nypl-core source-mapping file from {url}:' - ' {errorType} {errorMessage}' - .format(url=url, errorType=type(e), errorMessage=e)) from None + "Failed to retrieve nypl-core source-mapping file from {url}:" + " {errorType} {errorMessage}".format( + url=url, errorType=type(e), errorMessage=e + ) + ) from None try: CACHE[name] = response.json() return CACHE[name] except (JSONDecodeError, KeyError) as e: raise ResearchCatalogIdentifierError( - 'Failed to parse nypl-core source-mapping file: {errorType}' - ' {errorMessage}' - .format(errorType=type(e), errorMessage=e)) from None + "Failed to parse nypl-core source-mapping file: {errorType}" + " {errorMessage}".format(errorType=type(e), errorMessage=e) + ) from None class ResearchCatalogIdentifierError(Exception): diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py index 7e26981..7a4946d 100644 --- a/tests/test_avro_client.py +++ b/tests/test_avro_client.py @@ -4,99 +4,113 @@ from nypl_py_utils.classes.avro_client import AvroDecoder, AvroEncoder, AvroClientError from requests.exceptions import ConnectTimeout -_TEST_SCHEMA = {'data': {'schema': json.dumps({ - 'name': 'TestSchema', - 'type': 'record', - 'fields': [ - { - 'name': 'patron_id', - 'type': 'int' - }, - { - 'name': 'library_branch', - 'type': ['null', 'string'] - } - ] -})}} +_TEST_SCHEMA = { + "data": { + "schema": json.dumps( + { + "name": "TestSchema", + "type": "record", + "fields": [ + {"name": "patron_id", "type": "int"}, + {"name": "library_branch", "type": ["null", "string"]}, + ], + } + ) + } +} class TestAvroClient: @pytest.fixture def test_avro_encoder_instance(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) - return AvroEncoder('https://test_schema_url') + requests_mock.get("https://test_schema_url", text=json.dumps(_TEST_SCHEMA)) + return AvroEncoder("https://test_schema_url") @pytest.fixture def test_avro_decoder_instance(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) - return AvroDecoder('https://test_schema_url') + requests_mock.get("https://test_schema_url", text=json.dumps(_TEST_SCHEMA)) + # requests_mock.get( + # 'https://test_schema_url', text=json.dumps(_CIRC_TRANS_SCHEMA)) + return AvroDecoder("https://test_schema_url") - def test_get_json_schema(self, test_avro_encoder_instance, test_avro_decoder_instance): - assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data']['schema'] - assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data']['schema'] + def test_get_json_schema( + self, test_avro_encoder_instance, test_avro_decoder_instance + ): + assert test_avro_encoder_instance.schema == _TEST_SCHEMA["data"]["schema"] + assert test_avro_decoder_instance.schema == _TEST_SCHEMA["data"]["schema"] def test_request_error(self, requests_mock): - requests_mock.get('https://test_schema_url', exc=ConnectTimeout) + requests_mock.get("https://test_schema_url", exc=ConnectTimeout) with pytest.raises(AvroClientError): - AvroEncoder('https://test_schema_url') + AvroEncoder("https://test_schema_url") def test_bad_json_error(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text='bad json') + requests_mock.get("https://test_schema_url", text="bad json") with pytest.raises(AvroClientError): - AvroEncoder('https://test_schema_url') + AvroEncoder("https://test_schema_url") def test_missing_key_error(self, requests_mock): requests_mock.get( - 'https://test_schema_url', text=json.dumps({'field': 'value'})) + "https://test_schema_url", text=json.dumps({"field": "value"}) + ) with pytest.raises(AvroClientError): - AvroEncoder('https://test_schema_url') + AvroEncoder("https://test_schema_url") - def test_encode_record(self, test_avro_encoder_instance, test_avro_decoder_instance): - TEST_RECORD = {'patron_id': 123, 'library_branch': 'aa'} + def test_encode_record( + self, test_avro_encoder_instance, test_avro_decoder_instance + ): + TEST_RECORD = {"patron_id": 123, "library_branch": "aa"} encoded_record = test_avro_encoder_instance.encode_record(TEST_RECORD) assert type(encoded_record) is bytes assert test_avro_decoder_instance.decode_record(encoded_record) == TEST_RECORD def test_encode_record_error(self, test_avro_encoder_instance): - TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} + TEST_RECORD = {"patron_id": 123, "bad_field": "bad"} with pytest.raises(AvroClientError): test_avro_encoder_instance.encode_record(TEST_RECORD) def test_encode_batch(self, test_avro_encoder_instance, test_avro_decoder_instance): TEST_BATCH = [ - {'patron_id': 123, 'library_branch': 'aa'}, - {'patron_id': 456, 'library_branch': None}, - {'patron_id': 789, 'library_branch': 'bb'}] + {"patron_id": 123, "library_branch": "aa"}, + {"patron_id": 456, "library_branch": None}, + {"patron_id": 789, "library_branch": "bb"}, + ] encoded_records = test_avro_encoder_instance.encode_batch(TEST_BATCH) assert len(encoded_records) == len(TEST_BATCH) for i in range(3): assert type(encoded_records[i]) is bytes - assert test_avro_decoder_instance.decode_record( - encoded_records[i]) == TEST_BATCH[i] + assert ( + test_avro_decoder_instance.decode_record(encoded_records[i]) + == TEST_BATCH[i] + ) def test_encode_batch_error(self, test_avro_encoder_instance): BAD_BATCH = [ - {'patron_id': 123, 'library_branch': 'aa'}, - {'patron_id': 456, 'bad_field': 'bad'}] + {"patron_id": 123, "library_branch": "aa"}, + {"patron_id": 456, "bad_field": "bad"}, + ] with pytest.raises(AvroClientError): test_avro_encoder_instance.encode_batch(BAD_BATCH) def test_decode_record_binary(self, test_avro_decoder_instance): TEST_DECODED_RECORD = {"patron_id": 123, "library_branch": "aa"} - TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' - assert test_avro_decoder_instance.decode_record( - TEST_ENCODED_RECORD) == TEST_DECODED_RECORD - + TEST_ENCODED_RECORD = b"\xf6\x01\x02\x04aa" + assert ( + test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) + == TEST_DECODED_RECORD + ) + def test_decode_record_b64(self, test_avro_decoder_instance): TEST_DECODED_RECORD = {"patron_id'": 123, "library_branch": "aa"} - TEST_ENCODED_RECORD = "eyJwYXRyb25faWQnIjogMTIzLCAibGlicmFyeV9icmFuY2giOiAiYWEifQ==" - assert test_avro_decoder_instance.decode_record( - TEST_ENCODED_RECORD, "base64") == TEST_DECODED_RECORD + TEST_ENCODED_RECORD = ( + "eyJwYXRyb25faWQnIjogMTIzLCAibGlicmFyeV9icmFuY2giOiAiYWEifQ==" + ) + assert ( + test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD, "base64") + == TEST_DECODED_RECORD + ) def test_decode_record_error(self, test_avro_decoder_instance): - TEST_ENCODED_RECORD = b'bad-encoding' + TEST_ENCODED_RECORD = b"bad-encoding" with pytest.raises(AvroClientError): - test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) + test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) \ No newline at end of file diff --git a/tests/test_config_helper.py b/tests/test_config_helper.py index feadfe5..9c3b149 100644 --- a/tests/test_config_helper.py +++ b/tests/test_config_helper.py @@ -1,15 +1,18 @@ import os import pytest -from nypl_py_utils.functions.config_helper import ( - load_env_file, ConfigHelperError) +from nypl_py_utils.functions.config_helper import load_env_file, ConfigHelperError _TEST_VARIABLE_NAMES = [ - 'TEST_STRING', 'TEST_INT', 'TEST_LIST', 'TEST_ENCRYPTED_VARIABLE_1', - 'TEST_ENCRYPTED_VARIABLE_2', 'TEST_ENCRYPTED_LIST'] + "TEST_STRING", + "TEST_INT", + "TEST_LIST", + "TEST_ENCRYPTED_VARIABLE_1", + "TEST_ENCRYPTED_VARIABLE_2", + "TEST_ENCRYPTED_LIST", +] -_TEST_CONFIG_CONTENTS = \ - '''--- +_TEST_CONFIG_CONTENTS = """--- PLAINTEXT_VARIABLES: TEST_STRING: string-variable TEST_INT: 1 @@ -22,7 +25,7 @@ TEST_ENCRYPTED_LIST: - test-encryption-3 - test-encryption-4 -...''' +...""" class TestConfigHelper: @@ -30,30 +33,42 @@ class TestConfigHelper: def test_load_env_file(self, mocker): mock_kms_client = mocker.MagicMock() mock_kms_client.decrypt.side_effect = [ - 'test-decryption-1', 'test-decryption-2', 'test-decryption-3', - 'test-decryption-4'] - mocker.patch('nypl_py_utils.functions.config_helper.KmsClient', - return_value=mock_kms_client) + "test-decryption-1", + "test-decryption-2", + "test-decryption-3", + "test-decryption-4", + ] + mocker.patch( + "nypl_py_utils.functions.config_helper.KmsClient", + return_value=mock_kms_client, + ) mock_file_open = mocker.patch( - 'builtins.open', mocker.mock_open(read_data=_TEST_CONFIG_CONTENTS)) + "builtins.open", mocker.mock_open(read_data=_TEST_CONFIG_CONTENTS) + ) for key in _TEST_VARIABLE_NAMES: assert key not in os.environ - load_env_file('test-env', 'test-path/{}.yaml') + load_env_file("test-env", "test-path/{}.yaml") - mock_file_open.assert_called_once_with('test-path/test-env.yaml', 'r') - mock_kms_client.decrypt.assert_has_calls([ - mocker.call('test-encryption-1'), mocker.call('test-encryption-2'), - mocker.call('test-encryption-3'), mocker.call('test-encryption-4')] + mock_file_open.assert_called_once_with("test-path/test-env.yaml", "r") + mock_kms_client.decrypt.assert_has_calls( + [ + mocker.call("test-encryption-1"), + mocker.call("test-encryption-2"), + mocker.call("test-encryption-3"), + mocker.call("test-encryption-4"), + ] ) mock_kms_client.close.assert_called_once() - assert os.environ['TEST_STRING'] == 'string-variable' - assert os.environ['TEST_INT'] == '1' - assert os.environ['TEST_LIST'] == '["string-var", 2]' - assert os.environ['TEST_ENCRYPTED_VARIABLE_1'] == 'test-decryption-1' - assert os.environ['TEST_ENCRYPTED_VARIABLE_2'] == 'test-decryption-2' - assert os.environ['TEST_ENCRYPTED_LIST'] == \ - '["test-decryption-3", "test-decryption-4"]' + assert os.environ["TEST_STRING"] == "string-variable" + assert os.environ["TEST_INT"] == "1" + assert os.environ["TEST_LIST"] == '["string-var", 2]' + assert os.environ["TEST_ENCRYPTED_VARIABLE_1"] == "test-decryption-1" + assert os.environ["TEST_ENCRYPTED_VARIABLE_2"] == "test-decryption-2" + assert ( + os.environ["TEST_ENCRYPTED_LIST"] + == '["test-decryption-3", "test-decryption-4"]' + ) for key in _TEST_VARIABLE_NAMES: if key in os.environ: @@ -61,10 +76,9 @@ def test_load_env_file(self, mocker): def test_missing_file_error(self): with pytest.raises(ConfigHelperError): - load_env_file('bad-env', 'bad-path/{}.yaml') + load_env_file("bad-env", "bad-path/{}.yaml") def test_bad_yaml(self, mocker): - mocker.patch( - 'builtins.open', mocker.mock_open(read_data='bad yaml: [')) + mocker.patch("builtins.open", mocker.mock_open(read_data="bad yaml: [")) with pytest.raises(ConfigHelperError): - load_env_file('test-env', 'test-path/{}.not_yaml') + load_env_file("test-env", "test-path/{}.not_yaml") diff --git a/tests/test_kinesis_client.py b/tests/test_kinesis_client.py index 820de77..be43d55 100644 --- a/tests/test_kinesis_client.py +++ b/tests/test_kinesis_client.py @@ -1,96 +1,119 @@ import pytest from freezegun import freeze_time -from nypl_py_utils.classes.kinesis_client import ( - KinesisClient, KinesisClientError) +from nypl_py_utils.classes.kinesis_client import KinesisClient, KinesisClientError -_TEST_DATETIME_KEY = '1672531200000000000' +_TEST_DATETIME_KEY = "1672531200000000000" _TEST_KINESIS_RECORDS = [ - {'Data': b'a', 'PartitionKey': _TEST_DATETIME_KEY}, - {'Data': b'b', 'PartitionKey': _TEST_DATETIME_KEY}, - {'Data': b'c', 'PartitionKey': _TEST_DATETIME_KEY}, - {'Data': b'd', 'PartitionKey': _TEST_DATETIME_KEY}, - {'Data': b'e', 'PartitionKey': _TEST_DATETIME_KEY} + {"Data": b"a", "PartitionKey": _TEST_DATETIME_KEY}, + {"Data": b"b", "PartitionKey": _TEST_DATETIME_KEY}, + {"Data": b"c", "PartitionKey": _TEST_DATETIME_KEY}, + {"Data": b"d", "PartitionKey": _TEST_DATETIME_KEY}, + {"Data": b"e", "PartitionKey": _TEST_DATETIME_KEY}, ] -@freeze_time('2023-01-01') +@freeze_time("2023-01-01") class TestKinesisClient: @pytest.fixture def test_instance(self, mocker): - mocker.patch('boto3.client') - return KinesisClient('test_stream_arn', 2) + mocker.patch("boto3.client") + return KinesisClient("test_stream_arn", 2) def test_send_records(self, test_instance, mocker): - MOCK_RECORDS = [b'a', b'b', b'c', b'd', b'e'] + MOCK_RECORDS = [b"a", b"b", b"c", b"d", b"e"] mocked_send_method = mocker.patch( - 'nypl_py_utils.classes.kinesis_client.KinesisClient._send_kinesis_format_records') # noqa: E501 - mock_sleep = mocker.patch('time.sleep', return_value=None) + "nypl_py_utils.classes.kinesis_client.KinesisClient._send_kinesis_format_records" + ) # noqa: E501 + mock_sleep = mocker.patch("time.sleep", return_value=None) test_instance.send_records(MOCK_RECORDS) - mocked_send_method.assert_has_calls([ - mocker.call([_TEST_KINESIS_RECORDS[0], - _TEST_KINESIS_RECORDS[1]], 1), - mocker.call([_TEST_KINESIS_RECORDS[2], - _TEST_KINESIS_RECORDS[3]], 1), - mocker.call([_TEST_KINESIS_RECORDS[4]], 1)]) + mocked_send_method.assert_has_calls( + [ + mocker.call([_TEST_KINESIS_RECORDS[0], _TEST_KINESIS_RECORDS[1]], 1), + mocker.call([_TEST_KINESIS_RECORDS[2], _TEST_KINESIS_RECORDS[3]], 1), + mocker.call([_TEST_KINESIS_RECORDS[4]], 1), + ] + ) mock_sleep.assert_not_called() def test_send_records_with_pause(self, mocker): - mocker.patch('boto3.client') - test_instance = KinesisClient('test_stream_arn', 500) + mocker.patch("boto3.client") + test_instance = KinesisClient("test_stream_arn", 500) - MOCK_RECORDS = [b'a'] * 2200 + MOCK_RECORDS = [b"a"] * 2200 mocked_send_method = mocker.patch( - 'nypl_py_utils.classes.kinesis_client.KinesisClient._send_kinesis_format_records') # noqa: E501 - mock_sleep = mocker.patch('time.sleep', return_value=None) + "nypl_py_utils.classes.kinesis_client.KinesisClient._send_kinesis_format_records" + ) # noqa: E501 + mock_sleep = mocker.patch("time.sleep", return_value=None) test_instance.send_records(MOCK_RECORDS) - mocked_send_method.assert_has_calls([ - mocker.call([_TEST_KINESIS_RECORDS[0]]*500, 1), - mocker.call([_TEST_KINESIS_RECORDS[0]]*500, 1), - mocker.call([_TEST_KINESIS_RECORDS[0]]*500, 1), - mocker.call([_TEST_KINESIS_RECORDS[0]]*500, 1), - mocker.call([_TEST_KINESIS_RECORDS[0]]*200, 1)]) + mocked_send_method.assert_has_calls( + [ + mocker.call([_TEST_KINESIS_RECORDS[0]] * 500, 1), + mocker.call([_TEST_KINESIS_RECORDS[0]] * 500, 1), + mocker.call([_TEST_KINESIS_RECORDS[0]] * 500, 1), + mocker.call([_TEST_KINESIS_RECORDS[0]] * 500, 1), + mocker.call([_TEST_KINESIS_RECORDS[0]] * 200, 1), + ] + ) assert mock_sleep.call_count == 2 def test_send_kinesis_format_records(self, test_instance): - test_instance.kinesis_client.put_records.return_value = { - 'FailedRecordCount': 0} + test_instance.kinesis_client.put_records.return_value = {"FailedRecordCount": 0} test_instance._send_kinesis_format_records(_TEST_KINESIS_RECORDS, 1) test_instance.kinesis_client.put_records.assert_called_once_with( - Records=_TEST_KINESIS_RECORDS, StreamARN='test_stream_arn') + Records=_TEST_KINESIS_RECORDS, StreamARN="test_stream_arn" + ) - def test_send_kinesis_format_records_with_failures( - self, test_instance, mocker): + def test_send_kinesis_format_records_with_failures(self, test_instance, mocker): test_instance.kinesis_client.put_records.side_effect = [ - {'FailedRecordCount': 2, 'Records': [ - 'record0', {'ErrorCode': 1}, - 'record2', {'ErrorCode': 3}, - 'record4']}, - {'FailedRecordCount': 0}] + { + "FailedRecordCount": 2, + "Records": [ + "record0", + {"ErrorCode": 1}, + "record2", + {"ErrorCode": 3}, + "record4", + ], + }, + {"FailedRecordCount": 0}, + ] test_instance._send_kinesis_format_records(_TEST_KINESIS_RECORDS, 1) - test_instance.kinesis_client.put_records.assert_has_calls([ - mocker.call(Records=_TEST_KINESIS_RECORDS, - StreamARN='test_stream_arn'), - mocker.call(Records=[_TEST_KINESIS_RECORDS[1], - _TEST_KINESIS_RECORDS[3]], - StreamARN='test_stream_arn')]) + test_instance.kinesis_client.put_records.assert_has_calls( + [ + mocker.call(Records=_TEST_KINESIS_RECORDS, StreamARN="test_stream_arn"), + mocker.call( + Records=[_TEST_KINESIS_RECORDS[1], _TEST_KINESIS_RECORDS[3]], + StreamARN="test_stream_arn", + ), + ] + ) def test_send_kinesis_format_records_with_repeating_failures( - self, test_instance, mocker): + self, test_instance, mocker + ): test_instance.kinesis_client.put_records.side_effect = [ - {'FailedRecordCount': 5, 'Records': [ - {'ErrorCode': 0}, {'ErrorCode': 1}, {'ErrorCode': 2}, - {'ErrorCode': 3}, {'ErrorCode': 4}]}] * 5 + { + "FailedRecordCount": 5, + "Records": [ + {"ErrorCode": 0}, + {"ErrorCode": 1}, + {"ErrorCode": 2}, + {"ErrorCode": 3}, + {"ErrorCode": 4}, + ], + } + ] * 5 with pytest.raises(KinesisClientError): - test_instance._send_kinesis_format_records( - _TEST_KINESIS_RECORDS, 1) + test_instance._send_kinesis_format_records(_TEST_KINESIS_RECORDS, 1) - test_instance.kinesis_client.put_records.assert_has_calls([ - mocker.call(Records=_TEST_KINESIS_RECORDS, - StreamARN='test_stream_arn')] * 5) + test_instance.kinesis_client.put_records.assert_has_calls( + [mocker.call(Records=_TEST_KINESIS_RECORDS, StreamARN="test_stream_arn")] + * 5 + ) diff --git a/tests/test_kms_client.py b/tests/test_kms_client.py index e500b03..fcc10a3 100644 --- a/tests/test_kms_client.py +++ b/tests/test_kms_client.py @@ -3,12 +3,12 @@ from base64 import b64encode from nypl_py_utils.classes.kms_client import KmsClient, KmsClientError -_TEST_ENCRYPTED_VALUE = b64encode(b'test-encrypted-value') +_TEST_ENCRYPTED_VALUE = b64encode(b"test-encrypted-value") _TEST_DECRYPTION = { - 'KeyId': 'test-key-id', - 'Plaintext': b'test-decrypted-value', - 'EncryptionAlgorithm': 'test-encryption-algorithm', - 'ResponseMetadata': {} + "KeyId": "test-key-id", + "Plaintext": b"test-decrypted-value", + "EncryptionAlgorithm": "test-encryption-algorithm", + "ResponseMetadata": {}, } @@ -16,16 +16,16 @@ class TestKmsClient: @pytest.fixture def test_instance(self, mocker): - mocker.patch('boto3.client') + mocker.patch("boto3.client") return KmsClient() def test_decrypt(self, test_instance): test_instance.kms_client.decrypt.return_value = _TEST_DECRYPTION assert test_instance.kms_client.decrypt.called_once_with( - CiphertextBlob=b'test-encrypted-value') - assert test_instance.decrypt( - _TEST_ENCRYPTED_VALUE) == 'test-decrypted-value' + CiphertextBlob=b"test-encrypted-value" + ) + assert test_instance.decrypt(_TEST_ENCRYPTED_VALUE) == "test-decrypted-value" def test_base64_error(self, test_instance): with pytest.raises(KmsClientError): - test_instance.decrypt('bad-b64') + test_instance.decrypt("bad-b64") diff --git a/tests/test_log_helper.py b/tests/test_log_helper.py index cf7f616..77c46b3 100644 --- a/tests/test_log_helper.py +++ b/tests/test_log_helper.py @@ -6,56 +6,64 @@ from nypl_py_utils.functions.log_helper import create_log -@freeze_time('2023-01-01 19:00:00') +@freeze_time("2023-01-01 19:00:00") class TestLogHelper: def test_default_logging(self, caplog): - logger = create_log('test_log') + logger = create_log("test_log") assert logger.getEffectiveLevel() == logging.INFO assert len(logger.handlers) == 1 - logger.info('Test info message') + logger.info("Test info message") # freeze_time changes the utc time, while the logger uses local time by # default, so force the logger to use utc time logger.handlers[0].formatter.converter = time.gmtime assert len(caplog.records) == 1 - assert logger.handlers[0].format(caplog.records[0]) == \ - '2023-01-01 19:00:00,000 | test_log | INFO: Test info message' + assert ( + logger.handlers[0].format(caplog.records[0]) + == "2023-01-01 19:00:00,000 | test_log | INFO: Test info message" + ) def test_logging_with_custom_log_level(self, caplog): - os.environ['LOG_LEVEL'] = 'error' - logger = create_log('test_log') + os.environ["LOG_LEVEL"] = "error" + logger = create_log("test_log") assert logger.getEffectiveLevel() == logging.ERROR - logger.info('Test info message') - logger.error('Test error message') + logger.info("Test info message") + logger.error("Test error message") assert len(caplog.records) == 1 # freeze_time changes the utc time, while the logger uses local time by # default, so force the logger to use utc time logger.handlers[0].formatter.converter = time.gmtime - assert logger.handlers[0].format(caplog.records[0]) == \ - '2023-01-01 19:00:00,000 | test_log | ERROR: Test error message' - del os.environ['LOG_LEVEL'] + assert ( + logger.handlers[0].format(caplog.records[0]) + == "2023-01-01 19:00:00,000 | test_log | ERROR: Test error message" + ) + del os.environ["LOG_LEVEL"] def test_logging_no_duplicates(self, caplog): - logger = create_log('test_log') - logger.info('Test info message') + logger = create_log("test_log") + logger.info("Test info message") # Test that logger uses the most recently set log level and doesn't # duplicate handlers/messages when create_log is called more than once. - os.environ['LOG_LEVEL'] = 'error' - logger = create_log('test_log') + os.environ["LOG_LEVEL"] = "error" + logger = create_log("test_log") assert logger.getEffectiveLevel() == logging.ERROR assert len(logger.handlers) == 1 - logger.info('Test info message 2') - logger.error('Test error message') + logger.info("Test info message 2") + logger.error("Test error message") assert len(caplog.records) == 2 # freeze_time changes the utc time, while the logger uses local time by # default, so force the logger to use utc time logger.handlers[0].formatter.converter = time.gmtime - assert logger.handlers[0].format(caplog.records[0]) == \ - '2023-01-01 19:00:00,000 | test_log | INFO: Test info message' - assert logger.handlers[0].format(caplog.records[1]) == \ - '2023-01-01 19:00:00,000 | test_log | ERROR: Test error message' - del os.environ['LOG_LEVEL'] + assert ( + logger.handlers[0].format(caplog.records[0]) + == "2023-01-01 19:00:00,000 | test_log | INFO: Test info message" + ) + assert ( + logger.handlers[0].format(caplog.records[1]) + == "2023-01-01 19:00:00,000 | test_log | ERROR: Test error message" + ) + del os.environ["LOG_LEVEL"] diff --git a/tests/test_mysql_client.py b/tests/test_mysql_client.py index 11bb94f..b793006 100644 --- a/tests/test_mysql_client.py +++ b/tests/test_mysql_client.py @@ -7,32 +7,34 @@ class TestMySQLClient: @pytest.fixture def mock_mysql_conn(self, mocker): - return mocker.patch('mysql.connector.connect') + return mocker.patch("mysql.connector.connect") @pytest.fixture def test_instance(self): - return MySQLClient('test_host', 'test_port', 'test_database', - 'test_user', 'test_password') + return MySQLClient( + "test_host", "test_port", "test_database", "test_user", "test_password" + ) def test_connect(self, mock_mysql_conn, test_instance): test_instance.connect() - mock_mysql_conn.assert_called_once_with(host='test_host', - port='test_port', - database='test_database', - user='test_user', - password='test_password') + mock_mysql_conn.assert_called_once_with( + host="test_host", + port="test_port", + database="test_database", + user="test_user", + password="test_password", + ) def test_execute_read_query(self, mock_mysql_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() - mock_cursor.description = [('description', None, None)] - mock_cursor.fetchall.return_value = [(1, 2, 3), ('a', 'b', 'c')] + mock_cursor.description = [("description", None, None)] + mock_cursor.fetchall.return_value = [(1, 2, 3), ("a", "b", "c")] test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query( - 'test query') == [(1, 2, 3), ('a', 'b', 'c')] - mock_cursor.execute.assert_called_once_with('test query', None) + assert test_instance.execute_query("test query") == [(1, 2, 3), ("a", "b", "c")] + mock_cursor.execute.assert_called_once_with("test query", None) test_instance.conn.commit.assert_not_called() mock_cursor.close.assert_called_once() @@ -43,28 +45,29 @@ def test_execute_write_query(self, mock_mysql_conn, test_instance, mocker): mock_cursor.description = None test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query('test query') is None - mock_cursor.execute.assert_called_once_with('test query', None) + assert test_instance.execute_query("test query") is None + mock_cursor.execute.assert_called_once_with("test query", None) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_write_query_with_params(self, mock_mysql_conn, - test_instance, mocker): + def test_execute_write_query_with_params( + self, mock_mysql_conn, test_instance, mocker + ): test_instance.connect() mock_cursor = mocker.MagicMock() mock_cursor.description = None test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query( - 'test query %s %s', query_params=('a', 1)) is None - mock_cursor.execute.assert_called_once_with('test query %s %s', - ('a', 1)) + assert ( + test_instance.execute_query("test query %s %s", query_params=("a", 1)) + is None + ) + mock_cursor.execute.assert_called_once_with("test query %s %s", ("a", 1)) test_instance.conn.commit.called_once() mock_cursor.close.assert_called_once() - def test_execute_query_with_exception( - self, mock_mysql_conn, test_instance, mocker): + def test_execute_query_with_exception(self, mock_mysql_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() @@ -72,7 +75,7 @@ def test_execute_query_with_exception( test_instance.conn.cursor.return_value = mock_cursor with pytest.raises(MySQLClientError): - test_instance.execute_query('test query') + test_instance.execute_query("test query") test_instance.conn.rollback.assert_called_once() mock_cursor.close.assert_called_once() diff --git a/tests/test_oauth2_api_client.py b/tests/test_oauth2_api_client.py index b5fcd5c..b6377d9 100644 --- a/tests/test_oauth2_api_client.py +++ b/tests/test_oauth2_api_client.py @@ -5,19 +5,21 @@ from requests_oauthlib import OAuth2Session from requests import HTTPError, JSONDecodeError, Response -from nypl_py_utils.classes.oauth2_api_client import (Oauth2ApiClient, - Oauth2ApiClientError) +from nypl_py_utils.classes.oauth2_api_client import ( + Oauth2ApiClient, + Oauth2ApiClientError, +) _TOKEN_RESPONSE = { - 'access_token': 'super-secret-token', - 'expires_in': 1, - 'token_type': 'Bearer', - 'scope': ['offline_access', 'openid', 'login:staff', 'admin'], - 'id_token': 'super-secret-token' + "access_token": "super-secret-token", + "expires_in": 1, + "token_type": "Bearer", + "scope": ["offline_access", "openid", "login:staff", "admin"], + "id_token": "super-secret-token", } -BASE_URL = 'https://example.com/api/v0.1' -TOKEN_URL = 'https://oauth.example.com/oauth/token' +BASE_URL = "https://example.com/api/v0.1" +TOKEN_URL = "https://oauth.example.com/oauth/token" class MockEmptyResponse: @@ -30,7 +32,7 @@ def json(self): if self.empty: raise JSONDecodeError else: - return 'success' + return "success" class TestOauth2ApiClient: @@ -43,102 +45,112 @@ def token_server_post(self, requests_mock): @pytest.fixture def test_instance(self, requests_mock): - return Oauth2ApiClient(base_url=BASE_URL, - token_url=TOKEN_URL, - client_id='clientid', - client_secret='clientsecret' - ) + return Oauth2ApiClient( + base_url=BASE_URL, + token_url=TOKEN_URL, + client_id="clientid", + client_secret="clientsecret", + ) @pytest.fixture def test_instance_with_retries(self, requests_mock): - return Oauth2ApiClient(base_url=BASE_URL, - token_url=TOKEN_URL, - client_id='clientid', - client_secret='clientsecret', - with_retries=True - ) + return Oauth2ApiClient( + base_url=BASE_URL, + token_url=TOKEN_URL, + client_id="clientid", + client_secret="clientsecret", + with_retries=True, + ) def test_uses_env_vars(self): env = { - 'NYPL_API_CLIENT_ID': 'env client id', - 'NYPL_API_CLIENT_SECRET': 'env client secret', - 'NYPL_API_TOKEN_URL': 'env token url', - 'NYPL_API_BASE_URL': 'env base url' + "NYPL_API_CLIENT_ID": "env client id", + "NYPL_API_CLIENT_SECRET": "env client secret", + "NYPL_API_TOKEN_URL": "env token url", + "NYPL_API_BASE_URL": "env base url", } for key, value in env.items(): os.environ[key] = value client = Oauth2ApiClient() - assert client.client_id == 'env client id' - assert client.client_secret == 'env client secret' - assert client.token_url == 'env token url' - assert client.base_url == 'env base url' + assert client.client_id == "env client id" + assert client.client_secret == "env client secret" + assert client.token_url == "env token url" + assert client.base_url == "env base url" for key, value in env.items(): - os.environ[key] = '' + os.environ[key] = "" def test_generate_access_token(self, test_instance, token_server_post): test_instance._create_oauth_client() test_instance._generate_access_token() - assert test_instance.oauth_client.token['access_token']\ - == _TOKEN_RESPONSE['access_token'] + assert ( + test_instance.oauth_client.token["access_token"] + == _TOKEN_RESPONSE["access_token"] + ) def test_create_oauth_client(self, token_server_post, test_instance): test_instance._create_oauth_client() assert type(test_instance.oauth_client) is OAuth2Session - def test_do_http_method(self, requests_mock, token_server_post, - test_instance): - requests_mock.get(f'{BASE_URL}/foo', json={'foo': 'bar'}) + def test_do_http_method(self, requests_mock, token_server_post, test_instance): + requests_mock.get(f"{BASE_URL}/foo", json={"foo": "bar"}) - requests_mock.get(f'{BASE_URL}/foo', json={'foo': 'bar'}) - resp = test_instance._do_http_method('GET', 'foo') + requests_mock.get(f"{BASE_URL}/foo", json={"foo": "bar"}) + resp = test_instance._do_http_method("GET", "foo") assert resp.status_code == 200 - assert resp.json() == {'foo': 'bar'} + assert resp.json() == {"foo": "bar"} - def test_token_expiration(self, requests_mock, test_instance, - token_server_post, mocker): - api_get_mock = requests_mock.get(f'{BASE_URL}/foo', - json={'foo': 'bar'}) + def test_token_expiration( + self, requests_mock, test_instance, token_server_post, mocker + ): + api_get_mock = requests_mock.get(f"{BASE_URL}/foo", json={"foo": "bar"}) # Perform first request: - test_instance._do_http_method('GET', 'foo') + test_instance._do_http_method("GET", "foo") # Expect this first call triggered a single token server call: assert len(token_server_post.request_history) == 1 # And the GET request used the supplied Bearer token: - assert api_get_mock.request_history[0]._request\ - .headers['Authorization'] == 'Bearer super-secret-token' + assert ( + api_get_mock.request_history[0]._request.headers["Authorization"] + == "Bearer super-secret-token" + ) # The token obtained above expires in 1s, so wait out expiration: time.sleep(1.1) # Register new token response: second_token_response = dict(_TOKEN_RESPONSE) - second_token_response['id_token'] = 'super-secret-second-token' - second_token_response['access_token'] = 'super-secret-second-token' - second_token_server_post = requests_mock\ - .post(TOKEN_URL, text=json.dumps(second_token_response)) + second_token_response["id_token"] = "super-secret-second-token" + second_token_response["access_token"] = "super-secret-second-token" + second_token_server_post = requests_mock.post( + TOKEN_URL, text=json.dumps(second_token_response) + ) # Perform second request: - response = test_instance._do_http_method('GET', 'foo') + response = test_instance._do_http_method("GET", "foo") # Ensure we still return a plain requests Response object assert isinstance(response, Response) assert response.json() == {"foo": "bar"} # Expect a call on the second token server: assert len(second_token_server_post.request_history) == 1 # Expect the second GET request to carry the new Bearer token: - assert api_get_mock.request_history[1]._request\ - .headers['Authorization'] == 'Bearer super-secret-second-token' + assert ( + api_get_mock.request_history[1]._request.headers["Authorization"] + == "Bearer super-secret-second-token" + ) - def test_error_status_raises_error(self, requests_mock, test_instance, - token_server_post): - requests_mock.get(f'{BASE_URL}/foo', status_code=400) + def test_error_status_raises_error( + self, requests_mock, test_instance, token_server_post + ): + requests_mock.get(f"{BASE_URL}/foo", status_code=400) with pytest.raises(HTTPError): - test_instance._do_http_method('GET', 'foo') + test_instance._do_http_method("GET", "foo") def test_token_refresh_failure_raises_error( - self, requests_mock, test_instance, token_server_post): + self, requests_mock, test_instance, token_server_post + ): """ Failure to fetch a token can raise a number of errors including: - requests.exceptions.HTTPError for invalid access_token @@ -150,46 +162,57 @@ def test_token_refresh_failure_raises_error( a new valid token in response to token expiration. This test asserts that the client will not allow more than successive 3 retries. """ - requests_mock.get(f'{BASE_URL}/foo', json={'foo': 'bar'}) + requests_mock.get(f"{BASE_URL}/foo", json={"foo": "bar"}) token_response = dict(_TOKEN_RESPONSE) - token_response['expires_in'] = 0 - token_server_post = requests_mock\ - .post(TOKEN_URL, text=json.dumps(token_response)) + token_response["expires_in"] = 0 + token_server_post = requests_mock.post( + TOKEN_URL, text=json.dumps(token_response) + ) with pytest.raises(Oauth2ApiClientError): - test_instance._do_http_method('GET', 'foo') + test_instance._do_http_method("GET", "foo") # Expect 1 initial token fetch, plus 3 retries: assert len(token_server_post.request_history) == 4 - def test_bad_response_no_retries(self, requests_mock, test_instance, - mocker): - mocker.patch.object(test_instance, '_do_http_method', - return_value=MockEmptyResponse(empty=True)) - get_spy = mocker.spy(test_instance, 'get') - resp = test_instance.get('spaghetti') + def test_bad_response_no_retries(self, requests_mock, test_instance, mocker): + mocker.patch.object( + test_instance, "_do_http_method", return_value=MockEmptyResponse(empty=True) + ) + get_spy = mocker.spy(test_instance, "get") + resp = test_instance.get("spaghetti") assert get_spy.call_count == 1 assert resp.status_code == 500 - assert resp.message == 'Oauth2 Client: Bad response from OauthClient' - - def test_http_retry_fail(self, requests_mock, test_instance_with_retries, - mocker): - mocker.patch.object(test_instance_with_retries, '_do_http_method', - return_value=MockEmptyResponse(empty=True)) - get_spy = mocker.spy(test_instance_with_retries, 'get') - resp = test_instance_with_retries.get('spaghetti') + assert resp.message == "Oauth2 Client: Bad response from OauthClient" + + def test_http_retry_fail(self, requests_mock, test_instance_with_retries, mocker): + mocker.patch.object( + test_instance_with_retries, + "_do_http_method", + return_value=MockEmptyResponse(empty=True), + ) + get_spy = mocker.spy(test_instance_with_retries, "get") + resp = test_instance_with_retries.get("spaghetti") assert get_spy.call_count == 3 assert resp.status_code == 500 - assert resp.message == 'Oauth2 Client: Request failed after 3 \ - empty responses received from Oauth2 Client' - - def test_http_retry_success(self, requests_mock, - test_instance_with_retries, mocker): - mocker.patch.object(test_instance_with_retries, '_do_http_method', - side_effect=[MockEmptyResponse(empty=True), - MockEmptyResponse(empty=False, - status_code=200)]) - get_spy = mocker.spy(test_instance_with_retries, 'get') - resp = test_instance_with_retries.get('spaghetti') + assert ( + resp.message + == "Oauth2 Client: Request failed after 3 \ + empty responses received from Oauth2 Client" + ) + + def test_http_retry_success( + self, requests_mock, test_instance_with_retries, mocker + ): + mocker.patch.object( + test_instance_with_retries, + "_do_http_method", + side_effect=[ + MockEmptyResponse(empty=True), + MockEmptyResponse(empty=False, status_code=200), + ], + ) + get_spy = mocker.spy(test_instance_with_retries, "get") + resp = test_instance_with_retries.get("spaghetti") assert get_spy.call_count == 2 - assert resp.json() == 'success' + assert resp.json() == "success" diff --git a/tests/test_obfuscation_helper.py b/tests/test_obfuscation_helper.py index ed76261..112785f 100644 --- a/tests/test_obfuscation_helper.py +++ b/tests/test_obfuscation_helper.py @@ -1,19 +1,20 @@ import os -from nypl_py_utils.functions.obfuscation_helper import (obfuscate, - obfuscate_with_salt) +from nypl_py_utils.functions.obfuscation_helper import obfuscate, obfuscate_with_salt -_TEST_SALT_1 = '$2a$10$8AvAPrrUsmlBa50qgc683e' -_TEST_SALT_2 = '$2b$12$iuSSdD6F/nJ1GSXzesM8sO' +_TEST_SALT_1 = "$2a$10$8AvAPrrUsmlBa50qgc683e" +_TEST_SALT_2 = "$2b$12$iuSSdD6F/nJ1GSXzesM8sO" class TestObfuscationHelper: def test_obfuscation_with_environment_variable(self): - os.environ['BCRYPT_SALT'] = _TEST_SALT_1 - assert obfuscate('test_input') == 'UPMawmdZfleeSg5REsZbLbAivWl97O6' - del os.environ['BCRYPT_SALT'] + os.environ["BCRYPT_SALT"] = _TEST_SALT_1 + assert obfuscate("test_input") == "UPMawmdZfleeSg5REsZbLbAivWl97O6" + del os.environ["BCRYPT_SALT"] def test_obfuscation_with_custom_salt(self): - assert (obfuscate_with_salt('test_input', _TEST_SALT_2) == - 'SUXLCHnsRVt4Vj1PyP9KPEqADxtUj5.') + assert ( + obfuscate_with_salt("test_input", _TEST_SALT_2) + == "SUXLCHnsRVt4Vj1PyP9KPEqADxtUj5." + ) diff --git a/tests/test_postgresql_client.py b/tests/test_postgresql_client.py index 99e5042..2c32827 100644 --- a/tests/test_postgresql_client.py +++ b/tests/test_postgresql_client.py @@ -1,38 +1,40 @@ import pytest from nypl_py_utils.classes.postgresql_client import ( - PostgreSQLClient, PostgreSQLClientError) + PostgreSQLClient, + PostgreSQLClientError, +) class TestPostgreSQLClient: @pytest.fixture def mock_pg_conn(self, mocker): - return mocker.patch('psycopg.connect') + return mocker.patch("psycopg.connect") @pytest.fixture def test_instance(self): - return PostgreSQLClient('test_host', 'test_port', 'test_db_name', - 'test_user', 'test_password') + return PostgreSQLClient( + "test_host", "test_port", "test_db_name", "test_user", "test_password" + ) def test_connect(self, mock_pg_conn, test_instance): test_instance.connect() mock_pg_conn.assert_called_once_with( - 'postgresql://test_user:test_password@test_host:test_port/' + - 'test_db_name') + "postgresql://test_user:test_password@test_host:test_port/" + "test_db_name" + ) def test_execute_read_query(self, mock_pg_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() - mock_cursor.description = [('description', None, None)] + mock_cursor.description = [("description", None, None)] mock_cursor.execute.return_value = mock_cursor - mock_cursor.fetchall.return_value = [(1, 2, 3), ('a', 'b', 'c')] + mock_cursor.fetchall.return_value = [(1, 2, 3), ("a", "b", "c")] test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query( - 'test query') == [(1, 2, 3), ('a', 'b', 'c')] - mock_cursor.execute.assert_called_once_with('test query', None) + assert test_instance.execute_query("test query") == [(1, 2, 3), ("a", "b", "c")] + mock_cursor.execute.assert_called_once_with("test query", None) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() @@ -43,28 +45,27 @@ def test_execute_write_query(self, mock_pg_conn, test_instance, mocker): mock_cursor.description = None test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query('test query') is None - mock_cursor.execute.assert_called_once_with('test query', None) + assert test_instance.execute_query("test query") is None + mock_cursor.execute.assert_called_once_with("test query", None) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_write_query_with_params(self, mock_pg_conn, test_instance, - mocker): + def test_execute_write_query_with_params(self, mock_pg_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() mock_cursor.description = None test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query( - 'test query %s %s', query_params=('a', 1)) is None - mock_cursor.execute.assert_called_once_with('test query %s %s', - ('a', 1)) + assert ( + test_instance.execute_query("test query %s %s", query_params=("a", 1)) + is None + ) + mock_cursor.execute.assert_called_once_with("test query %s %s", ("a", 1)) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_query_with_exception( - self, mock_pg_conn, test_instance, mocker): + def test_execute_query_with_exception(self, mock_pg_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() @@ -72,7 +73,7 @@ def test_execute_query_with_exception( test_instance.conn.cursor.return_value = mock_cursor with pytest.raises(PostgreSQLClientError): - test_instance.execute_query('test query') + test_instance.execute_query("test query") test_instance.conn.rollback.assert_called_once() mock_cursor.close.assert_called_once() diff --git a/tests/test_postgresql_pool_client.py b/tests/test_postgresql_pool_client.py index 82f22b6..5a57808 100644 --- a/tests/test_postgresql_pool_client.py +++ b/tests/test_postgresql_pool_client.py @@ -1,7 +1,9 @@ import pytest from nypl_py_utils.classes.postgresql_pool_client import ( - PostgreSQLPoolClient, PostgreSQLPoolClientError) + PostgreSQLPoolClient, + PostgreSQLPoolClientError, +) from psycopg import Error @@ -9,15 +11,16 @@ class TestPostgreSQLPoolClient: @pytest.fixture def test_instance(self, mocker): - mocker.patch('psycopg_pool.ConnectionPool.open') - mocker.patch('psycopg_pool.ConnectionPool.close') - return PostgreSQLPoolClient('test_host', 'test_port', 'test_db_name', - 'test_user', 'test_password') + mocker.patch("psycopg_pool.ConnectionPool.open") + mocker.patch("psycopg_pool.ConnectionPool.close") + return PostgreSQLPoolClient( + "test_host", "test_port", "test_db_name", "test_user", "test_password" + ) def test_init(self, test_instance): assert test_instance.pool.conninfo == ( - 'postgresql://test_user:test_password@test_host:test_port/' + - 'test_db_name') + "postgresql://test_user:test_password@test_host:test_port/" + "test_db_name" + ) assert test_instance.pool._opened is False assert test_instance.pool.min_size == 0 assert test_instance.pool.max_size == 1 @@ -25,21 +28,24 @@ def test_init(self, test_instance): def test_init_with_long_max_idle(self): with pytest.raises(PostgreSQLPoolClientError): PostgreSQLPoolClient( - 'test_host', 'test_port', 'test_db_name', 'test_user', - 'test_password', max_idle=300.0) + "test_host", + "test_port", + "test_db_name", + "test_user", + "test_password", + max_idle=300.0, + ) def test_connect(self, test_instance): test_instance.connect() - test_instance.pool.open.assert_called_once_with(wait=True, - timeout=300.0) + test_instance.pool.open.assert_called_once_with(wait=True, timeout=300.0) def test_connect_with_exception(self, mocker): - mocker.patch('psycopg_pool.ConnectionPool.open', - side_effect=Error()) + mocker.patch("psycopg_pool.ConnectionPool.open", side_effect=Error()) test_instance = PostgreSQLPoolClient( - 'test_host', 'test_port', 'test_db_name', 'test_user', - 'test_password') + "test_host", "test_port", "test_db_name", "test_user", "test_password" + ) with pytest.raises(PostgreSQLPoolClientError): test_instance.connect(timeout=1.0) @@ -48,18 +54,18 @@ def test_execute_read_query(self, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() - mock_cursor.description = [('description', None, None)] - mock_cursor.fetchall.return_value = [(1, 2, 3), ('a', 'b', 'c')] + mock_cursor.description = [("description", None, None)] + mock_cursor.fetchall.return_value = [(1, 2, 3), ("a", "b", "c")] mock_conn = mocker.MagicMock() mock_conn.execute.return_value = mock_cursor mock_conn_context = mocker.MagicMock() mock_conn_context.__enter__.return_value = mock_conn - mocker.patch('psycopg_pool.ConnectionPool.connection', - return_value=mock_conn_context) + mocker.patch( + "psycopg_pool.ConnectionPool.connection", return_value=mock_conn_context + ) - assert test_instance.execute_query( - 'test query') == [(1, 2, 3), ('a', 'b', 'c')] - mock_conn.execute.assert_called_once_with('test query', None) + assert test_instance.execute_query("test query") == [(1, 2, 3), ("a", "b", "c")] + mock_conn.execute.assert_called_once_with("test query", None) mock_cursor.fetchall.assert_called_once() def test_execute_write_query(self, test_instance, mocker): @@ -71,11 +77,12 @@ def test_execute_write_query(self, test_instance, mocker): mock_conn.execute.return_value = mock_cursor mock_conn_context = mocker.MagicMock() mock_conn_context.__enter__.return_value = mock_conn - mocker.patch('psycopg_pool.ConnectionPool.connection', - return_value=mock_conn_context) + mocker.patch( + "psycopg_pool.ConnectionPool.connection", return_value=mock_conn_context + ) - assert test_instance.execute_query('test query') is None - mock_conn.execute.assert_called_once_with('test query', None) + assert test_instance.execute_query("test query") is None + mock_conn.execute.assert_called_once_with("test query", None) def test_execute_write_query_with_params(self, test_instance, mocker): test_instance.connect() @@ -86,13 +93,15 @@ def test_execute_write_query_with_params(self, test_instance, mocker): mock_conn.execute.return_value = mock_cursor mock_conn_context = mocker.MagicMock() mock_conn_context.__enter__.return_value = mock_conn - mocker.patch('psycopg_pool.ConnectionPool.connection', - return_value=mock_conn_context) + mocker.patch( + "psycopg_pool.ConnectionPool.connection", return_value=mock_conn_context + ) - assert test_instance.execute_query( - 'test query %s %s', query_params=('a', 1)) is None - mock_conn.execute.assert_called_once_with('test query %s %s', - ('a', 1)) + assert ( + test_instance.execute_query("test query %s %s", query_params=("a", 1)) + is None + ) + mock_conn.execute.assert_called_once_with("test query %s %s", ("a", 1)) def test_execute_query_with_exception(self, test_instance, mocker): test_instance.connect() @@ -101,11 +110,12 @@ def test_execute_query_with_exception(self, test_instance, mocker): mock_conn.execute.side_effect = Exception() mock_conn_context = mocker.MagicMock() mock_conn_context.__enter__.return_value = mock_conn - mocker.patch('psycopg_pool.ConnectionPool.connection', - return_value=mock_conn_context) + mocker.patch( + "psycopg_pool.ConnectionPool.connection", return_value=mock_conn_context + ) with pytest.raises(PostgreSQLPoolClientError): - test_instance.execute_query('test query') + test_instance.execute_query("test query") def test_close_pool(self, test_instance): test_instance.connect() @@ -116,6 +126,6 @@ def test_reopen_pool(self, test_instance, mocker): test_instance.connect() test_instance.close_pool() test_instance.connect() - test_instance.pool.open.assert_has_calls([ - mocker.call(wait=True, timeout=300), - mocker.call(wait=True, timeout=300)]) + test_instance.pool.open.assert_has_calls( + [mocker.call(wait=True, timeout=300), mocker.call(wait=True, timeout=300)] + ) diff --git a/tests/test_redshift_client.py b/tests/test_redshift_client.py index 7d6219d..d60b85a 100644 --- a/tests/test_redshift_client.py +++ b/tests/test_redshift_client.py @@ -1,55 +1,56 @@ import pytest -from nypl_py_utils.classes.redshift_client import ( - RedshiftClient, RedshiftClientError) +from nypl_py_utils.classes.redshift_client import RedshiftClient, RedshiftClientError class TestRedshiftClient: @pytest.fixture def mock_redshift_conn(self, mocker): - return mocker.patch('redshift_connector.connect') + return mocker.patch("redshift_connector.connect") @pytest.fixture def test_instance(self): - return RedshiftClient('test_host', 'test_database', 'test_user', - 'test_password') + return RedshiftClient( + "test_host", "test_database", "test_user", "test_password" + ) def test_connect(self, mock_redshift_conn, test_instance): test_instance.connect() - mock_redshift_conn.assert_called_once_with(host='test_host', - database='test_database', - user='test_user', - password='test_password', - sslmode='verify-full') + mock_redshift_conn.assert_called_once_with( + host="test_host", + database="test_database", + user="test_user", + password="test_password", + sslmode="verify-full", + ) def test_execute_query(self, mock_redshift_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() - mock_cursor.fetchall.return_value = [[1, 2, 3], ['a', 'b', 'c']] + mock_cursor.fetchall.return_value = [[1, 2, 3], ["a", "b", "c"]] test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query( - 'test query') == [[1, 2, 3], ['a', 'b', 'c']] - mock_cursor.execute.assert_called_once_with('test query') + assert test_instance.execute_query("test query") == [[1, 2, 3], ["a", "b", "c"]] + mock_cursor.execute.assert_called_once_with("test query") mock_cursor.fetchall.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_dataframe_query(self, mock_redshift_conn, test_instance, - mocker): + def test_execute_dataframe_query(self, mock_redshift_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() test_instance.conn.cursor.return_value = mock_cursor - test_instance.execute_query('test query', dataframe=True) - mock_cursor.execute.assert_called_once_with('test query') + test_instance.execute_query("test query", dataframe=True) + mock_cursor.execute.assert_called_once_with("test query") mock_cursor.fetch_dataframe.assert_called_once() mock_cursor.close.assert_called_once() def test_execute_query_with_exception( - self, mock_redshift_conn, test_instance, mocker): + self, mock_redshift_conn, test_instance, mocker + ): test_instance.connect() mock_cursor = mocker.MagicMock() @@ -57,52 +58,66 @@ def test_execute_query_with_exception( test_instance.conn.cursor.return_value = mock_cursor with pytest.raises(RedshiftClientError): - test_instance.execute_query('test query') + test_instance.execute_query("test query") test_instance.conn.rollback.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_transaction(self, mock_redshift_conn, test_instance, - mocker): + def test_execute_transaction(self, mock_redshift_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() test_instance.conn.cursor.return_value = mock_cursor - test_instance.execute_transaction([('query 1', None), - ('query 2 %s %s', ('a', 1))]) - mock_cursor.execute.assert_has_calls([ - mocker.call('BEGIN TRANSACTION;'), - mocker.call('query 1', None), - mocker.call('query 2 %s %s', ('a', 1)), - mocker.call('END TRANSACTION;')]) + test_instance.execute_transaction( + [("query 1", None), ("query 2 %s %s", ("a", 1))] + ) + mock_cursor.execute.assert_has_calls( + [ + mocker.call("BEGIN TRANSACTION;"), + mocker.call("query 1", None), + mocker.call("query 2 %s %s", ("a", 1)), + mocker.call("END TRANSACTION;"), + ] + ) mock_cursor.executemany.assert_not_called() test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_transaction_with_many(self, mock_redshift_conn, - test_instance, mocker): + def test_execute_transaction_with_many( + self, mock_redshift_conn, test_instance, mocker + ): test_instance.connect() mock_cursor = mocker.MagicMock() test_instance.conn.cursor.return_value = mock_cursor - test_instance.execute_transaction([ - ('query 1', None), ('query 2 %s %s', (None, 1)), - ('query 3 %s %s', [(None, 10), ('b', 20)]), ('query 4', None)]) - mock_cursor.execute.assert_has_calls([ - mocker.call('BEGIN TRANSACTION;'), - mocker.call('query 1', None), - mocker.call('query 2 %s %s', (None, 1)), - mocker.call('query 4', None), - mocker.call('END TRANSACTION;')]) + test_instance.execute_transaction( + [ + ("query 1", None), + ("query 2 %s %s", (None, 1)), + ("query 3 %s %s", [(None, 10), ("b", 20)]), + ("query 4", None), + ] + ) + mock_cursor.execute.assert_has_calls( + [ + mocker.call("BEGIN TRANSACTION;"), + mocker.call("query 1", None), + mocker.call("query 2 %s %s", (None, 1)), + mocker.call("query 4", None), + mocker.call("END TRANSACTION;"), + ] + ) mock_cursor.executemany.assert_called_once_with( - 'query 3 %s %s', [(None, 10), ('b', 20)]) + "query 3 %s %s", [(None, 10), ("b", 20)] + ) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() def test_execute_transaction_with_exception( - self, mock_redshift_conn, test_instance, mocker): + self, mock_redshift_conn, test_instance, mocker + ): test_instance.connect() mock_cursor = mocker.MagicMock() @@ -110,13 +125,15 @@ def test_execute_transaction_with_exception( test_instance.conn.cursor.return_value = mock_cursor with pytest.raises(RedshiftClientError): - test_instance.execute_transaction( - [('query 1', None), ('query 2', None)]) - - mock_cursor.execute.assert_has_calls([ - mocker.call('BEGIN TRANSACTION;'), - mocker.call('query 1', None), - mocker.call('query 2', None)]) + test_instance.execute_transaction([("query 1", None), ("query 2", None)]) + + mock_cursor.execute.assert_has_calls( + [ + mocker.call("BEGIN TRANSACTION;"), + mocker.call("query 1", None), + mocker.call("query 2", None), + ] + ) test_instance.conn.commit.assert_not_called() test_instance.conn.rollback.assert_called_once() mock_cursor.close.assert_called_once() diff --git a/tests/test_research_catalog_identifier_helper.py b/tests/test_research_catalog_identifier_helper.py index bf7686f..fafbc88 100644 --- a/tests/test_research_catalog_identifier_helper.py +++ b/tests/test_research_catalog_identifier_helper.py @@ -2,31 +2,29 @@ import json from nypl_py_utils.functions.research_catalog_identifier_helper import ( - parse_research_catalog_identifier, research_catalog_id_prefix, - ResearchCatalogIdentifierError) + parse_research_catalog_identifier, + research_catalog_id_prefix, + ResearchCatalogIdentifierError, +) _TEST_MAPPING = { - 'sierra-nypl': { - 'organization': 'nyplOrg:0001', - 'bibPrefix': 'b', - 'holdingPrefix': 'h', - 'itemPrefix': 'i' - }, - 'recap-pul': { - 'organization': 'nyplOrg:0003', - 'bibPrefix': 'pb', - 'itemPrefix': 'pi' - }, - 'recap-cul': { - 'organization': 'nyplOrg:0002', - 'bibPrefix': 'cb', - 'itemPrefix': 'ci' - }, - 'recap-hl': { - 'organization': 'nyplOrg:0004', - 'bibPrefix': 'hb', - 'itemPrefix': 'hi' - } + "sierra-nypl": { + "organization": "nyplOrg:0001", + "bibPrefix": "b", + "holdingPrefix": "h", + "itemPrefix": "i", + }, + "recap-pul": { + "organization": "nyplOrg:0003", + "bibPrefix": "pb", + "itemPrefix": "pi", + }, + "recap-cul": { + "organization": "nyplOrg:0002", + "bibPrefix": "cb", + "itemPrefix": "ci", + }, + "recap-hl": {"organization": "nyplOrg:0004", "bibPrefix": "hb", "itemPrefix": "hi"}, } @@ -34,39 +32,51 @@ class TestResearchCatalogIdentifierHelper: @pytest.fixture(autouse=True) def test_instance(self, requests_mock): requests_mock.get( - 'https://raw.githubusercontent.com/NYPL/nypl-core/master/mappings/recap-discovery/nypl-source-mapping.json', # noqa - text=json.dumps(_TEST_MAPPING)) + "https://raw.githubusercontent.com/NYPL/nypl-core/master/mappings/recap-discovery/nypl-source-mapping.json", # noqa + text=json.dumps(_TEST_MAPPING), + ) def test_parse_research_catalog_identifier_parses_valid(self): - assert parse_research_catalog_identifier('b1234') == \ - {'id': '1234', 'nyplSource': 'sierra-nypl', 'nyplType': 'bib'} - assert parse_research_catalog_identifier('cb1234') == \ - {'id': '1234', 'nyplSource': 'recap-cul', 'nyplType': 'bib'} - assert parse_research_catalog_identifier('pi1234') == \ - {'id': '1234', 'nyplSource': 'recap-pul', 'nyplType': 'item'} - assert parse_research_catalog_identifier('h1234') == \ - {'id': '1234', 'nyplSource': 'sierra-nypl', - 'nyplType': 'holding'} + assert parse_research_catalog_identifier("b1234") == { + "id": "1234", + "nyplSource": "sierra-nypl", + "nyplType": "bib", + } + assert parse_research_catalog_identifier("cb1234") == { + "id": "1234", + "nyplSource": "recap-cul", + "nyplType": "bib", + } + assert parse_research_catalog_identifier("pi1234") == { + "id": "1234", + "nyplSource": "recap-pul", + "nyplType": "item", + } + assert parse_research_catalog_identifier("h1234") == { + "id": "1234", + "nyplSource": "sierra-nypl", + "nyplType": "holding", + } def test_parse_research_catalog_identifier_fails_nonsense(self): - for invalidIdentifier in [None, 1234, 'z1234', '1234']: + for invalidIdentifier in [None, 1234, "z1234", "1234"]: with pytest.raises(ResearchCatalogIdentifierError): parse_research_catalog_identifier(invalidIdentifier) def test_research_catalog_id_prefix_parses_valid(self, mocker): - assert research_catalog_id_prefix('sierra-nypl') == 'b' - assert research_catalog_id_prefix('sierra-nypl', 'bib') == 'b' - assert research_catalog_id_prefix('sierra-nypl', 'item') == 'i' - assert research_catalog_id_prefix('sierra-nypl', 'holding') == 'h' - assert research_catalog_id_prefix('recap-pul', 'bib') == 'pb' - assert research_catalog_id_prefix('recap-hl', 'bib') == 'hb' - assert research_catalog_id_prefix('recap-hl', 'item') == 'hi' - assert research_catalog_id_prefix('recap-pul', 'item') == 'pi' + assert research_catalog_id_prefix("sierra-nypl") == "b" + assert research_catalog_id_prefix("sierra-nypl", "bib") == "b" + assert research_catalog_id_prefix("sierra-nypl", "item") == "i" + assert research_catalog_id_prefix("sierra-nypl", "holding") == "h" + assert research_catalog_id_prefix("recap-pul", "bib") == "pb" + assert research_catalog_id_prefix("recap-hl", "bib") == "hb" + assert research_catalog_id_prefix("recap-hl", "item") == "hi" + assert research_catalog_id_prefix("recap-pul", "item") == "pi" def test_research_catalog_id_prefix_fails_nonsense(self, mocker): - for invalidSource in ['sierra-cul', None, 'recap-nypl']: + for invalidSource in ["sierra-cul", None, "recap-nypl"]: with pytest.raises(ResearchCatalogIdentifierError): research_catalog_id_prefix(invalidSource) - for invalidType in [None, '...']: + for invalidType in [None, "..."]: with pytest.raises(ResearchCatalogIdentifierError): - research_catalog_id_prefix('sierra-nypl', invalidType) + research_catalog_id_prefix("sierra-nypl", invalidType) diff --git a/tests/test_s3_client.py b/tests/test_s3_client.py index bbb74e0..6c1a6e5 100644 --- a/tests/test_s3_client.py +++ b/tests/test_s3_client.py @@ -3,20 +3,20 @@ from nypl_py_utils.classes.s3_client import S3Client -_TEST_STATE = {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'} +_TEST_STATE = {"key1": "val1", "key2": "val2", "key3": "val3"} class TestS3Client: @pytest.fixture def test_instance(self, mocker): - mocker.patch('boto3.client') - return S3Client('test_s3_bucket', 'test_s3_resource') + mocker.patch("boto3.client") + return S3Client("test_s3_bucket", "test_s3_resource") def test_fetch_cache(self, test_instance): def mock_download(bucket, resource, stream): - assert bucket == 'test_s3_bucket' - assert resource == 'test_s3_resource' + assert bucket == "test_s3_bucket" + assert resource == "test_s3_resource" stream.write(json.dumps(_TEST_STATE).encode()) test_instance.s3_client.download_fileobj.side_effect = mock_download @@ -26,5 +26,5 @@ def test_set_cache(self, test_instance): test_instance.set_cache(_TEST_STATE) arguments = test_instance.s3_client.upload_fileobj.call_args.args assert arguments[0].getvalue() == json.dumps(_TEST_STATE).encode() - assert arguments[1] == 'test_s3_bucket' - assert arguments[2] == 'test_s3_resource' + assert arguments[1] == "test_s3_bucket" + assert arguments[2] == "test_s3_resource" From fe948d8a47bfe3c658693d22552a9887971feb94 Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Tue, 2 Jul 2024 10:04:36 -0500 Subject: [PATCH 06/14] Revert "Reformatted with black" This reverts commit 69f12e5faff41be6c8e6e070a37e5682c66bcad4. --- CHANGELOG.md | 2 +- src/nypl_py_utils/classes/avro_client.py | 97 ++++----- src/nypl_py_utils/classes/kinesis_client.py | 61 +++--- src/nypl_py_utils/classes/kms_client.py | 26 +-- src/nypl_py_utils/classes/mysql_client.py | 38 ++-- .../classes/oauth2_api_client.py | 84 ++++---- .../classes/postgresql_client.py | 44 ++-- .../classes/postgresql_pool_client.py | 86 ++++---- src/nypl_py_utils/classes/redshift_client.py | 64 +++--- src/nypl_py_utils/classes/s3_client.py | 52 ++--- src/nypl_py_utils/functions/config_helper.py | 20 +- src/nypl_py_utils/functions/log_helper.py | 15 +- .../functions/obfuscation_helper.py | 14 +- .../research_catalog_identifier_helper.py | 65 +++--- tests/test_avro_client.py | 112 +++++----- tests/test_config_helper.py | 70 +++--- tests/test_kinesis_client.py | 135 +++++------- tests/test_kms_client.py | 20 +- tests/test_log_helper.py | 54 ++--- tests/test_mysql_client.py | 51 +++-- tests/test_oauth2_api_client.py | 199 ++++++++---------- tests/test_obfuscation_helper.py | 19 +- tests/test_postgresql_client.py | 45 ++-- tests/test_postgresql_pool_client.py | 86 ++++---- tests/test_redshift_client.py | 115 +++++----- ...test_research_catalog_identifier_helper.py | 102 ++++----- tests/test_s3_client.py | 14 +- 27 files changed, 748 insertions(+), 942 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c776943..f7b9675 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog ## v1.1.6 6/26/24 -- Generalized Avro functions and separated encoding/decoding behavior +- Generalized Avro functions and separated encoding/decoding behavior. ## v1.1.5 6/6/24 - Use executemany instead of execute when appropriate in RedshiftClient.execute_transaction diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index 6aab3b4..ae2c988 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -9,53 +9,46 @@ from nypl_py_utils.functions.log_helper import create_log from requests.exceptions import JSONDecodeError, RequestException - class AvroClient: """ Base class for Avro schema interaction. Takes as input the Platform API endpoint from which to fetch the schema in JSON format. """ - + def __init__(self, platform_schema_url): - self.logger = create_log("avro_encoder") - self.schema = avro.schema.parse(self.get_json_schema(platform_schema_url)) - + self.logger = create_log('avro_encoder') + self.schema = avro.schema.parse( + self.get_json_schema(platform_schema_url)) + def get_json_schema(self, platform_schema_url): """ Fetches a JSON response from the input Platform API endpoint and interprets it as an Avro schema. """ - self.logger.info("Fetching Avro schema from {}".format(platform_schema_url)) + self.logger.info('Fetching Avro schema from {}'.format( + platform_schema_url)) try: response = requests.get(platform_schema_url) response.raise_for_status() except RequestException as e: self.logger.error( - "Failed to retrieve schema from {url}: {error}".format( - url=platform_schema_url, error=e - ) - ) + 'Failed to retrieve schema from {url}: {error}'.format( + url=platform_schema_url, error=e)) raise AvroClientError( - "Failed to retrieve schema from {url}: {error}".format( - url=platform_schema_url, error=e - ) - ) from None + 'Failed to retrieve schema from {url}: {error}'.format( + url=platform_schema_url, error=e)) from None try: json_response = response.json() - return json_response["data"]["schema"] + return json_response['data']['schema'] except (JSONDecodeError, KeyError) as e: self.logger.error( - "Retrieved schema is malformed: {errorType} {errorMessage}".format( - errorType=type(e), errorMessage=e - ) - ) + 'Retrieved schema is malformed: {errorType} {errorMessage}' + .format(errorType=type(e), errorMessage=e)) raise AvroClientError( - "Retrieved schema is malformed: {errorType} {errorMessage}".format( - errorType=type(e), errorMessage=e - ) - ) from None - + 'Retrieved schema is malformed: {errorType} {errorMessage}' + .format(errorType=type(e), errorMessage=e)) from None + class AvroEncoder(AvroClient): """ @@ -70,8 +63,8 @@ def encode_record(self, record): Returns the encoded record as a byte string. """ self.logger.debug( - "Encoding record using {schema} schema".format(schema=self.schema.name) - ) + 'Encoding record using {schema} schema'.format( + schema=self.schema.name)) datum_writer = DatumWriter(self.schema) with BytesIO() as output_stream: encoder = BinaryEncoder(output_stream) @@ -79,8 +72,9 @@ def encode_record(self, record): datum_writer.write(record, encoder) return output_stream.getvalue() except AvroException as e: - self.logger.error("Failed to encode record: {}".format(e)) - raise AvroClientError("Failed to encode record: {}".format(e)) from None + self.logger.error('Failed to encode record: {}'.format(e)) + raise AvroClientError( + 'Failed to encode record: {}'.format(e)) from None def encode_batch(self, record_list): """ @@ -89,10 +83,8 @@ def encode_batch(self, record_list): Returns a list of byte strings where each string is an encoded record. """ self.logger.info( - "Encoding ({num_rec}) records using {schema} schema".format( - num_rec=len(record_list), schema=self.schema.name - ) - ) + 'Encoding ({num_rec}) records using {schema} schema'.format( + num_rec=len(record_list), schema=self.schema.name)) encoded_records = [] datum_writer = DatumWriter(self.schema) with BytesIO() as output_stream: @@ -104,10 +96,9 @@ def encode_batch(self, record_list): output_stream.seek(0) output_stream.truncate(0) except AvroException as e: - self.logger.error("Failed to encode record: {}".format(e)) + self.logger.error('Failed to encode record: {}'.format(e)) raise AvroClientError( - "Failed to encode record: {}".format(e) - ) from None + 'Failed to encode record: {}'.format(e)) from None return encoded_records @@ -119,37 +110,34 @@ class AvroDecoder(AvroClient): def decode_record(self, record, encoding="binary"): """ - Decodes a single record represented either as a byte or + Decodes a single record represented either as a byte or base64 string, using the given Avro schema. Returns a dictionary where each key is a field in the schema. """ - self.logger.info( - "Decoding {rec} of type {type} using {schema} schema".format( - rec=record, type=encoding, schema=self.schema.name - ) - ) - + self.logger.info('Decoding {rec} of type {type} using {schema} schema'.format( + rec=record, type=encoding, schema=self.schema.name)) + if encoding == "base64": return self._decode_base64(record) elif encoding == "binary": return self._decode_binary(record) else: - self.logger.error( - "Failed to decode record due to encoding type: {}".format(encoding) - ) - raise AvroClientError("Invalid encoding type: {}".format(encoding)) - + self.logger.error('Failed to decode record due to encoding type: {}'.format(encoding)) + raise AvroClientError( + 'Invalid encoding type: {}'.format(encoding)) + def _decode_base64(self, record): - decoded_data = base64.b64decode(record) + decoded_data = base64.b64decode(record).decode("utf-8") try: return json.loads(decoded_data) except Exception as e: if isinstance(decoded_data, bytes): - return self._decode_binary(decoded_data) + self._decode_binary(decoded_data) else: - self.logger.error("Failed to decode record: {}".format(e)) - raise AvroClientError("Failed to decode record: {}".format(e)) from None + self.logger.error('Failed to decode record: {}'.format(e)) + raise AvroClientError( + 'Failed to decode record: {}'.format(e)) from None def _decode_binary(self, record): datum_reader = DatumReader(self.schema) @@ -158,10 +146,11 @@ def _decode_binary(self, record): try: return datum_reader.read(decoder) except Exception as e: - self.logger.error("Failed to decode record: {}".format(e)) - raise AvroClientError("Failed to decode record: {}".format(e)) from None + self.logger.error('Failed to decode record: {}'.format(e)) + raise AvroClientError( + 'Failed to decode record: {}'.format(e)) from None class AvroClientError(Exception): def __init__(self, message=None): - self.message = message + self.message = message \ No newline at end of file diff --git a/src/nypl_py_utils/classes/kinesis_client.py b/src/nypl_py_utils/classes/kinesis_client.py index ce28b53..1c25b2b 100644 --- a/src/nypl_py_utils/classes/kinesis_client.py +++ b/src/nypl_py_utils/classes/kinesis_client.py @@ -17,19 +17,20 @@ class KinesisClient: """ def __init__(self, stream_arn, batch_size, max_retries=5): - self.logger = create_log("kinesis_client") + self.logger = create_log('kinesis_client') self.stream_arn = stream_arn self.batch_size = batch_size self.max_retries = max_retries try: self.kinesis_client = boto3.client( - "kinesis", region_name=os.environ.get("AWS_REGION", "us-east-1") - ) + 'kinesis', region_name=os.environ.get('AWS_REGION', + 'us-east-1')) except ClientError as e: - self.logger.error("Could not create Kinesis client: {err}".format(err=e)) + self.logger.error( + 'Could not create Kinesis client: {err}'.format(err=e)) raise KinesisClientError( - "Could not create Kinesis client: {err}".format(err=e) + 'Could not create Kinesis client: {err}'.format(err=e) ) from None def close(self): @@ -44,11 +45,10 @@ def send_records(self, records): """ records_sent_since_pause = 0 for i in range(0, len(records), self.batch_size): - encoded_batch = records[i : i + self.batch_size] - kinesis_records = [ - {"Data": record, "PartitionKey": str(int(time.time() * 1000000000))} - for record in encoded_batch - ] + encoded_batch = records[i:i + self.batch_size] + kinesis_records = [{'Data': record, 'PartitionKey': + str(int(time.time() * 1000000000))} + for record in encoded_batch] if records_sent_since_pause + len(encoded_batch) > 1000: records_sent_since_pause = 0 @@ -63,41 +63,32 @@ def _send_kinesis_format_records(self, kinesis_records, call_count): """ if call_count > self.max_retries: self.logger.error( - "Failed to send records to Kinesis {} times in a row".format( - call_count - 1 - ) - ) + 'Failed to send records to Kinesis {} times in a row'.format( + call_count-1)) raise KinesisClientError( - "Failed to send records to Kinesis {} times in a row".format( - call_count - 1 - ) - ) from None + 'Failed to send records to Kinesis {} times in a row'.format( + call_count-1)) from None try: self.logger.info( - "Sending ({count}) records to {arn} Kinesis stream".format( - count=len(kinesis_records), arn=self.stream_arn - ) - ) + 'Sending ({count}) records to {arn} Kinesis stream'.format( + count=len(kinesis_records), arn=self.stream_arn)) response = self.kinesis_client.put_records( - Records=kinesis_records, StreamARN=self.stream_arn - ) - if response["FailedRecordCount"] > 0: + Records=kinesis_records, StreamARN=self.stream_arn) + if response['FailedRecordCount'] > 0: self.logger.warning( - "Failed to send {} records to Kinesis".format( - response["FailedRecordCount"] - ) - ) + 'Failed to send {} records to Kinesis'.format( + response['FailedRecordCount'])) failed_records = [] - for i in range(len(response["Records"])): - if "ErrorCode" in response["Records"][i]: + for i in range(len(response['Records'])): + if 'ErrorCode' in response['Records'][i]: failed_records.append(kinesis_records[i]) - self._send_kinesis_format_records(failed_records, call_count + 1) + self._send_kinesis_format_records(failed_records, call_count+1) except ClientError as e: - self.logger.error("Error sending records to Kinesis: {}".format(e)) + self.logger.error( + 'Error sending records to Kinesis: {}'.format(e)) raise KinesisClientError( - "Error sending records to Kinesis: {}".format(e) - ) from None + 'Error sending records to Kinesis: {}'.format(e)) from None class KinesisClientError(Exception): diff --git a/src/nypl_py_utils/classes/kms_client.py b/src/nypl_py_utils/classes/kms_client.py index abf1684..26ecdef 100644 --- a/src/nypl_py_utils/classes/kms_client.py +++ b/src/nypl_py_utils/classes/kms_client.py @@ -11,17 +11,16 @@ class KmsClient: """Client for interacting with a KMS client""" def __init__(self): - self.logger = create_log("kms_client") + self.logger = create_log('kms_client') try: self.kms_client = boto3.client( - "kms", region_name=os.environ.get("AWS_REGION", "us-east-1") - ) + 'kms', region_name=os.environ.get('AWS_REGION', 'us-east-1')) except ClientError as e: - self.logger.error("Could not create KMS client: {err}".format(err=e)) + self.logger.error( + 'Could not create KMS client: {err}'.format(err=e)) raise KmsClientError( - "Could not create KMS client: {err}".format(err=e) - ) from None + 'Could not create KMS client: {err}'.format(err=e)) from None def close(self): self.kms_client.close() @@ -31,19 +30,16 @@ def decrypt(self, encrypted_text): This method takes a base 64 KMS-encoded string and uses the KMS client to decrypt it into a usable string. """ - self.logger.debug("Decrypting '{}'".format(encrypted_text)) + self.logger.debug('Decrypting \'{}\''.format(encrypted_text)) try: decoded_text = b64decode(encrypted_text) return self.kms_client.decrypt(CiphertextBlob=decoded_text)[ - "Plaintext" - ].decode("utf-8") + 'Plaintext'].decode('utf-8') except (ClientError, base64Error, TypeError) as e: - self.logger.error( - "Could not decrypt '{val}': {err}".format(val=encrypted_text, err=e) - ) - raise KmsClientError( - "Could not decrypt '{val}': {err}".format(val=encrypted_text, err=e) - ) from None + self.logger.error('Could not decrypt \'{val}\': {err}'.format( + val=encrypted_text, err=e)) + raise KmsClientError('Could not decrypt \'{val}\': {err}'.format( + val=encrypted_text, err=e)) from None class KmsClientError(Exception): diff --git a/src/nypl_py_utils/classes/mysql_client.py b/src/nypl_py_utils/classes/mysql_client.py index 7828c24..94bb3c7 100644 --- a/src/nypl_py_utils/classes/mysql_client.py +++ b/src/nypl_py_utils/classes/mysql_client.py @@ -7,7 +7,7 @@ class MySQLClient: """Client for managing connections to a MySQL database""" def __init__(self, host, port, database, user, password): - self.logger = create_log("mysql_client") + self.logger = create_log('mysql_client') self.conn = None self.host = host self.port = port @@ -28,7 +28,7 @@ def connect(self, **kwargs): Whether to automatically commit each query rather than running them as part of a transaction. By default False. """ - self.logger.info("Connecting to {} database".format(self.database)) + self.logger.info('Connecting to {} database'.format(self.database)) try: self.conn = mysql.connector.connect( host=self.host, @@ -36,19 +36,14 @@ def connect(self, **kwargs): database=self.database, user=self.user, password=self.password, - **kwargs, - ) + **kwargs) except mysql.connector.Error as e: self.logger.error( - "Error connecting to {name} database: {error}".format( - name=self.database, error=e - ) - ) + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) raise MySQLClientError( - "Error connecting to {name} database: {error}".format( - name=self.database, error=e - ) - ) from None + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) from None def execute_query(self, query, query_params=None, **kwargs): """ @@ -76,8 +71,8 @@ def execute_query(self, query, query_params=None, **kwargs): or dictionaries (based on the dictionary input) if there's something to return (even if the result set is empty). """ - self.logger.info("Querying {} database".format(self.database)) - self.logger.debug("Executing query {}".format(query)) + self.logger.info('Querying {} database'.format(self.database)) + self.logger.debug('Executing query {}'.format(query)) try: cursor = self.conn.cursor(**kwargs) cursor.execute(query, query_params) @@ -89,21 +84,18 @@ def execute_query(self, query, query_params=None, **kwargs): except Exception as e: self.conn.rollback() self.logger.error( - ("Error executing {name} database query '{query}': {error}").format( - name=self.database, query=query, error=e - ) - ) + ('Error executing {name} database query \'{query}\': {error}') + .format(name=self.database, query=query, error=e)) raise MySQLClientError( - ("Error executing {name} database query '{query}': {error}").format( - name=self.database, query=query, error=e - ) - ) from None + ('Error executing {name} database query \'{query}\': {error}') + .format(name=self.database, query=query, error=e)) from None finally: cursor.close() def close_connection(self): """Closes the database connection""" - self.logger.debug("Closing {} database connection".format(self.database)) + self.logger.debug('Closing {} database connection'.format( + self.database)) self.conn.close() diff --git a/src/nypl_py_utils/classes/oauth2_api_client.py b/src/nypl_py_utils/classes/oauth2_api_client.py index 5e9761f..5a47efb 100644 --- a/src/nypl_py_utils/classes/oauth2_api_client.py +++ b/src/nypl_py_utils/classes/oauth2_api_client.py @@ -15,24 +15,20 @@ class Oauth2ApiClient: API where empty responses are returned intermittently. """ - def __init__( - self, - client_id=None, - client_secret=None, - base_url=None, - token_url=None, - with_retries=False, - ): - self.client_id = client_id or os.environ.get("NYPL_API_CLIENT_ID", None) - self.client_secret = client_secret or os.environ.get( - "NYPL_API_CLIENT_SECRET", None - ) - self.token_url = token_url or os.environ.get("NYPL_API_TOKEN_URL", None) - self.base_url = base_url or os.environ.get("NYPL_API_BASE_URL", None) + def __init__(self, client_id=None, client_secret=None, base_url=None, + token_url=None, with_retries=False): + self.client_id = client_id \ + or os.environ.get('NYPL_API_CLIENT_ID', None) + self.client_secret = client_secret \ + or os.environ.get('NYPL_API_CLIENT_SECRET', None) + self.token_url = token_url \ + or os.environ.get('NYPL_API_TOKEN_URL', None) + self.base_url = base_url \ + or os.environ.get('NYPL_API_BASE_URL', None) self.oauth_client = None - self.logger = create_log("oauth2_api_client") + self.logger = create_log('oauth2_api_client') self.with_retries = with_retries @@ -40,7 +36,7 @@ def get(self, request_path, **kwargs): """ Issue an HTTP GET on the given request_path """ - resp = self._do_http_method("GET", request_path, **kwargs) + resp = self._do_http_method('GET', request_path, **kwargs) # This try/except block is to handle one of at least two possible # Sierra server errors. One is an empty response, and another is a # response with a 200 status code but response text in HTML declaring @@ -50,28 +46,25 @@ def get(self, request_path, **kwargs): except Exception: # build default server error response resp = Response() - resp.message = "Oauth2 Client: Bad response from OauthClient" + resp.message = 'Oauth2 Client: Bad response from OauthClient' resp.status_code = 500 - self.logger.warning( - f"Get request using path {request_path} \ -returned response text:\n{resp.text}" - ) + self.logger.warning(f'Get request using path {request_path} \ +returned response text:\n{resp.text}') # if client has specified that we want to retry failed requests and # we haven't hit max retries if self.with_retries is True: - retries = kwargs.get("retries", 0) + 1 + retries = kwargs.get('retries', 0) + 1 if retries < 3: self.logger.warning( - f"Retrying get request due to empty response from\ -Oauth2 Client using path: {request_path}. Retry #{retries}" - ) + f'Retrying get request due to empty response from\ +Oauth2 Client using path: {request_path}. Retry #{retries}') sleep(pow(2, retries - 1)) - kwargs["retries"] = retries + kwargs['retries'] = retries # try request again resp = self.get(request_path, **kwargs) else: - resp.message = "Oauth2 Client: Request failed after 3 \ - empty responses received from Oauth2 Client" + resp.message = 'Oauth2 Client: Request failed after 3 \ + empty responses received from Oauth2 Client' # Return request. If retries returned real data, it will be here, # otherwise it will be the default 500 response generated earlier. return resp @@ -80,21 +73,21 @@ def post(self, request_path, json, **kwargs): """ Issue an HTTP POST on the given request_path with given JSON body """ - kwargs["json"] = json - return self._do_http_method("POST", request_path, **kwargs) + kwargs['json'] = json + return self._do_http_method('POST', request_path, **kwargs) def patch(self, request_path, json, **kwargs): """ Issue an HTTP PATCH on the given request_path with given JSON body """ - kwargs["json"] = json - return self._do_http_method("PATCH", request_path, **kwargs) + kwargs['json'] = json + return self._do_http_method('PATCH', request_path, **kwargs) def delete(self, request_path, **kwargs): """ Issue an HTTP DELETE on the given request_path """ - return self._do_http_method("DELETE", request_path, **kwargs) + return self._do_http_method('DELETE', request_path, **kwargs) def _do_http_method(self, method, request_path, **kwargs): """ @@ -103,26 +96,25 @@ def _do_http_method(self, method, request_path, **kwargs): if not self.oauth_client: self._create_oauth_client() - url = f"{self.base_url}/{request_path}" - self.logger.debug(f"{method} {url}") + url = f'{self.base_url}/{request_path}' + self.logger.debug(f'{method} {url}') try: # Build kwargs cleaned of local variables: - kwargs_cleaned = { - k: kwargs[k] for k in kwargs if not k.startswith("_do_http_method_") - } + kwargs_cleaned = {k: kwargs[k] for k in kwargs + if not k.startswith('_do_http_method_')} resp = self.oauth_client.request(method, url, **kwargs_cleaned) resp.raise_for_status() return resp except TokenExpiredError: - self.logger.debug("TokenExpiredError encountered") + self.logger.debug('TokenExpiredError encountered') # Raise error after 3 successive token refreshes - kwargs["_do_http_method_token_refreshes"] = ( - kwargs.get("_do_http_method_token_refreshes", 0) + 1 - ) - if kwargs["_do_http_method_token_refreshes"] > 3: - raise Oauth2ApiClientError("Exhausted token refreshes") from None + kwargs['_do_http_method_token_refreshes'] = \ + kwargs.get('_do_http_method_token_refreshes', 0) + 1 + if kwargs['_do_http_method_token_refreshes'] > 3: + raise Oauth2ApiClientError('Exhausted token refreshes') \ + from None self._generate_access_token() return self._do_http_method(method, request_path, **kwargs) @@ -139,11 +131,11 @@ def _generate_access_token(self): """ Fetch and store a fresh token """ - self.logger.debug(f"Refreshing token via @{self.token_url}") + self.logger.debug(f'Refreshing token via @{self.token_url}') self.oauth_client.fetch_token( token_url=self.token_url, client_id=self.client_id, - client_secret=self.client_secret, + client_secret=self.client_secret ) diff --git a/src/nypl_py_utils/classes/postgresql_client.py b/src/nypl_py_utils/classes/postgresql_client.py index 569a203..05c7a97 100644 --- a/src/nypl_py_utils/classes/postgresql_client.py +++ b/src/nypl_py_utils/classes/postgresql_client.py @@ -7,11 +7,12 @@ class PostgreSQLClient: """Client for managing individual connections to a PostgreSQL database""" def __init__(self, host, port, db_name, user, password): - self.logger = create_log("postgresql_client") + self.logger = create_log('postgresql_client') self.conn = None - self.conn_info = ( - "postgresql://{user}:{password}@{host}:{port}/" "{db_name}" - ).format(user=user, password=password, host=host, port=port, db_name=db_name) + self.conn_info = ('postgresql://{user}:{password}@{host}:{port}/' + '{db_name}').format(user=user, password=password, + host=host, port=port, + db_name=db_name) self.db_name = db_name @@ -32,20 +33,16 @@ def connect(self, **kwargs): returned. Defaults to tuple_row, which returns the rows as a list of tuples. """ - self.logger.info("Connecting to {} database".format(self.db_name)) + self.logger.info('Connecting to {} database'.format(self.db_name)) try: self.conn = psycopg.connect(self.conn_info, **kwargs) except psycopg.Error as e: self.logger.error( - "Error connecting to {name} database: {error}".format( - name=self.db_name, error=e - ) - ) + 'Error connecting to {name} database: {error}'.format( + name=self.db_name, error=e)) raise PostgreSQLClientError( - "Error connecting to {name} database: {error}".format( - name=self.db_name, error=e - ) - ) from None + 'Error connecting to {name} database: {error}'.format( + name=self.db_name, error=e)) from None def execute_query(self, query, query_params=None, **kwargs): """ @@ -68,8 +65,8 @@ def execute_query(self, query, query_params=None, **kwargs): based on the connection's row_factory if there's something to return (even if the result set is empty). """ - self.logger.info("Querying {} database".format(self.db_name)) - self.logger.debug("Executing query {}".format(query)) + self.logger.info('Querying {} database'.format(self.db_name)) + self.logger.debug('Executing query {}'.format(query)) try: cursor = self.conn.cursor() cursor.execute(query, query_params, **kwargs) @@ -78,21 +75,20 @@ def execute_query(self, query, query_params=None, **kwargs): except Exception as e: self.conn.rollback() self.logger.error( - ("Error executing {name} database query '{query}': " "{error}").format( - name=self.db_name, query=query, error=e - ) - ) + ('Error executing {name} database query \'{query}\': ' + '{error}').format( + name=self.db_name, query=query, error=e)) raise PostgreSQLClientError( - ("Error executing {name} database query '{query}': " "{error}").format( - name=self.db_name, query=query, error=e - ) - ) from None + ('Error executing {name} database query \'{query}\': ' + '{error}').format( + name=self.db_name, query=query, error=e)) from None finally: cursor.close() def close_connection(self): """Closes the database connection""" - self.logger.debug("Closing {} database connection".format(self.db_name)) + self.logger.debug('Closing {} database connection'.format( + self.db_name)) self.conn.close() diff --git a/src/nypl_py_utils/classes/postgresql_pool_client.py b/src/nypl_py_utils/classes/postgresql_pool_client.py index 47a2ca9..beaf589 100644 --- a/src/nypl_py_utils/classes/postgresql_pool_client.py +++ b/src/nypl_py_utils/classes/postgresql_pool_client.py @@ -8,9 +8,8 @@ class PostgreSQLPoolClient: """Client for managing a connection pool to a PostgreSQL database""" - def __init__( - self, host, port, db_name, user, password, conn_timeout=300.0, **kwargs - ): + def __init__(self, host, port, db_name, user, password, conn_timeout=300.0, + **kwargs): """ Creates (but does not open) a connection pool. @@ -33,30 +32,25 @@ def __init__( min_size connections, which will stay open until manually closed. """ - self.logger = create_log("postgresql_client") - self.conn_info = ( - "postgresql://{user}:{password}@{host}:{port}/" "{db_name}" - ).format(user=user, password=password, host=host, port=port, db_name=db_name) + self.logger = create_log('postgresql_client') + self.conn_info = ('postgresql://{user}:{password}@{host}:{port}/' + '{db_name}').format(user=user, password=password, + host=host, port=port, + db_name=db_name) self.db_name = db_name self.kwargs = kwargs - self.kwargs["min_size"] = kwargs.get("min_size", 0) - self.kwargs["max_size"] = kwargs.get("max_size", 1) - self.kwargs["max_idle"] = kwargs.get("max_idle", 90.0) - - if self.kwargs["max_idle"] > 150.0: - self.logger.error( - ( - "max_idle is too high -- values over 150 seconds are unsafe " - "and may lead to connection leakages in ECS" - ) - ) - raise PostgreSQLPoolClientError( - ( - "max_idle is too high -- values over 150 seconds are unsafe " - "and may lead to connection leakages in ECS" - ) - ) from None + self.kwargs['min_size'] = kwargs.get('min_size', 0) + self.kwargs['max_size'] = kwargs.get('max_size', 1) + self.kwargs['max_idle'] = kwargs.get('max_idle', 90.0) + + if self.kwargs['max_idle'] > 150.0: + self.logger.error(( + 'max_idle is too high -- values over 150 seconds are unsafe ' + 'and may lead to connection leakages in ECS')) + raise PostgreSQLPoolClientError(( + 'max_idle is too high -- values over 150 seconds are unsafe ' + 'and may lead to connection leakages in ECS')) from None self.pool = ConnectionPool(self.conn_info, open=False, **self.kwargs) @@ -71,24 +65,22 @@ def connect(self, timeout=300.0): The number of seconds to try connecting before throwing an error. Defaults to 300 seconds. """ - self.logger.info("Connecting to {} database".format(self.db_name)) + self.logger.info('Connecting to {} database'.format(self.db_name)) try: if self.pool is None: - self.pool = ConnectionPool(self.conn_info, open=False, **self.kwargs) + self.pool = ConnectionPool( + self.conn_info, open=False, **self.kwargs) self.pool.open(wait=True, timeout=timeout) except psycopg.Error as e: self.logger.error( - "Error connecting to {name} database: {error}".format( - name=self.db_name, error=e - ) - ) + 'Error connecting to {name} database: {error}'.format( + name=self.db_name, error=e)) raise PostgreSQLPoolClientError( - "Error connecting to {name} database: {error}".format( - name=self.db_name, error=e - ) - ) from None + 'Error connecting to {name} database: {error}'.format( + name=self.db_name, error=e)) from None - def execute_query(self, query, query_params=None, row_factory=tuple_row, **kwargs): + def execute_query(self, query, query_params=None, row_factory=tuple_row, + **kwargs): """ Requests a connection from the pool and uses it to execute an arbitrary query. After the query is complete, either commits it or rolls it back, @@ -114,28 +106,28 @@ def execute_query(self, query, query_params=None, row_factory=tuple_row, **kwarg based on the row_factory input if there's something to return (even if the result set is empty). """ - self.logger.info("Querying {} database".format(self.db_name)) - self.logger.debug("Executing query {}".format(query)) + self.logger.info('Querying {} database'.format(self.db_name)) + self.logger.debug('Executing query {}'.format(query)) with self.pool.connection() as conn: try: conn.row_factory = row_factory cursor = conn.execute(query, query_params, **kwargs) - return None if cursor.description is None else cursor.fetchall() + return (None if cursor.description is None + else cursor.fetchall()) except Exception as e: self.logger.error( - ( - "Error executing {name} database query '{query}': " "{error}" - ).format(name=self.db_name, query=query, error=e) - ) + ('Error executing {name} database query \'{query}\': ' + '{error}').format( + name=self.db_name, query=query, error=e)) raise PostgreSQLPoolClientError( - ( - "Error executing {name} database query '{query}': " "{error}" - ).format(name=self.db_name, query=query, error=e) - ) from None + ('Error executing {name} database query \'{query}\': ' + '{error}').format( + name=self.db_name, query=query, error=e)) from None def close_pool(self): """Closes the connection pool""" - self.logger.debug("Closing {} database connection pool".format(self.db_name)) + self.logger.debug('Closing {} database connection pool'.format( + self.db_name)) self.pool.close() self.pool = None diff --git a/src/nypl_py_utils/classes/redshift_client.py b/src/nypl_py_utils/classes/redshift_client.py index 2fc3ad9..17c4558 100644 --- a/src/nypl_py_utils/classes/redshift_client.py +++ b/src/nypl_py_utils/classes/redshift_client.py @@ -8,7 +8,7 @@ class RedshiftClient: """Client for managing connections to Redshift""" def __init__(self, host, database, user, password): - self.logger = create_log("redshift_client") + self.logger = create_log('redshift_client') self.conn = None self.host = host self.database = database @@ -17,26 +17,21 @@ def __init__(self, host, database, user, password): def connect(self): """Connects to a Redshift database using the given credentials""" - self.logger.info("Connecting to {} database".format(self.database)) + self.logger.info('Connecting to {} database'.format(self.database)) try: self.conn = redshift_connector.connect( host=self.host, database=self.database, user=self.user, password=self.password, - sslmode="verify-full", - ) + sslmode='verify-full') except ClientError as e: self.logger.error( - "Error connecting to {name} database: {error}".format( - name=self.database, error=e - ) - ) + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) raise RedshiftClientError( - "Error connecting to {name} database: {error}".format( - name=self.database, error=e - ) - ) from None + 'Error connecting to {name} database: {error}'.format( + name=self.database, error=e)) from None def execute_query(self, query, dataframe=False): """ @@ -56,8 +51,8 @@ def execute_query(self, query, dataframe=False): A list of tuples or a pandas DataFrame (based on the `dataframe` input) """ - self.logger.info("Querying {} database".format(self.database)) - self.logger.debug("Executing query {}".format(query)) + self.logger.info('Querying {} database'.format(self.database)) + self.logger.debug('Executing query {}'.format(query)) try: cursor = self.conn.cursor() cursor.execute(query) @@ -68,15 +63,11 @@ def execute_query(self, query, dataframe=False): except Exception as e: self.conn.rollback() self.logger.error( - ("Error executing {name} database query '{query}': {error}").format( - name=self.database, query=query, error=e - ) - ) + ('Error executing {name} database query \'{query}\': {error}') + .format(name=self.database, query=query, error=e)) raise RedshiftClientError( - ("Error executing {name} database query '{query}': {error}").format( - name=self.database, query=query, error=e - ) - ) from None + ('Error executing {name} database query \'{query}\': {error}') + .format(name=self.database, query=query, error=e)) from None finally: cursor.close() @@ -97,40 +88,37 @@ def execute_transaction(self, queries): "INSERT INTO x VALUES (%s, %s)", [(1, "a"), (2, "b")]) """ - self.logger.info( - "Executing transaction against {} database".format(self.database) - ) + self.logger.info('Executing transaction against {} database'.format( + self.database)) try: cursor = self.conn.cursor() - cursor.execute("BEGIN TRANSACTION;") + cursor.execute('BEGIN TRANSACTION;') for query in queries: - self.logger.debug("Executing query {}".format(query)) + self.logger.debug('Executing query {}'.format(query)) if query[1] is not None and all( - isinstance(el, tuple) or isinstance(el, list) for el in query[1] + isinstance(el, tuple) or isinstance(el, list) + for el in query[1] ): cursor.executemany(query[0], query[1]) else: cursor.execute(query[0], query[1]) - cursor.execute("END TRANSACTION;") + cursor.execute('END TRANSACTION;') self.conn.commit() except Exception as e: self.conn.rollback() self.logger.error( - ("Error executing {name} database transaction: {error}").format( - name=self.database, error=e - ) - ) + ('Error executing {name} database transaction: {error}') + .format(name=self.database, error=e)) raise RedshiftClientError( - ("Error executing {name} database transaction: {error}").format( - name=self.database, error=e - ) - ) from None + ('Error executing {name} database transaction: {error}') + .format(name=self.database, error=e)) from None finally: cursor.close() def close_connection(self): """Closes the database connection""" - self.logger.debug("Closing {} database connection".format(self.database)) + self.logger.debug('Closing {} database connection'.format( + self.database)) self.conn.close() diff --git a/src/nypl_py_utils/classes/s3_client.py b/src/nypl_py_utils/classes/s3_client.py index 4536624..af71531 100644 --- a/src/nypl_py_utils/classes/s3_client.py +++ b/src/nypl_py_utils/classes/s3_client.py @@ -15,66 +15,56 @@ class S3Client: """ def __init__(self, bucket, resource): - self.logger = create_log("s3_client") + self.logger = create_log('s3_client') self.bucket = bucket self.resource = resource try: self.s3_client = boto3.client( - "s3", region_name=os.environ.get("AWS_REGION", "us-east-1") - ) + 's3', region_name=os.environ.get('AWS_REGION', 'us-east-1')) except ClientError as e: - self.logger.error("Could not create S3 client: {err}".format(err=e)) + self.logger.error( + 'Could not create S3 client: {err}'.format(err=e)) raise S3ClientError( - "Could not create S3 client: {err}".format(err=e) - ) from None + 'Could not create S3 client: {err}'.format(err=e)) from None def close(self): self.s3_client.close() def fetch_cache(self): """Fetches a JSON file from S3 and returns the resulting dictionary""" - self.logger.info( - "Fetching {file} from S3 bucket {bucket}".format( - file=self.resource, bucket=self.bucket - ) - ) + self.logger.info('Fetching {file} from S3 bucket {bucket}'.format( + file=self.resource, bucket=self.bucket)) try: output_stream = BytesIO() - self.s3_client.download_fileobj(self.bucket, self.resource, output_stream) + self.s3_client.download_fileobj( + self.bucket, self.resource, output_stream) return json.loads(output_stream.getvalue()) except ClientError as e: self.logger.error( - "Error retrieving {file} from S3 bucket {bucket}: {error}".format( - file=self.resource, bucket=self.bucket, error=e - ) - ) + 'Error retrieving {file} from S3 bucket {bucket}: {error}' + .format(file=self.resource, bucket=self.bucket, error=e)) raise S3ClientError( - "Error retrieving {file} from S3 bucket {bucket}: {error}".format( - file=self.resource, bucket=self.bucket, error=e - ) + 'Error retrieving {file} from S3 bucket {bucket}: {error}' + .format(file=self.resource, bucket=self.bucket, error=e) ) from None def set_cache(self, state): """Writes a dictionary to JSON and uploads the resulting file to S3""" self.logger.info( - "Setting {file} in S3 bucket {bucket} to {state}".format( - file=self.resource, bucket=self.bucket, state=state - ) - ) + 'Setting {file} in S3 bucket {bucket} to {state}'.format( + file=self.resource, bucket=self.bucket, state=state)) try: input_stream = BytesIO(json.dumps(state).encode()) - self.s3_client.upload_fileobj(input_stream, self.bucket, self.resource) + self.s3_client.upload_fileobj( + input_stream, self.bucket, self.resource) except ClientError as e: self.logger.error( - "Error uploading {file} to S3 bucket {bucket}: {error}".format( - file=self.resource, bucket=self.bucket, error=e - ) - ) + 'Error uploading {file} to S3 bucket {bucket}: {error}' + .format(file=self.resource, bucket=self.bucket, error=e)) raise S3ClientError( - "Error uploading {file} to S3 bucket {bucket}: {error}".format( - file=self.resource, bucket=self.bucket, error=e - ) + 'Error uploading {file} to S3 bucket {bucket}: {error}' + .format(file=self.resource, bucket=self.bucket, error=e) ) from None diff --git a/src/nypl_py_utils/functions/config_helper.py b/src/nypl_py_utils/functions/config_helper.py index c0192e5..7edb5ea 100644 --- a/src/nypl_py_utils/functions/config_helper.py +++ b/src/nypl_py_utils/functions/config_helper.py @@ -5,7 +5,7 @@ from nypl_py_utils.classes.kms_client import KmsClient from nypl_py_utils.functions.log_helper import create_log -logger = create_log("config_helper") +logger = create_log('config_helper') def load_env_file(run_type, file_string): @@ -30,31 +30,29 @@ def load_env_file(run_type, file_string): env_dict = None open_file = file_string.format(run_type) - logger.info("Loading env file {}".format(open_file)) + logger.info('Loading env file {}'.format(open_file)) try: - with open(open_file, "r") as env_stream: + with open(open_file, 'r') as env_stream: try: env_dict = yaml.safe_load(env_stream) except yaml.YAMLError: - logger.error("Invalid YAML file: {}".format(open_file)) + logger.error('Invalid YAML file: {}'.format(open_file)) raise ConfigHelperError( - "Invalid YAML file: {}".format(open_file) - ) from None + 'Invalid YAML file: {}'.format(open_file)) from None except FileNotFoundError: - logger.error("Could not find config file {}".format(open_file)) + logger.error('Could not find config file {}'.format(open_file)) raise ConfigHelperError( - "Could not find config file {}".format(open_file) - ) from None + 'Could not find config file {}'.format(open_file)) from None if env_dict: - for key, value in env_dict.get("PLAINTEXT_VARIABLES", {}).items(): + for key, value in env_dict.get('PLAINTEXT_VARIABLES', {}).items(): if type(value) is list: os.environ[key] = json.dumps(value) else: os.environ[key] = str(value) kms_client = KmsClient() - for key, value in env_dict.get("ENCRYPTED_VARIABLES", {}).items(): + for key, value in env_dict.get('ENCRYPTED_VARIABLES', {}).items(): if type(value) is list: decrypted_list = [kms_client.decrypt(v) for v in value] os.environ[key] = json.dumps(decrypted_list) diff --git a/src/nypl_py_utils/functions/log_helper.py b/src/nypl_py_utils/functions/log_helper.py index 7eb7b83..7d7bf78 100644 --- a/src/nypl_py_utils/functions/log_helper.py +++ b/src/nypl_py_utils/functions/log_helper.py @@ -3,11 +3,11 @@ import sys levels = { - "debug": logging.DEBUG, - "info": logging.INFO, - "warning": logging.WARNING, - "error": logging.ERROR, - "critical": logging.CRITICAL, + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warning': logging.WARNING, + 'error': logging.ERROR, + 'critical': logging.CRITICAL } @@ -18,12 +18,13 @@ def create_log(module): console_log = logging.StreamHandler(stream=sys.stdout) - log_level = os.environ.get("LOG_LEVEL", "info").lower() + log_level = os.environ.get('LOG_LEVEL', 'info').lower() logger.setLevel(levels[log_level]) console_log.setLevel(levels[log_level]) - formatter = logging.Formatter("%(asctime)s | %(name)s | %(levelname)s: %(message)s") + formatter = logging.Formatter( + '%(asctime)s | %(name)s | %(levelname)s: %(message)s') console_log.setFormatter(formatter) logger.addHandler(console_log) diff --git a/src/nypl_py_utils/functions/obfuscation_helper.py b/src/nypl_py_utils/functions/obfuscation_helper.py index e3cc4ae..4209f86 100644 --- a/src/nypl_py_utils/functions/obfuscation_helper.py +++ b/src/nypl_py_utils/functions/obfuscation_helper.py @@ -3,7 +3,7 @@ from nypl_py_utils.functions.log_helper import create_log -logger = create_log("obfuscation_helper") +logger = create_log('obfuscation_helper') def obfuscate(input): @@ -16,11 +16,11 @@ def obfuscate(input): but is converted to a string before being obfuscated. The obfuscation salt is read from the `BCRYPT_SALT` environment variable. """ - logger.debug("Obfuscating input '{}' with environment salt".format(input)) - hash = bcrypt.hashpw( - str(input).encode(), os.environ["BCRYPT_SALT"].encode() - ).decode() - return hash.split(os.environ["BCRYPT_SALT"])[-1] + logger.debug('Obfuscating input \'{}\' with environment salt'.format( + input)) + hash = bcrypt.hashpw(str(input).encode(), + os.environ['BCRYPT_SALT'].encode()).decode() + return hash.split(os.environ['BCRYPT_SALT'])[-1] def obfuscate_with_salt(input, salt): @@ -28,6 +28,6 @@ def obfuscate_with_salt(input, salt): This method is the same as `obfuscate` above but takes the obfuscation salt as a string input. """ - logger.debug("Obfuscating input '{}' with custom salt".format(input)) + logger.debug('Obfuscating input \'{}\' with custom salt'.format(input)) hash = bcrypt.hashpw(str(input).encode(), salt.encode()).decode() return hash.split(salt)[-1] diff --git a/src/nypl_py_utils/functions/research_catalog_identifier_helper.py b/src/nypl_py_utils/functions/research_catalog_identifier_helper.py index 00afc63..4079faf 100644 --- a/src/nypl_py_utils/functions/research_catalog_identifier_helper.py +++ b/src/nypl_py_utils/functions/research_catalog_identifier_helper.py @@ -16,51 +16,60 @@ def parse_research_catalog_identifier(identifier: str): - id: The numeric string id """ if not isinstance(identifier, str): - raise ResearchCatalogIdentifierError(f"Invalid RC identifier: {identifier}") + raise ResearchCatalogIdentifierError( + f'Invalid RC identifier: {identifier}') # Extract prefix from the identifier: - match = re.match(r"^([a-z]+)", identifier) + match = re.match(r'^([a-z]+)', identifier) if match is None: - raise ResearchCatalogIdentifierError(f"Invalid RC identifier: {identifier}") + raise ResearchCatalogIdentifierError( + f'Invalid RC identifier: {identifier}') prefix = match[0] # The id is the identifier without the prefix: - id = identifier.replace(prefix, "") + id = identifier.replace(prefix, '') nyplType = None nyplSource = None # Look up nyplType and nyplSource in nypl-core based on the prefix: for _nyplSource, mapping in nypl_core_source_mapping().items(): - if mapping.get("bibPrefix") == prefix: - nyplType = "bib" - elif mapping.get("itemPrefix") == prefix: - nyplType = "item" - elif mapping.get("holdingPrefix") == prefix: - nyplType = "holding" + if mapping.get('bibPrefix') == prefix: + nyplType = 'bib' + elif mapping.get('itemPrefix') == prefix: + nyplType = 'item' + elif mapping.get('holdingPrefix') == prefix: + nyplType = 'holding' if nyplType is not None: nyplSource = _nyplSource break if nyplSource is None: - raise ResearchCatalogIdentifierError(f"Invalid RC identifier: {identifier}") + raise ResearchCatalogIdentifierError( + f'Invalid RC identifier: {identifier}') - return {"nyplSource": nyplSource, "nyplType": nyplType, "id": id} + return { + 'nyplSource': nyplSource, + 'nyplType': nyplType, + 'id': id + } -def research_catalog_id_prefix(nyplSource: str, nyplType="bib"): +def research_catalog_id_prefix(nyplSource: str, nyplType='bib'): """ Given a nyplSource (e.g. 'sierra-nypl') and nyplType (e.g. 'item'), returns the relevant prefix used in the RC identifier (e.g. 'i') """ if nypl_core_source_mapping().get(nyplSource) is None: - raise ResearchCatalogIdentifierError(f"Invalid nyplSource: {nyplSource}") + raise ResearchCatalogIdentifierError( + f'Invalid nyplSource: {nyplSource}') if not isinstance(nyplType, str): - raise ResearchCatalogIdentifierError(f"Invalid nyplType: {nyplType}") + raise ResearchCatalogIdentifierError( + f'Invalid nyplType: {nyplType}') - prefixKey = f"{nyplType}Prefix" + prefixKey = f'{nyplType}Prefix' if nypl_core_source_mapping()[nyplSource].get(prefixKey) is None: - raise ResearchCatalogIdentifierError(f"Invalid nyplType: {nyplType}") + raise ResearchCatalogIdentifierError(f'Invalid nyplType: {nyplType}') return nypl_core_source_mapping()[nyplSource][prefixKey] @@ -69,33 +78,29 @@ def nypl_core_source_mapping(): """ Builds a nypl-source-mapping by retrieving the mapping from NYPL-Core """ - name = "nypl-core-source-mapping" + name = 'nypl-core-source-mapping' if not CACHE.get(name) is None: return CACHE[name] - url = os.environ.get( - "NYPL_CORE_SOURCE_MAPPING_URL", - "https://raw.githubusercontent.com/NYPL/nypl-core/master/mappings/recap-discovery/nypl-source-mapping.json", - ) # noqa + url = os.environ.get('NYPL_CORE_SOURCE_MAPPING_URL', + 'https://raw.githubusercontent.com/NYPL/nypl-core/master/mappings/recap-discovery/nypl-source-mapping.json') # noqa try: response = requests.get(url) response.raise_for_status() except RequestException as e: raise ResearchCatalogIdentifierError( - "Failed to retrieve nypl-core source-mapping file from {url}:" - " {errorType} {errorMessage}".format( - url=url, errorType=type(e), errorMessage=e - ) - ) from None + 'Failed to retrieve nypl-core source-mapping file from {url}:' + ' {errorType} {errorMessage}' + .format(url=url, errorType=type(e), errorMessage=e)) from None try: CACHE[name] = response.json() return CACHE[name] except (JSONDecodeError, KeyError) as e: raise ResearchCatalogIdentifierError( - "Failed to parse nypl-core source-mapping file: {errorType}" - " {errorMessage}".format(errorType=type(e), errorMessage=e) - ) from None + 'Failed to parse nypl-core source-mapping file: {errorType}' + ' {errorMessage}' + .format(errorType=type(e), errorMessage=e)) from None class ResearchCatalogIdentifierError(Exception): diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py index 7a4946d..7e26981 100644 --- a/tests/test_avro_client.py +++ b/tests/test_avro_client.py @@ -4,113 +4,99 @@ from nypl_py_utils.classes.avro_client import AvroDecoder, AvroEncoder, AvroClientError from requests.exceptions import ConnectTimeout -_TEST_SCHEMA = { - "data": { - "schema": json.dumps( - { - "name": "TestSchema", - "type": "record", - "fields": [ - {"name": "patron_id", "type": "int"}, - {"name": "library_branch", "type": ["null", "string"]}, - ], - } - ) - } -} +_TEST_SCHEMA = {'data': {'schema': json.dumps({ + 'name': 'TestSchema', + 'type': 'record', + 'fields': [ + { + 'name': 'patron_id', + 'type': 'int' + }, + { + 'name': 'library_branch', + 'type': ['null', 'string'] + } + ] +})}} class TestAvroClient: @pytest.fixture def test_avro_encoder_instance(self, requests_mock): - requests_mock.get("https://test_schema_url", text=json.dumps(_TEST_SCHEMA)) - return AvroEncoder("https://test_schema_url") + requests_mock.get( + 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + return AvroEncoder('https://test_schema_url') @pytest.fixture def test_avro_decoder_instance(self, requests_mock): - requests_mock.get("https://test_schema_url", text=json.dumps(_TEST_SCHEMA)) - # requests_mock.get( - # 'https://test_schema_url', text=json.dumps(_CIRC_TRANS_SCHEMA)) - return AvroDecoder("https://test_schema_url") + requests_mock.get( + 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + return AvroDecoder('https://test_schema_url') - def test_get_json_schema( - self, test_avro_encoder_instance, test_avro_decoder_instance - ): - assert test_avro_encoder_instance.schema == _TEST_SCHEMA["data"]["schema"] - assert test_avro_decoder_instance.schema == _TEST_SCHEMA["data"]["schema"] + def test_get_json_schema(self, test_avro_encoder_instance, test_avro_decoder_instance): + assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data']['schema'] + assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data']['schema'] def test_request_error(self, requests_mock): - requests_mock.get("https://test_schema_url", exc=ConnectTimeout) + requests_mock.get('https://test_schema_url', exc=ConnectTimeout) with pytest.raises(AvroClientError): - AvroEncoder("https://test_schema_url") + AvroEncoder('https://test_schema_url') def test_bad_json_error(self, requests_mock): - requests_mock.get("https://test_schema_url", text="bad json") + requests_mock.get( + 'https://test_schema_url', text='bad json') with pytest.raises(AvroClientError): - AvroEncoder("https://test_schema_url") + AvroEncoder('https://test_schema_url') def test_missing_key_error(self, requests_mock): requests_mock.get( - "https://test_schema_url", text=json.dumps({"field": "value"}) - ) + 'https://test_schema_url', text=json.dumps({'field': 'value'})) with pytest.raises(AvroClientError): - AvroEncoder("https://test_schema_url") + AvroEncoder('https://test_schema_url') - def test_encode_record( - self, test_avro_encoder_instance, test_avro_decoder_instance - ): - TEST_RECORD = {"patron_id": 123, "library_branch": "aa"} + def test_encode_record(self, test_avro_encoder_instance, test_avro_decoder_instance): + TEST_RECORD = {'patron_id': 123, 'library_branch': 'aa'} encoded_record = test_avro_encoder_instance.encode_record(TEST_RECORD) assert type(encoded_record) is bytes assert test_avro_decoder_instance.decode_record(encoded_record) == TEST_RECORD def test_encode_record_error(self, test_avro_encoder_instance): - TEST_RECORD = {"patron_id": 123, "bad_field": "bad"} + TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} with pytest.raises(AvroClientError): test_avro_encoder_instance.encode_record(TEST_RECORD) def test_encode_batch(self, test_avro_encoder_instance, test_avro_decoder_instance): TEST_BATCH = [ - {"patron_id": 123, "library_branch": "aa"}, - {"patron_id": 456, "library_branch": None}, - {"patron_id": 789, "library_branch": "bb"}, - ] + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'library_branch': None}, + {'patron_id': 789, 'library_branch': 'bb'}] encoded_records = test_avro_encoder_instance.encode_batch(TEST_BATCH) assert len(encoded_records) == len(TEST_BATCH) for i in range(3): assert type(encoded_records[i]) is bytes - assert ( - test_avro_decoder_instance.decode_record(encoded_records[i]) - == TEST_BATCH[i] - ) + assert test_avro_decoder_instance.decode_record( + encoded_records[i]) == TEST_BATCH[i] def test_encode_batch_error(self, test_avro_encoder_instance): BAD_BATCH = [ - {"patron_id": 123, "library_branch": "aa"}, - {"patron_id": 456, "bad_field": "bad"}, - ] + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'bad_field': 'bad'}] with pytest.raises(AvroClientError): test_avro_encoder_instance.encode_batch(BAD_BATCH) def test_decode_record_binary(self, test_avro_decoder_instance): TEST_DECODED_RECORD = {"patron_id": 123, "library_branch": "aa"} - TEST_ENCODED_RECORD = b"\xf6\x01\x02\x04aa" - assert ( - test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) - == TEST_DECODED_RECORD - ) - + TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' + assert test_avro_decoder_instance.decode_record( + TEST_ENCODED_RECORD) == TEST_DECODED_RECORD + def test_decode_record_b64(self, test_avro_decoder_instance): TEST_DECODED_RECORD = {"patron_id'": 123, "library_branch": "aa"} - TEST_ENCODED_RECORD = ( - "eyJwYXRyb25faWQnIjogMTIzLCAibGlicmFyeV9icmFuY2giOiAiYWEifQ==" - ) - assert ( - test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD, "base64") - == TEST_DECODED_RECORD - ) + TEST_ENCODED_RECORD = "eyJwYXRyb25faWQnIjogMTIzLCAibGlicmFyeV9icmFuY2giOiAiYWEifQ==" + assert test_avro_decoder_instance.decode_record( + TEST_ENCODED_RECORD, "base64") == TEST_DECODED_RECORD def test_decode_record_error(self, test_avro_decoder_instance): - TEST_ENCODED_RECORD = b"bad-encoding" + TEST_ENCODED_RECORD = b'bad-encoding' with pytest.raises(AvroClientError): - test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) \ No newline at end of file + test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) diff --git a/tests/test_config_helper.py b/tests/test_config_helper.py index 9c3b149..feadfe5 100644 --- a/tests/test_config_helper.py +++ b/tests/test_config_helper.py @@ -1,18 +1,15 @@ import os import pytest -from nypl_py_utils.functions.config_helper import load_env_file, ConfigHelperError +from nypl_py_utils.functions.config_helper import ( + load_env_file, ConfigHelperError) _TEST_VARIABLE_NAMES = [ - "TEST_STRING", - "TEST_INT", - "TEST_LIST", - "TEST_ENCRYPTED_VARIABLE_1", - "TEST_ENCRYPTED_VARIABLE_2", - "TEST_ENCRYPTED_LIST", -] + 'TEST_STRING', 'TEST_INT', 'TEST_LIST', 'TEST_ENCRYPTED_VARIABLE_1', + 'TEST_ENCRYPTED_VARIABLE_2', 'TEST_ENCRYPTED_LIST'] -_TEST_CONFIG_CONTENTS = """--- +_TEST_CONFIG_CONTENTS = \ + '''--- PLAINTEXT_VARIABLES: TEST_STRING: string-variable TEST_INT: 1 @@ -25,7 +22,7 @@ TEST_ENCRYPTED_LIST: - test-encryption-3 - test-encryption-4 -...""" +...''' class TestConfigHelper: @@ -33,42 +30,30 @@ class TestConfigHelper: def test_load_env_file(self, mocker): mock_kms_client = mocker.MagicMock() mock_kms_client.decrypt.side_effect = [ - "test-decryption-1", - "test-decryption-2", - "test-decryption-3", - "test-decryption-4", - ] - mocker.patch( - "nypl_py_utils.functions.config_helper.KmsClient", - return_value=mock_kms_client, - ) + 'test-decryption-1', 'test-decryption-2', 'test-decryption-3', + 'test-decryption-4'] + mocker.patch('nypl_py_utils.functions.config_helper.KmsClient', + return_value=mock_kms_client) mock_file_open = mocker.patch( - "builtins.open", mocker.mock_open(read_data=_TEST_CONFIG_CONTENTS) - ) + 'builtins.open', mocker.mock_open(read_data=_TEST_CONFIG_CONTENTS)) for key in _TEST_VARIABLE_NAMES: assert key not in os.environ - load_env_file("test-env", "test-path/{}.yaml") + load_env_file('test-env', 'test-path/{}.yaml') - mock_file_open.assert_called_once_with("test-path/test-env.yaml", "r") - mock_kms_client.decrypt.assert_has_calls( - [ - mocker.call("test-encryption-1"), - mocker.call("test-encryption-2"), - mocker.call("test-encryption-3"), - mocker.call("test-encryption-4"), - ] + mock_file_open.assert_called_once_with('test-path/test-env.yaml', 'r') + mock_kms_client.decrypt.assert_has_calls([ + mocker.call('test-encryption-1'), mocker.call('test-encryption-2'), + mocker.call('test-encryption-3'), mocker.call('test-encryption-4')] ) mock_kms_client.close.assert_called_once() - assert os.environ["TEST_STRING"] == "string-variable" - assert os.environ["TEST_INT"] == "1" - assert os.environ["TEST_LIST"] == '["string-var", 2]' - assert os.environ["TEST_ENCRYPTED_VARIABLE_1"] == "test-decryption-1" - assert os.environ["TEST_ENCRYPTED_VARIABLE_2"] == "test-decryption-2" - assert ( - os.environ["TEST_ENCRYPTED_LIST"] - == '["test-decryption-3", "test-decryption-4"]' - ) + assert os.environ['TEST_STRING'] == 'string-variable' + assert os.environ['TEST_INT'] == '1' + assert os.environ['TEST_LIST'] == '["string-var", 2]' + assert os.environ['TEST_ENCRYPTED_VARIABLE_1'] == 'test-decryption-1' + assert os.environ['TEST_ENCRYPTED_VARIABLE_2'] == 'test-decryption-2' + assert os.environ['TEST_ENCRYPTED_LIST'] == \ + '["test-decryption-3", "test-decryption-4"]' for key in _TEST_VARIABLE_NAMES: if key in os.environ: @@ -76,9 +61,10 @@ def test_load_env_file(self, mocker): def test_missing_file_error(self): with pytest.raises(ConfigHelperError): - load_env_file("bad-env", "bad-path/{}.yaml") + load_env_file('bad-env', 'bad-path/{}.yaml') def test_bad_yaml(self, mocker): - mocker.patch("builtins.open", mocker.mock_open(read_data="bad yaml: [")) + mocker.patch( + 'builtins.open', mocker.mock_open(read_data='bad yaml: [')) with pytest.raises(ConfigHelperError): - load_env_file("test-env", "test-path/{}.not_yaml") + load_env_file('test-env', 'test-path/{}.not_yaml') diff --git a/tests/test_kinesis_client.py b/tests/test_kinesis_client.py index be43d55..820de77 100644 --- a/tests/test_kinesis_client.py +++ b/tests/test_kinesis_client.py @@ -1,119 +1,96 @@ import pytest from freezegun import freeze_time -from nypl_py_utils.classes.kinesis_client import KinesisClient, KinesisClientError +from nypl_py_utils.classes.kinesis_client import ( + KinesisClient, KinesisClientError) -_TEST_DATETIME_KEY = "1672531200000000000" +_TEST_DATETIME_KEY = '1672531200000000000' _TEST_KINESIS_RECORDS = [ - {"Data": b"a", "PartitionKey": _TEST_DATETIME_KEY}, - {"Data": b"b", "PartitionKey": _TEST_DATETIME_KEY}, - {"Data": b"c", "PartitionKey": _TEST_DATETIME_KEY}, - {"Data": b"d", "PartitionKey": _TEST_DATETIME_KEY}, - {"Data": b"e", "PartitionKey": _TEST_DATETIME_KEY}, + {'Data': b'a', 'PartitionKey': _TEST_DATETIME_KEY}, + {'Data': b'b', 'PartitionKey': _TEST_DATETIME_KEY}, + {'Data': b'c', 'PartitionKey': _TEST_DATETIME_KEY}, + {'Data': b'd', 'PartitionKey': _TEST_DATETIME_KEY}, + {'Data': b'e', 'PartitionKey': _TEST_DATETIME_KEY} ] -@freeze_time("2023-01-01") +@freeze_time('2023-01-01') class TestKinesisClient: @pytest.fixture def test_instance(self, mocker): - mocker.patch("boto3.client") - return KinesisClient("test_stream_arn", 2) + mocker.patch('boto3.client') + return KinesisClient('test_stream_arn', 2) def test_send_records(self, test_instance, mocker): - MOCK_RECORDS = [b"a", b"b", b"c", b"d", b"e"] + MOCK_RECORDS = [b'a', b'b', b'c', b'd', b'e'] mocked_send_method = mocker.patch( - "nypl_py_utils.classes.kinesis_client.KinesisClient._send_kinesis_format_records" - ) # noqa: E501 - mock_sleep = mocker.patch("time.sleep", return_value=None) + 'nypl_py_utils.classes.kinesis_client.KinesisClient._send_kinesis_format_records') # noqa: E501 + mock_sleep = mocker.patch('time.sleep', return_value=None) test_instance.send_records(MOCK_RECORDS) - mocked_send_method.assert_has_calls( - [ - mocker.call([_TEST_KINESIS_RECORDS[0], _TEST_KINESIS_RECORDS[1]], 1), - mocker.call([_TEST_KINESIS_RECORDS[2], _TEST_KINESIS_RECORDS[3]], 1), - mocker.call([_TEST_KINESIS_RECORDS[4]], 1), - ] - ) + mocked_send_method.assert_has_calls([ + mocker.call([_TEST_KINESIS_RECORDS[0], + _TEST_KINESIS_RECORDS[1]], 1), + mocker.call([_TEST_KINESIS_RECORDS[2], + _TEST_KINESIS_RECORDS[3]], 1), + mocker.call([_TEST_KINESIS_RECORDS[4]], 1)]) mock_sleep.assert_not_called() def test_send_records_with_pause(self, mocker): - mocker.patch("boto3.client") - test_instance = KinesisClient("test_stream_arn", 500) + mocker.patch('boto3.client') + test_instance = KinesisClient('test_stream_arn', 500) - MOCK_RECORDS = [b"a"] * 2200 + MOCK_RECORDS = [b'a'] * 2200 mocked_send_method = mocker.patch( - "nypl_py_utils.classes.kinesis_client.KinesisClient._send_kinesis_format_records" - ) # noqa: E501 - mock_sleep = mocker.patch("time.sleep", return_value=None) + 'nypl_py_utils.classes.kinesis_client.KinesisClient._send_kinesis_format_records') # noqa: E501 + mock_sleep = mocker.patch('time.sleep', return_value=None) test_instance.send_records(MOCK_RECORDS) - mocked_send_method.assert_has_calls( - [ - mocker.call([_TEST_KINESIS_RECORDS[0]] * 500, 1), - mocker.call([_TEST_KINESIS_RECORDS[0]] * 500, 1), - mocker.call([_TEST_KINESIS_RECORDS[0]] * 500, 1), - mocker.call([_TEST_KINESIS_RECORDS[0]] * 500, 1), - mocker.call([_TEST_KINESIS_RECORDS[0]] * 200, 1), - ] - ) + mocked_send_method.assert_has_calls([ + mocker.call([_TEST_KINESIS_RECORDS[0]]*500, 1), + mocker.call([_TEST_KINESIS_RECORDS[0]]*500, 1), + mocker.call([_TEST_KINESIS_RECORDS[0]]*500, 1), + mocker.call([_TEST_KINESIS_RECORDS[0]]*500, 1), + mocker.call([_TEST_KINESIS_RECORDS[0]]*200, 1)]) assert mock_sleep.call_count == 2 def test_send_kinesis_format_records(self, test_instance): - test_instance.kinesis_client.put_records.return_value = {"FailedRecordCount": 0} + test_instance.kinesis_client.put_records.return_value = { + 'FailedRecordCount': 0} test_instance._send_kinesis_format_records(_TEST_KINESIS_RECORDS, 1) test_instance.kinesis_client.put_records.assert_called_once_with( - Records=_TEST_KINESIS_RECORDS, StreamARN="test_stream_arn" - ) + Records=_TEST_KINESIS_RECORDS, StreamARN='test_stream_arn') - def test_send_kinesis_format_records_with_failures(self, test_instance, mocker): + def test_send_kinesis_format_records_with_failures( + self, test_instance, mocker): test_instance.kinesis_client.put_records.side_effect = [ - { - "FailedRecordCount": 2, - "Records": [ - "record0", - {"ErrorCode": 1}, - "record2", - {"ErrorCode": 3}, - "record4", - ], - }, - {"FailedRecordCount": 0}, - ] + {'FailedRecordCount': 2, 'Records': [ + 'record0', {'ErrorCode': 1}, + 'record2', {'ErrorCode': 3}, + 'record4']}, + {'FailedRecordCount': 0}] test_instance._send_kinesis_format_records(_TEST_KINESIS_RECORDS, 1) - test_instance.kinesis_client.put_records.assert_has_calls( - [ - mocker.call(Records=_TEST_KINESIS_RECORDS, StreamARN="test_stream_arn"), - mocker.call( - Records=[_TEST_KINESIS_RECORDS[1], _TEST_KINESIS_RECORDS[3]], - StreamARN="test_stream_arn", - ), - ] - ) + test_instance.kinesis_client.put_records.assert_has_calls([ + mocker.call(Records=_TEST_KINESIS_RECORDS, + StreamARN='test_stream_arn'), + mocker.call(Records=[_TEST_KINESIS_RECORDS[1], + _TEST_KINESIS_RECORDS[3]], + StreamARN='test_stream_arn')]) def test_send_kinesis_format_records_with_repeating_failures( - self, test_instance, mocker - ): + self, test_instance, mocker): test_instance.kinesis_client.put_records.side_effect = [ - { - "FailedRecordCount": 5, - "Records": [ - {"ErrorCode": 0}, - {"ErrorCode": 1}, - {"ErrorCode": 2}, - {"ErrorCode": 3}, - {"ErrorCode": 4}, - ], - } - ] * 5 + {'FailedRecordCount': 5, 'Records': [ + {'ErrorCode': 0}, {'ErrorCode': 1}, {'ErrorCode': 2}, + {'ErrorCode': 3}, {'ErrorCode': 4}]}] * 5 with pytest.raises(KinesisClientError): - test_instance._send_kinesis_format_records(_TEST_KINESIS_RECORDS, 1) + test_instance._send_kinesis_format_records( + _TEST_KINESIS_RECORDS, 1) - test_instance.kinesis_client.put_records.assert_has_calls( - [mocker.call(Records=_TEST_KINESIS_RECORDS, StreamARN="test_stream_arn")] - * 5 - ) + test_instance.kinesis_client.put_records.assert_has_calls([ + mocker.call(Records=_TEST_KINESIS_RECORDS, + StreamARN='test_stream_arn')] * 5) diff --git a/tests/test_kms_client.py b/tests/test_kms_client.py index fcc10a3..e500b03 100644 --- a/tests/test_kms_client.py +++ b/tests/test_kms_client.py @@ -3,12 +3,12 @@ from base64 import b64encode from nypl_py_utils.classes.kms_client import KmsClient, KmsClientError -_TEST_ENCRYPTED_VALUE = b64encode(b"test-encrypted-value") +_TEST_ENCRYPTED_VALUE = b64encode(b'test-encrypted-value') _TEST_DECRYPTION = { - "KeyId": "test-key-id", - "Plaintext": b"test-decrypted-value", - "EncryptionAlgorithm": "test-encryption-algorithm", - "ResponseMetadata": {}, + 'KeyId': 'test-key-id', + 'Plaintext': b'test-decrypted-value', + 'EncryptionAlgorithm': 'test-encryption-algorithm', + 'ResponseMetadata': {} } @@ -16,16 +16,16 @@ class TestKmsClient: @pytest.fixture def test_instance(self, mocker): - mocker.patch("boto3.client") + mocker.patch('boto3.client') return KmsClient() def test_decrypt(self, test_instance): test_instance.kms_client.decrypt.return_value = _TEST_DECRYPTION assert test_instance.kms_client.decrypt.called_once_with( - CiphertextBlob=b"test-encrypted-value" - ) - assert test_instance.decrypt(_TEST_ENCRYPTED_VALUE) == "test-decrypted-value" + CiphertextBlob=b'test-encrypted-value') + assert test_instance.decrypt( + _TEST_ENCRYPTED_VALUE) == 'test-decrypted-value' def test_base64_error(self, test_instance): with pytest.raises(KmsClientError): - test_instance.decrypt("bad-b64") + test_instance.decrypt('bad-b64') diff --git a/tests/test_log_helper.py b/tests/test_log_helper.py index 77c46b3..cf7f616 100644 --- a/tests/test_log_helper.py +++ b/tests/test_log_helper.py @@ -6,64 +6,56 @@ from nypl_py_utils.functions.log_helper import create_log -@freeze_time("2023-01-01 19:00:00") +@freeze_time('2023-01-01 19:00:00') class TestLogHelper: def test_default_logging(self, caplog): - logger = create_log("test_log") + logger = create_log('test_log') assert logger.getEffectiveLevel() == logging.INFO assert len(logger.handlers) == 1 - logger.info("Test info message") + logger.info('Test info message') # freeze_time changes the utc time, while the logger uses local time by # default, so force the logger to use utc time logger.handlers[0].formatter.converter = time.gmtime assert len(caplog.records) == 1 - assert ( - logger.handlers[0].format(caplog.records[0]) - == "2023-01-01 19:00:00,000 | test_log | INFO: Test info message" - ) + assert logger.handlers[0].format(caplog.records[0]) == \ + '2023-01-01 19:00:00,000 | test_log | INFO: Test info message' def test_logging_with_custom_log_level(self, caplog): - os.environ["LOG_LEVEL"] = "error" - logger = create_log("test_log") + os.environ['LOG_LEVEL'] = 'error' + logger = create_log('test_log') assert logger.getEffectiveLevel() == logging.ERROR - logger.info("Test info message") - logger.error("Test error message") + logger.info('Test info message') + logger.error('Test error message') assert len(caplog.records) == 1 # freeze_time changes the utc time, while the logger uses local time by # default, so force the logger to use utc time logger.handlers[0].formatter.converter = time.gmtime - assert ( - logger.handlers[0].format(caplog.records[0]) - == "2023-01-01 19:00:00,000 | test_log | ERROR: Test error message" - ) - del os.environ["LOG_LEVEL"] + assert logger.handlers[0].format(caplog.records[0]) == \ + '2023-01-01 19:00:00,000 | test_log | ERROR: Test error message' + del os.environ['LOG_LEVEL'] def test_logging_no_duplicates(self, caplog): - logger = create_log("test_log") - logger.info("Test info message") + logger = create_log('test_log') + logger.info('Test info message') # Test that logger uses the most recently set log level and doesn't # duplicate handlers/messages when create_log is called more than once. - os.environ["LOG_LEVEL"] = "error" - logger = create_log("test_log") + os.environ['LOG_LEVEL'] = 'error' + logger = create_log('test_log') assert logger.getEffectiveLevel() == logging.ERROR assert len(logger.handlers) == 1 - logger.info("Test info message 2") - logger.error("Test error message") + logger.info('Test info message 2') + logger.error('Test error message') assert len(caplog.records) == 2 # freeze_time changes the utc time, while the logger uses local time by # default, so force the logger to use utc time logger.handlers[0].formatter.converter = time.gmtime - assert ( - logger.handlers[0].format(caplog.records[0]) - == "2023-01-01 19:00:00,000 | test_log | INFO: Test info message" - ) - assert ( - logger.handlers[0].format(caplog.records[1]) - == "2023-01-01 19:00:00,000 | test_log | ERROR: Test error message" - ) - del os.environ["LOG_LEVEL"] + assert logger.handlers[0].format(caplog.records[0]) == \ + '2023-01-01 19:00:00,000 | test_log | INFO: Test info message' + assert logger.handlers[0].format(caplog.records[1]) == \ + '2023-01-01 19:00:00,000 | test_log | ERROR: Test error message' + del os.environ['LOG_LEVEL'] diff --git a/tests/test_mysql_client.py b/tests/test_mysql_client.py index b793006..11bb94f 100644 --- a/tests/test_mysql_client.py +++ b/tests/test_mysql_client.py @@ -7,34 +7,32 @@ class TestMySQLClient: @pytest.fixture def mock_mysql_conn(self, mocker): - return mocker.patch("mysql.connector.connect") + return mocker.patch('mysql.connector.connect') @pytest.fixture def test_instance(self): - return MySQLClient( - "test_host", "test_port", "test_database", "test_user", "test_password" - ) + return MySQLClient('test_host', 'test_port', 'test_database', + 'test_user', 'test_password') def test_connect(self, mock_mysql_conn, test_instance): test_instance.connect() - mock_mysql_conn.assert_called_once_with( - host="test_host", - port="test_port", - database="test_database", - user="test_user", - password="test_password", - ) + mock_mysql_conn.assert_called_once_with(host='test_host', + port='test_port', + database='test_database', + user='test_user', + password='test_password') def test_execute_read_query(self, mock_mysql_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() - mock_cursor.description = [("description", None, None)] - mock_cursor.fetchall.return_value = [(1, 2, 3), ("a", "b", "c")] + mock_cursor.description = [('description', None, None)] + mock_cursor.fetchall.return_value = [(1, 2, 3), ('a', 'b', 'c')] test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query("test query") == [(1, 2, 3), ("a", "b", "c")] - mock_cursor.execute.assert_called_once_with("test query", None) + assert test_instance.execute_query( + 'test query') == [(1, 2, 3), ('a', 'b', 'c')] + mock_cursor.execute.assert_called_once_with('test query', None) test_instance.conn.commit.assert_not_called() mock_cursor.close.assert_called_once() @@ -45,29 +43,28 @@ def test_execute_write_query(self, mock_mysql_conn, test_instance, mocker): mock_cursor.description = None test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query("test query") is None - mock_cursor.execute.assert_called_once_with("test query", None) + assert test_instance.execute_query('test query') is None + mock_cursor.execute.assert_called_once_with('test query', None) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_write_query_with_params( - self, mock_mysql_conn, test_instance, mocker - ): + def test_execute_write_query_with_params(self, mock_mysql_conn, + test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() mock_cursor.description = None test_instance.conn.cursor.return_value = mock_cursor - assert ( - test_instance.execute_query("test query %s %s", query_params=("a", 1)) - is None - ) - mock_cursor.execute.assert_called_once_with("test query %s %s", ("a", 1)) + assert test_instance.execute_query( + 'test query %s %s', query_params=('a', 1)) is None + mock_cursor.execute.assert_called_once_with('test query %s %s', + ('a', 1)) test_instance.conn.commit.called_once() mock_cursor.close.assert_called_once() - def test_execute_query_with_exception(self, mock_mysql_conn, test_instance, mocker): + def test_execute_query_with_exception( + self, mock_mysql_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() @@ -75,7 +72,7 @@ def test_execute_query_with_exception(self, mock_mysql_conn, test_instance, mock test_instance.conn.cursor.return_value = mock_cursor with pytest.raises(MySQLClientError): - test_instance.execute_query("test query") + test_instance.execute_query('test query') test_instance.conn.rollback.assert_called_once() mock_cursor.close.assert_called_once() diff --git a/tests/test_oauth2_api_client.py b/tests/test_oauth2_api_client.py index b6377d9..b5fcd5c 100644 --- a/tests/test_oauth2_api_client.py +++ b/tests/test_oauth2_api_client.py @@ -5,21 +5,19 @@ from requests_oauthlib import OAuth2Session from requests import HTTPError, JSONDecodeError, Response -from nypl_py_utils.classes.oauth2_api_client import ( - Oauth2ApiClient, - Oauth2ApiClientError, -) +from nypl_py_utils.classes.oauth2_api_client import (Oauth2ApiClient, + Oauth2ApiClientError) _TOKEN_RESPONSE = { - "access_token": "super-secret-token", - "expires_in": 1, - "token_type": "Bearer", - "scope": ["offline_access", "openid", "login:staff", "admin"], - "id_token": "super-secret-token", + 'access_token': 'super-secret-token', + 'expires_in': 1, + 'token_type': 'Bearer', + 'scope': ['offline_access', 'openid', 'login:staff', 'admin'], + 'id_token': 'super-secret-token' } -BASE_URL = "https://example.com/api/v0.1" -TOKEN_URL = "https://oauth.example.com/oauth/token" +BASE_URL = 'https://example.com/api/v0.1' +TOKEN_URL = 'https://oauth.example.com/oauth/token' class MockEmptyResponse: @@ -32,7 +30,7 @@ def json(self): if self.empty: raise JSONDecodeError else: - return "success" + return 'success' class TestOauth2ApiClient: @@ -45,112 +43,102 @@ def token_server_post(self, requests_mock): @pytest.fixture def test_instance(self, requests_mock): - return Oauth2ApiClient( - base_url=BASE_URL, - token_url=TOKEN_URL, - client_id="clientid", - client_secret="clientsecret", - ) + return Oauth2ApiClient(base_url=BASE_URL, + token_url=TOKEN_URL, + client_id='clientid', + client_secret='clientsecret' + ) @pytest.fixture def test_instance_with_retries(self, requests_mock): - return Oauth2ApiClient( - base_url=BASE_URL, - token_url=TOKEN_URL, - client_id="clientid", - client_secret="clientsecret", - with_retries=True, - ) + return Oauth2ApiClient(base_url=BASE_URL, + token_url=TOKEN_URL, + client_id='clientid', + client_secret='clientsecret', + with_retries=True + ) def test_uses_env_vars(self): env = { - "NYPL_API_CLIENT_ID": "env client id", - "NYPL_API_CLIENT_SECRET": "env client secret", - "NYPL_API_TOKEN_URL": "env token url", - "NYPL_API_BASE_URL": "env base url", + 'NYPL_API_CLIENT_ID': 'env client id', + 'NYPL_API_CLIENT_SECRET': 'env client secret', + 'NYPL_API_TOKEN_URL': 'env token url', + 'NYPL_API_BASE_URL': 'env base url' } for key, value in env.items(): os.environ[key] = value client = Oauth2ApiClient() - assert client.client_id == "env client id" - assert client.client_secret == "env client secret" - assert client.token_url == "env token url" - assert client.base_url == "env base url" + assert client.client_id == 'env client id' + assert client.client_secret == 'env client secret' + assert client.token_url == 'env token url' + assert client.base_url == 'env base url' for key, value in env.items(): - os.environ[key] = "" + os.environ[key] = '' def test_generate_access_token(self, test_instance, token_server_post): test_instance._create_oauth_client() test_instance._generate_access_token() - assert ( - test_instance.oauth_client.token["access_token"] - == _TOKEN_RESPONSE["access_token"] - ) + assert test_instance.oauth_client.token['access_token']\ + == _TOKEN_RESPONSE['access_token'] def test_create_oauth_client(self, token_server_post, test_instance): test_instance._create_oauth_client() assert type(test_instance.oauth_client) is OAuth2Session - def test_do_http_method(self, requests_mock, token_server_post, test_instance): - requests_mock.get(f"{BASE_URL}/foo", json={"foo": "bar"}) + def test_do_http_method(self, requests_mock, token_server_post, + test_instance): + requests_mock.get(f'{BASE_URL}/foo', json={'foo': 'bar'}) - requests_mock.get(f"{BASE_URL}/foo", json={"foo": "bar"}) - resp = test_instance._do_http_method("GET", "foo") + requests_mock.get(f'{BASE_URL}/foo', json={'foo': 'bar'}) + resp = test_instance._do_http_method('GET', 'foo') assert resp.status_code == 200 - assert resp.json() == {"foo": "bar"} + assert resp.json() == {'foo': 'bar'} - def test_token_expiration( - self, requests_mock, test_instance, token_server_post, mocker - ): - api_get_mock = requests_mock.get(f"{BASE_URL}/foo", json={"foo": "bar"}) + def test_token_expiration(self, requests_mock, test_instance, + token_server_post, mocker): + api_get_mock = requests_mock.get(f'{BASE_URL}/foo', + json={'foo': 'bar'}) # Perform first request: - test_instance._do_http_method("GET", "foo") + test_instance._do_http_method('GET', 'foo') # Expect this first call triggered a single token server call: assert len(token_server_post.request_history) == 1 # And the GET request used the supplied Bearer token: - assert ( - api_get_mock.request_history[0]._request.headers["Authorization"] - == "Bearer super-secret-token" - ) + assert api_get_mock.request_history[0]._request\ + .headers['Authorization'] == 'Bearer super-secret-token' # The token obtained above expires in 1s, so wait out expiration: time.sleep(1.1) # Register new token response: second_token_response = dict(_TOKEN_RESPONSE) - second_token_response["id_token"] = "super-secret-second-token" - second_token_response["access_token"] = "super-secret-second-token" - second_token_server_post = requests_mock.post( - TOKEN_URL, text=json.dumps(second_token_response) - ) + second_token_response['id_token'] = 'super-secret-second-token' + second_token_response['access_token'] = 'super-secret-second-token' + second_token_server_post = requests_mock\ + .post(TOKEN_URL, text=json.dumps(second_token_response)) # Perform second request: - response = test_instance._do_http_method("GET", "foo") + response = test_instance._do_http_method('GET', 'foo') # Ensure we still return a plain requests Response object assert isinstance(response, Response) assert response.json() == {"foo": "bar"} # Expect a call on the second token server: assert len(second_token_server_post.request_history) == 1 # Expect the second GET request to carry the new Bearer token: - assert ( - api_get_mock.request_history[1]._request.headers["Authorization"] - == "Bearer super-secret-second-token" - ) + assert api_get_mock.request_history[1]._request\ + .headers['Authorization'] == 'Bearer super-secret-second-token' - def test_error_status_raises_error( - self, requests_mock, test_instance, token_server_post - ): - requests_mock.get(f"{BASE_URL}/foo", status_code=400) + def test_error_status_raises_error(self, requests_mock, test_instance, + token_server_post): + requests_mock.get(f'{BASE_URL}/foo', status_code=400) with pytest.raises(HTTPError): - test_instance._do_http_method("GET", "foo") + test_instance._do_http_method('GET', 'foo') def test_token_refresh_failure_raises_error( - self, requests_mock, test_instance, token_server_post - ): + self, requests_mock, test_instance, token_server_post): """ Failure to fetch a token can raise a number of errors including: - requests.exceptions.HTTPError for invalid access_token @@ -162,57 +150,46 @@ def test_token_refresh_failure_raises_error( a new valid token in response to token expiration. This test asserts that the client will not allow more than successive 3 retries. """ - requests_mock.get(f"{BASE_URL}/foo", json={"foo": "bar"}) + requests_mock.get(f'{BASE_URL}/foo', json={'foo': 'bar'}) token_response = dict(_TOKEN_RESPONSE) - token_response["expires_in"] = 0 - token_server_post = requests_mock.post( - TOKEN_URL, text=json.dumps(token_response) - ) + token_response['expires_in'] = 0 + token_server_post = requests_mock\ + .post(TOKEN_URL, text=json.dumps(token_response)) with pytest.raises(Oauth2ApiClientError): - test_instance._do_http_method("GET", "foo") + test_instance._do_http_method('GET', 'foo') # Expect 1 initial token fetch, plus 3 retries: assert len(token_server_post.request_history) == 4 - def test_bad_response_no_retries(self, requests_mock, test_instance, mocker): - mocker.patch.object( - test_instance, "_do_http_method", return_value=MockEmptyResponse(empty=True) - ) - get_spy = mocker.spy(test_instance, "get") - resp = test_instance.get("spaghetti") + def test_bad_response_no_retries(self, requests_mock, test_instance, + mocker): + mocker.patch.object(test_instance, '_do_http_method', + return_value=MockEmptyResponse(empty=True)) + get_spy = mocker.spy(test_instance, 'get') + resp = test_instance.get('spaghetti') assert get_spy.call_count == 1 assert resp.status_code == 500 - assert resp.message == "Oauth2 Client: Bad response from OauthClient" - - def test_http_retry_fail(self, requests_mock, test_instance_with_retries, mocker): - mocker.patch.object( - test_instance_with_retries, - "_do_http_method", - return_value=MockEmptyResponse(empty=True), - ) - get_spy = mocker.spy(test_instance_with_retries, "get") - resp = test_instance_with_retries.get("spaghetti") + assert resp.message == 'Oauth2 Client: Bad response from OauthClient' + + def test_http_retry_fail(self, requests_mock, test_instance_with_retries, + mocker): + mocker.patch.object(test_instance_with_retries, '_do_http_method', + return_value=MockEmptyResponse(empty=True)) + get_spy = mocker.spy(test_instance_with_retries, 'get') + resp = test_instance_with_retries.get('spaghetti') assert get_spy.call_count == 3 assert resp.status_code == 500 - assert ( - resp.message - == "Oauth2 Client: Request failed after 3 \ - empty responses received from Oauth2 Client" - ) - - def test_http_retry_success( - self, requests_mock, test_instance_with_retries, mocker - ): - mocker.patch.object( - test_instance_with_retries, - "_do_http_method", - side_effect=[ - MockEmptyResponse(empty=True), - MockEmptyResponse(empty=False, status_code=200), - ], - ) - get_spy = mocker.spy(test_instance_with_retries, "get") - resp = test_instance_with_retries.get("spaghetti") + assert resp.message == 'Oauth2 Client: Request failed after 3 \ + empty responses received from Oauth2 Client' + + def test_http_retry_success(self, requests_mock, + test_instance_with_retries, mocker): + mocker.patch.object(test_instance_with_retries, '_do_http_method', + side_effect=[MockEmptyResponse(empty=True), + MockEmptyResponse(empty=False, + status_code=200)]) + get_spy = mocker.spy(test_instance_with_retries, 'get') + resp = test_instance_with_retries.get('spaghetti') assert get_spy.call_count == 2 - assert resp.json() == "success" + assert resp.json() == 'success' diff --git a/tests/test_obfuscation_helper.py b/tests/test_obfuscation_helper.py index 112785f..ed76261 100644 --- a/tests/test_obfuscation_helper.py +++ b/tests/test_obfuscation_helper.py @@ -1,20 +1,19 @@ import os -from nypl_py_utils.functions.obfuscation_helper import obfuscate, obfuscate_with_salt +from nypl_py_utils.functions.obfuscation_helper import (obfuscate, + obfuscate_with_salt) -_TEST_SALT_1 = "$2a$10$8AvAPrrUsmlBa50qgc683e" -_TEST_SALT_2 = "$2b$12$iuSSdD6F/nJ1GSXzesM8sO" +_TEST_SALT_1 = '$2a$10$8AvAPrrUsmlBa50qgc683e' +_TEST_SALT_2 = '$2b$12$iuSSdD6F/nJ1GSXzesM8sO' class TestObfuscationHelper: def test_obfuscation_with_environment_variable(self): - os.environ["BCRYPT_SALT"] = _TEST_SALT_1 - assert obfuscate("test_input") == "UPMawmdZfleeSg5REsZbLbAivWl97O6" - del os.environ["BCRYPT_SALT"] + os.environ['BCRYPT_SALT'] = _TEST_SALT_1 + assert obfuscate('test_input') == 'UPMawmdZfleeSg5REsZbLbAivWl97O6' + del os.environ['BCRYPT_SALT'] def test_obfuscation_with_custom_salt(self): - assert ( - obfuscate_with_salt("test_input", _TEST_SALT_2) - == "SUXLCHnsRVt4Vj1PyP9KPEqADxtUj5." - ) + assert (obfuscate_with_salt('test_input', _TEST_SALT_2) == + 'SUXLCHnsRVt4Vj1PyP9KPEqADxtUj5.') diff --git a/tests/test_postgresql_client.py b/tests/test_postgresql_client.py index 2c32827..99e5042 100644 --- a/tests/test_postgresql_client.py +++ b/tests/test_postgresql_client.py @@ -1,40 +1,38 @@ import pytest from nypl_py_utils.classes.postgresql_client import ( - PostgreSQLClient, - PostgreSQLClientError, -) + PostgreSQLClient, PostgreSQLClientError) class TestPostgreSQLClient: @pytest.fixture def mock_pg_conn(self, mocker): - return mocker.patch("psycopg.connect") + return mocker.patch('psycopg.connect') @pytest.fixture def test_instance(self): - return PostgreSQLClient( - "test_host", "test_port", "test_db_name", "test_user", "test_password" - ) + return PostgreSQLClient('test_host', 'test_port', 'test_db_name', + 'test_user', 'test_password') def test_connect(self, mock_pg_conn, test_instance): test_instance.connect() mock_pg_conn.assert_called_once_with( - "postgresql://test_user:test_password@test_host:test_port/" + "test_db_name" - ) + 'postgresql://test_user:test_password@test_host:test_port/' + + 'test_db_name') def test_execute_read_query(self, mock_pg_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() - mock_cursor.description = [("description", None, None)] + mock_cursor.description = [('description', None, None)] mock_cursor.execute.return_value = mock_cursor - mock_cursor.fetchall.return_value = [(1, 2, 3), ("a", "b", "c")] + mock_cursor.fetchall.return_value = [(1, 2, 3), ('a', 'b', 'c')] test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query("test query") == [(1, 2, 3), ("a", "b", "c")] - mock_cursor.execute.assert_called_once_with("test query", None) + assert test_instance.execute_query( + 'test query') == [(1, 2, 3), ('a', 'b', 'c')] + mock_cursor.execute.assert_called_once_with('test query', None) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() @@ -45,27 +43,28 @@ def test_execute_write_query(self, mock_pg_conn, test_instance, mocker): mock_cursor.description = None test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query("test query") is None - mock_cursor.execute.assert_called_once_with("test query", None) + assert test_instance.execute_query('test query') is None + mock_cursor.execute.assert_called_once_with('test query', None) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_write_query_with_params(self, mock_pg_conn, test_instance, mocker): + def test_execute_write_query_with_params(self, mock_pg_conn, test_instance, + mocker): test_instance.connect() mock_cursor = mocker.MagicMock() mock_cursor.description = None test_instance.conn.cursor.return_value = mock_cursor - assert ( - test_instance.execute_query("test query %s %s", query_params=("a", 1)) - is None - ) - mock_cursor.execute.assert_called_once_with("test query %s %s", ("a", 1)) + assert test_instance.execute_query( + 'test query %s %s', query_params=('a', 1)) is None + mock_cursor.execute.assert_called_once_with('test query %s %s', + ('a', 1)) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_query_with_exception(self, mock_pg_conn, test_instance, mocker): + def test_execute_query_with_exception( + self, mock_pg_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() @@ -73,7 +72,7 @@ def test_execute_query_with_exception(self, mock_pg_conn, test_instance, mocker) test_instance.conn.cursor.return_value = mock_cursor with pytest.raises(PostgreSQLClientError): - test_instance.execute_query("test query") + test_instance.execute_query('test query') test_instance.conn.rollback.assert_called_once() mock_cursor.close.assert_called_once() diff --git a/tests/test_postgresql_pool_client.py b/tests/test_postgresql_pool_client.py index 5a57808..82f22b6 100644 --- a/tests/test_postgresql_pool_client.py +++ b/tests/test_postgresql_pool_client.py @@ -1,9 +1,7 @@ import pytest from nypl_py_utils.classes.postgresql_pool_client import ( - PostgreSQLPoolClient, - PostgreSQLPoolClientError, -) + PostgreSQLPoolClient, PostgreSQLPoolClientError) from psycopg import Error @@ -11,16 +9,15 @@ class TestPostgreSQLPoolClient: @pytest.fixture def test_instance(self, mocker): - mocker.patch("psycopg_pool.ConnectionPool.open") - mocker.patch("psycopg_pool.ConnectionPool.close") - return PostgreSQLPoolClient( - "test_host", "test_port", "test_db_name", "test_user", "test_password" - ) + mocker.patch('psycopg_pool.ConnectionPool.open') + mocker.patch('psycopg_pool.ConnectionPool.close') + return PostgreSQLPoolClient('test_host', 'test_port', 'test_db_name', + 'test_user', 'test_password') def test_init(self, test_instance): assert test_instance.pool.conninfo == ( - "postgresql://test_user:test_password@test_host:test_port/" + "test_db_name" - ) + 'postgresql://test_user:test_password@test_host:test_port/' + + 'test_db_name') assert test_instance.pool._opened is False assert test_instance.pool.min_size == 0 assert test_instance.pool.max_size == 1 @@ -28,24 +25,21 @@ def test_init(self, test_instance): def test_init_with_long_max_idle(self): with pytest.raises(PostgreSQLPoolClientError): PostgreSQLPoolClient( - "test_host", - "test_port", - "test_db_name", - "test_user", - "test_password", - max_idle=300.0, - ) + 'test_host', 'test_port', 'test_db_name', 'test_user', + 'test_password', max_idle=300.0) def test_connect(self, test_instance): test_instance.connect() - test_instance.pool.open.assert_called_once_with(wait=True, timeout=300.0) + test_instance.pool.open.assert_called_once_with(wait=True, + timeout=300.0) def test_connect_with_exception(self, mocker): - mocker.patch("psycopg_pool.ConnectionPool.open", side_effect=Error()) + mocker.patch('psycopg_pool.ConnectionPool.open', + side_effect=Error()) test_instance = PostgreSQLPoolClient( - "test_host", "test_port", "test_db_name", "test_user", "test_password" - ) + 'test_host', 'test_port', 'test_db_name', 'test_user', + 'test_password') with pytest.raises(PostgreSQLPoolClientError): test_instance.connect(timeout=1.0) @@ -54,18 +48,18 @@ def test_execute_read_query(self, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() - mock_cursor.description = [("description", None, None)] - mock_cursor.fetchall.return_value = [(1, 2, 3), ("a", "b", "c")] + mock_cursor.description = [('description', None, None)] + mock_cursor.fetchall.return_value = [(1, 2, 3), ('a', 'b', 'c')] mock_conn = mocker.MagicMock() mock_conn.execute.return_value = mock_cursor mock_conn_context = mocker.MagicMock() mock_conn_context.__enter__.return_value = mock_conn - mocker.patch( - "psycopg_pool.ConnectionPool.connection", return_value=mock_conn_context - ) + mocker.patch('psycopg_pool.ConnectionPool.connection', + return_value=mock_conn_context) - assert test_instance.execute_query("test query") == [(1, 2, 3), ("a", "b", "c")] - mock_conn.execute.assert_called_once_with("test query", None) + assert test_instance.execute_query( + 'test query') == [(1, 2, 3), ('a', 'b', 'c')] + mock_conn.execute.assert_called_once_with('test query', None) mock_cursor.fetchall.assert_called_once() def test_execute_write_query(self, test_instance, mocker): @@ -77,12 +71,11 @@ def test_execute_write_query(self, test_instance, mocker): mock_conn.execute.return_value = mock_cursor mock_conn_context = mocker.MagicMock() mock_conn_context.__enter__.return_value = mock_conn - mocker.patch( - "psycopg_pool.ConnectionPool.connection", return_value=mock_conn_context - ) + mocker.patch('psycopg_pool.ConnectionPool.connection', + return_value=mock_conn_context) - assert test_instance.execute_query("test query") is None - mock_conn.execute.assert_called_once_with("test query", None) + assert test_instance.execute_query('test query') is None + mock_conn.execute.assert_called_once_with('test query', None) def test_execute_write_query_with_params(self, test_instance, mocker): test_instance.connect() @@ -93,15 +86,13 @@ def test_execute_write_query_with_params(self, test_instance, mocker): mock_conn.execute.return_value = mock_cursor mock_conn_context = mocker.MagicMock() mock_conn_context.__enter__.return_value = mock_conn - mocker.patch( - "psycopg_pool.ConnectionPool.connection", return_value=mock_conn_context - ) + mocker.patch('psycopg_pool.ConnectionPool.connection', + return_value=mock_conn_context) - assert ( - test_instance.execute_query("test query %s %s", query_params=("a", 1)) - is None - ) - mock_conn.execute.assert_called_once_with("test query %s %s", ("a", 1)) + assert test_instance.execute_query( + 'test query %s %s', query_params=('a', 1)) is None + mock_conn.execute.assert_called_once_with('test query %s %s', + ('a', 1)) def test_execute_query_with_exception(self, test_instance, mocker): test_instance.connect() @@ -110,12 +101,11 @@ def test_execute_query_with_exception(self, test_instance, mocker): mock_conn.execute.side_effect = Exception() mock_conn_context = mocker.MagicMock() mock_conn_context.__enter__.return_value = mock_conn - mocker.patch( - "psycopg_pool.ConnectionPool.connection", return_value=mock_conn_context - ) + mocker.patch('psycopg_pool.ConnectionPool.connection', + return_value=mock_conn_context) with pytest.raises(PostgreSQLPoolClientError): - test_instance.execute_query("test query") + test_instance.execute_query('test query') def test_close_pool(self, test_instance): test_instance.connect() @@ -126,6 +116,6 @@ def test_reopen_pool(self, test_instance, mocker): test_instance.connect() test_instance.close_pool() test_instance.connect() - test_instance.pool.open.assert_has_calls( - [mocker.call(wait=True, timeout=300), mocker.call(wait=True, timeout=300)] - ) + test_instance.pool.open.assert_has_calls([ + mocker.call(wait=True, timeout=300), + mocker.call(wait=True, timeout=300)]) diff --git a/tests/test_redshift_client.py b/tests/test_redshift_client.py index d60b85a..7d6219d 100644 --- a/tests/test_redshift_client.py +++ b/tests/test_redshift_client.py @@ -1,56 +1,55 @@ import pytest -from nypl_py_utils.classes.redshift_client import RedshiftClient, RedshiftClientError +from nypl_py_utils.classes.redshift_client import ( + RedshiftClient, RedshiftClientError) class TestRedshiftClient: @pytest.fixture def mock_redshift_conn(self, mocker): - return mocker.patch("redshift_connector.connect") + return mocker.patch('redshift_connector.connect') @pytest.fixture def test_instance(self): - return RedshiftClient( - "test_host", "test_database", "test_user", "test_password" - ) + return RedshiftClient('test_host', 'test_database', 'test_user', + 'test_password') def test_connect(self, mock_redshift_conn, test_instance): test_instance.connect() - mock_redshift_conn.assert_called_once_with( - host="test_host", - database="test_database", - user="test_user", - password="test_password", - sslmode="verify-full", - ) + mock_redshift_conn.assert_called_once_with(host='test_host', + database='test_database', + user='test_user', + password='test_password', + sslmode='verify-full') def test_execute_query(self, mock_redshift_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() - mock_cursor.fetchall.return_value = [[1, 2, 3], ["a", "b", "c"]] + mock_cursor.fetchall.return_value = [[1, 2, 3], ['a', 'b', 'c']] test_instance.conn.cursor.return_value = mock_cursor - assert test_instance.execute_query("test query") == [[1, 2, 3], ["a", "b", "c"]] - mock_cursor.execute.assert_called_once_with("test query") + assert test_instance.execute_query( + 'test query') == [[1, 2, 3], ['a', 'b', 'c']] + mock_cursor.execute.assert_called_once_with('test query') mock_cursor.fetchall.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_dataframe_query(self, mock_redshift_conn, test_instance, mocker): + def test_execute_dataframe_query(self, mock_redshift_conn, test_instance, + mocker): test_instance.connect() mock_cursor = mocker.MagicMock() test_instance.conn.cursor.return_value = mock_cursor - test_instance.execute_query("test query", dataframe=True) - mock_cursor.execute.assert_called_once_with("test query") + test_instance.execute_query('test query', dataframe=True) + mock_cursor.execute.assert_called_once_with('test query') mock_cursor.fetch_dataframe.assert_called_once() mock_cursor.close.assert_called_once() def test_execute_query_with_exception( - self, mock_redshift_conn, test_instance, mocker - ): + self, mock_redshift_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() @@ -58,66 +57,52 @@ def test_execute_query_with_exception( test_instance.conn.cursor.return_value = mock_cursor with pytest.raises(RedshiftClientError): - test_instance.execute_query("test query") + test_instance.execute_query('test query') test_instance.conn.rollback.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_transaction(self, mock_redshift_conn, test_instance, mocker): + def test_execute_transaction(self, mock_redshift_conn, test_instance, + mocker): test_instance.connect() mock_cursor = mocker.MagicMock() test_instance.conn.cursor.return_value = mock_cursor - test_instance.execute_transaction( - [("query 1", None), ("query 2 %s %s", ("a", 1))] - ) - mock_cursor.execute.assert_has_calls( - [ - mocker.call("BEGIN TRANSACTION;"), - mocker.call("query 1", None), - mocker.call("query 2 %s %s", ("a", 1)), - mocker.call("END TRANSACTION;"), - ] - ) + test_instance.execute_transaction([('query 1', None), + ('query 2 %s %s', ('a', 1))]) + mock_cursor.execute.assert_has_calls([ + mocker.call('BEGIN TRANSACTION;'), + mocker.call('query 1', None), + mocker.call('query 2 %s %s', ('a', 1)), + mocker.call('END TRANSACTION;')]) mock_cursor.executemany.assert_not_called() test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() - def test_execute_transaction_with_many( - self, mock_redshift_conn, test_instance, mocker - ): + def test_execute_transaction_with_many(self, mock_redshift_conn, + test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() test_instance.conn.cursor.return_value = mock_cursor - test_instance.execute_transaction( - [ - ("query 1", None), - ("query 2 %s %s", (None, 1)), - ("query 3 %s %s", [(None, 10), ("b", 20)]), - ("query 4", None), - ] - ) - mock_cursor.execute.assert_has_calls( - [ - mocker.call("BEGIN TRANSACTION;"), - mocker.call("query 1", None), - mocker.call("query 2 %s %s", (None, 1)), - mocker.call("query 4", None), - mocker.call("END TRANSACTION;"), - ] - ) + test_instance.execute_transaction([ + ('query 1', None), ('query 2 %s %s', (None, 1)), + ('query 3 %s %s', [(None, 10), ('b', 20)]), ('query 4', None)]) + mock_cursor.execute.assert_has_calls([ + mocker.call('BEGIN TRANSACTION;'), + mocker.call('query 1', None), + mocker.call('query 2 %s %s', (None, 1)), + mocker.call('query 4', None), + mocker.call('END TRANSACTION;')]) mock_cursor.executemany.assert_called_once_with( - "query 3 %s %s", [(None, 10), ("b", 20)] - ) + 'query 3 %s %s', [(None, 10), ('b', 20)]) test_instance.conn.commit.assert_called_once() mock_cursor.close.assert_called_once() def test_execute_transaction_with_exception( - self, mock_redshift_conn, test_instance, mocker - ): + self, mock_redshift_conn, test_instance, mocker): test_instance.connect() mock_cursor = mocker.MagicMock() @@ -125,15 +110,13 @@ def test_execute_transaction_with_exception( test_instance.conn.cursor.return_value = mock_cursor with pytest.raises(RedshiftClientError): - test_instance.execute_transaction([("query 1", None), ("query 2", None)]) - - mock_cursor.execute.assert_has_calls( - [ - mocker.call("BEGIN TRANSACTION;"), - mocker.call("query 1", None), - mocker.call("query 2", None), - ] - ) + test_instance.execute_transaction( + [('query 1', None), ('query 2', None)]) + + mock_cursor.execute.assert_has_calls([ + mocker.call('BEGIN TRANSACTION;'), + mocker.call('query 1', None), + mocker.call('query 2', None)]) test_instance.conn.commit.assert_not_called() test_instance.conn.rollback.assert_called_once() mock_cursor.close.assert_called_once() diff --git a/tests/test_research_catalog_identifier_helper.py b/tests/test_research_catalog_identifier_helper.py index fafbc88..bf7686f 100644 --- a/tests/test_research_catalog_identifier_helper.py +++ b/tests/test_research_catalog_identifier_helper.py @@ -2,29 +2,31 @@ import json from nypl_py_utils.functions.research_catalog_identifier_helper import ( - parse_research_catalog_identifier, - research_catalog_id_prefix, - ResearchCatalogIdentifierError, -) + parse_research_catalog_identifier, research_catalog_id_prefix, + ResearchCatalogIdentifierError) _TEST_MAPPING = { - "sierra-nypl": { - "organization": "nyplOrg:0001", - "bibPrefix": "b", - "holdingPrefix": "h", - "itemPrefix": "i", - }, - "recap-pul": { - "organization": "nyplOrg:0003", - "bibPrefix": "pb", - "itemPrefix": "pi", - }, - "recap-cul": { - "organization": "nyplOrg:0002", - "bibPrefix": "cb", - "itemPrefix": "ci", - }, - "recap-hl": {"organization": "nyplOrg:0004", "bibPrefix": "hb", "itemPrefix": "hi"}, + 'sierra-nypl': { + 'organization': 'nyplOrg:0001', + 'bibPrefix': 'b', + 'holdingPrefix': 'h', + 'itemPrefix': 'i' + }, + 'recap-pul': { + 'organization': 'nyplOrg:0003', + 'bibPrefix': 'pb', + 'itemPrefix': 'pi' + }, + 'recap-cul': { + 'organization': 'nyplOrg:0002', + 'bibPrefix': 'cb', + 'itemPrefix': 'ci' + }, + 'recap-hl': { + 'organization': 'nyplOrg:0004', + 'bibPrefix': 'hb', + 'itemPrefix': 'hi' + } } @@ -32,51 +34,39 @@ class TestResearchCatalogIdentifierHelper: @pytest.fixture(autouse=True) def test_instance(self, requests_mock): requests_mock.get( - "https://raw.githubusercontent.com/NYPL/nypl-core/master/mappings/recap-discovery/nypl-source-mapping.json", # noqa - text=json.dumps(_TEST_MAPPING), - ) + 'https://raw.githubusercontent.com/NYPL/nypl-core/master/mappings/recap-discovery/nypl-source-mapping.json', # noqa + text=json.dumps(_TEST_MAPPING)) def test_parse_research_catalog_identifier_parses_valid(self): - assert parse_research_catalog_identifier("b1234") == { - "id": "1234", - "nyplSource": "sierra-nypl", - "nyplType": "bib", - } - assert parse_research_catalog_identifier("cb1234") == { - "id": "1234", - "nyplSource": "recap-cul", - "nyplType": "bib", - } - assert parse_research_catalog_identifier("pi1234") == { - "id": "1234", - "nyplSource": "recap-pul", - "nyplType": "item", - } - assert parse_research_catalog_identifier("h1234") == { - "id": "1234", - "nyplSource": "sierra-nypl", - "nyplType": "holding", - } + assert parse_research_catalog_identifier('b1234') == \ + {'id': '1234', 'nyplSource': 'sierra-nypl', 'nyplType': 'bib'} + assert parse_research_catalog_identifier('cb1234') == \ + {'id': '1234', 'nyplSource': 'recap-cul', 'nyplType': 'bib'} + assert parse_research_catalog_identifier('pi1234') == \ + {'id': '1234', 'nyplSource': 'recap-pul', 'nyplType': 'item'} + assert parse_research_catalog_identifier('h1234') == \ + {'id': '1234', 'nyplSource': 'sierra-nypl', + 'nyplType': 'holding'} def test_parse_research_catalog_identifier_fails_nonsense(self): - for invalidIdentifier in [None, 1234, "z1234", "1234"]: + for invalidIdentifier in [None, 1234, 'z1234', '1234']: with pytest.raises(ResearchCatalogIdentifierError): parse_research_catalog_identifier(invalidIdentifier) def test_research_catalog_id_prefix_parses_valid(self, mocker): - assert research_catalog_id_prefix("sierra-nypl") == "b" - assert research_catalog_id_prefix("sierra-nypl", "bib") == "b" - assert research_catalog_id_prefix("sierra-nypl", "item") == "i" - assert research_catalog_id_prefix("sierra-nypl", "holding") == "h" - assert research_catalog_id_prefix("recap-pul", "bib") == "pb" - assert research_catalog_id_prefix("recap-hl", "bib") == "hb" - assert research_catalog_id_prefix("recap-hl", "item") == "hi" - assert research_catalog_id_prefix("recap-pul", "item") == "pi" + assert research_catalog_id_prefix('sierra-nypl') == 'b' + assert research_catalog_id_prefix('sierra-nypl', 'bib') == 'b' + assert research_catalog_id_prefix('sierra-nypl', 'item') == 'i' + assert research_catalog_id_prefix('sierra-nypl', 'holding') == 'h' + assert research_catalog_id_prefix('recap-pul', 'bib') == 'pb' + assert research_catalog_id_prefix('recap-hl', 'bib') == 'hb' + assert research_catalog_id_prefix('recap-hl', 'item') == 'hi' + assert research_catalog_id_prefix('recap-pul', 'item') == 'pi' def test_research_catalog_id_prefix_fails_nonsense(self, mocker): - for invalidSource in ["sierra-cul", None, "recap-nypl"]: + for invalidSource in ['sierra-cul', None, 'recap-nypl']: with pytest.raises(ResearchCatalogIdentifierError): research_catalog_id_prefix(invalidSource) - for invalidType in [None, "..."]: + for invalidType in [None, '...']: with pytest.raises(ResearchCatalogIdentifierError): - research_catalog_id_prefix("sierra-nypl", invalidType) + research_catalog_id_prefix('sierra-nypl', invalidType) diff --git a/tests/test_s3_client.py b/tests/test_s3_client.py index 6c1a6e5..bbb74e0 100644 --- a/tests/test_s3_client.py +++ b/tests/test_s3_client.py @@ -3,20 +3,20 @@ from nypl_py_utils.classes.s3_client import S3Client -_TEST_STATE = {"key1": "val1", "key2": "val2", "key3": "val3"} +_TEST_STATE = {'key1': 'val1', 'key2': 'val2', 'key3': 'val3'} class TestS3Client: @pytest.fixture def test_instance(self, mocker): - mocker.patch("boto3.client") - return S3Client("test_s3_bucket", "test_s3_resource") + mocker.patch('boto3.client') + return S3Client('test_s3_bucket', 'test_s3_resource') def test_fetch_cache(self, test_instance): def mock_download(bucket, resource, stream): - assert bucket == "test_s3_bucket" - assert resource == "test_s3_resource" + assert bucket == 'test_s3_bucket' + assert resource == 'test_s3_resource' stream.write(json.dumps(_TEST_STATE).encode()) test_instance.s3_client.download_fileobj.side_effect = mock_download @@ -26,5 +26,5 @@ def test_set_cache(self, test_instance): test_instance.set_cache(_TEST_STATE) arguments = test_instance.s3_client.upload_fileobj.call_args.args assert arguments[0].getvalue() == json.dumps(_TEST_STATE).encode() - assert arguments[1] == "test_s3_bucket" - assert arguments[2] == "test_s3_resource" + assert arguments[1] == 'test_s3_bucket' + assert arguments[2] == 'test_s3_resource' From 0cd7507f29b4612137509878f44a0e12a70c09d9 Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Tue, 2 Jul 2024 11:55:35 -0500 Subject: [PATCH 07/14] Autopep8 linter used --- src/nypl_py_utils/classes/avro_client.py | 28 ++++++++++++++---------- tests/test_avro_client.py | 27 +++++++++++++++-------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index ae2c988..2cc5cc3 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -9,17 +9,18 @@ from nypl_py_utils.functions.log_helper import create_log from requests.exceptions import JSONDecodeError, RequestException + class AvroClient: """ Base class for Avro schema interaction. Takes as input the Platform API endpoint from which to fetch the schema in JSON format. """ - + def __init__(self, platform_schema_url): self.logger = create_log('avro_encoder') self.schema = avro.schema.parse( self.get_json_schema(platform_schema_url)) - + def get_json_schema(self, platform_schema_url): """ Fetches a JSON response from the input Platform API endpoint and @@ -48,7 +49,7 @@ def get_json_schema(self, platform_schema_url): raise AvroClientError( 'Retrieved schema is malformed: {errorType} {errorMessage}' .format(errorType=type(e), errorMessage=e)) from None - + class AvroEncoder(AvroClient): """ @@ -64,7 +65,7 @@ def encode_record(self, record): """ self.logger.debug( 'Encoding record using {schema} schema'.format( - schema=self.schema.name)) + schema=self.schema.name)) datum_writer = DatumWriter(self.schema) with BytesIO() as output_stream: encoder = BinaryEncoder(output_stream) @@ -110,30 +111,33 @@ class AvroDecoder(AvroClient): def decode_record(self, record, encoding="binary"): """ - Decodes a single record represented either as a byte or + Decodes a single record represented either as a byte or base64 string, using the given Avro schema. Returns a dictionary where each key is a field in the schema. """ - self.logger.info('Decoding {rec} of type {type} using {schema} schema'.format( - rec=record, type=encoding, schema=self.schema.name)) - + self.logger.info('Decoding {rec} of type {type} using {schema} schema' + .format(rec=record, type=encoding, + schema=self.schema.name)) + if encoding == "base64": return self._decode_base64(record) elif encoding == "binary": return self._decode_binary(record) else: - self.logger.error('Failed to decode record due to encoding type: {}'.format(encoding)) + self.logger.error( + 'Failed to decode record due to encoding type: {}' + .format(encoding)) raise AvroClientError( 'Invalid encoding type: {}'.format(encoding)) - + def _decode_base64(self, record): decoded_data = base64.b64decode(record).decode("utf-8") try: return json.loads(decoded_data) except Exception as e: if isinstance(decoded_data, bytes): - self._decode_binary(decoded_data) + self._decode_binary(decoded_data) else: self.logger.error('Failed to decode record: {}'.format(e)) raise AvroClientError( @@ -153,4 +157,4 @@ def _decode_binary(self, record): class AvroClientError(Exception): def __init__(self, message=None): - self.message = message \ No newline at end of file + self.message = message diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py index 7e26981..0a91869 100644 --- a/tests/test_avro_client.py +++ b/tests/test_avro_client.py @@ -1,7 +1,8 @@ import json import pytest -from nypl_py_utils.classes.avro_client import AvroDecoder, AvroEncoder, AvroClientError +from nypl_py_utils.classes.avro_client import ( + AvroDecoder, AvroEncoder, AvroClientError) from requests.exceptions import ConnectTimeout _TEST_SCHEMA = {'data': {'schema': json.dumps({ @@ -19,6 +20,7 @@ ] })}} + class TestAvroClient: @pytest.fixture @@ -33,9 +35,12 @@ def test_avro_decoder_instance(self, requests_mock): 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) return AvroDecoder('https://test_schema_url') - def test_get_json_schema(self, test_avro_encoder_instance, test_avro_decoder_instance): - assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data']['schema'] - assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data']['schema'] + def test_get_json_schema(self, test_avro_encoder_instance, + test_avro_decoder_instance): + assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data'][ + 'schema'] + assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data'][ + 'schema'] def test_request_error(self, requests_mock): requests_mock.get('https://test_schema_url', exc=ConnectTimeout) @@ -54,18 +59,21 @@ def test_missing_key_error(self, requests_mock): with pytest.raises(AvroClientError): AvroEncoder('https://test_schema_url') - def test_encode_record(self, test_avro_encoder_instance, test_avro_decoder_instance): + def test_encode_record(self, test_avro_encoder_instance, + test_avro_decoder_instance): TEST_RECORD = {'patron_id': 123, 'library_branch': 'aa'} encoded_record = test_avro_encoder_instance.encode_record(TEST_RECORD) assert type(encoded_record) is bytes - assert test_avro_decoder_instance.decode_record(encoded_record) == TEST_RECORD + assert test_avro_decoder_instance.decode_record( + encoded_record) == TEST_RECORD def test_encode_record_error(self, test_avro_encoder_instance): TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} with pytest.raises(AvroClientError): test_avro_encoder_instance.encode_record(TEST_RECORD) - def test_encode_batch(self, test_avro_encoder_instance, test_avro_decoder_instance): + def test_encode_batch(self, test_avro_encoder_instance, + test_avro_decoder_instance): TEST_BATCH = [ {'patron_id': 123, 'library_branch': 'aa'}, {'patron_id': 456, 'library_branch': None}, @@ -89,10 +97,11 @@ def test_decode_record_binary(self, test_avro_decoder_instance): TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' assert test_avro_decoder_instance.decode_record( TEST_ENCODED_RECORD) == TEST_DECODED_RECORD - + def test_decode_record_b64(self, test_avro_decoder_instance): TEST_DECODED_RECORD = {"patron_id'": 123, "library_branch": "aa"} - TEST_ENCODED_RECORD = "eyJwYXRyb25faWQnIjogMTIzLCAibGlicmFyeV9icmFuY2giOiAiYWEifQ==" + TEST_ENCODED_RECORD = ( + "eyJwYXRyb25faWQnIjogMTIzLCAibGlicmFyeV9icmFuY2giOiAiYWEifQ==") assert test_avro_decoder_instance.decode_record( TEST_ENCODED_RECORD, "base64") == TEST_DECODED_RECORD From 98a5b63863302377fcafe7e934c5264d7175aaf1 Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Mon, 8 Jul 2024 13:18:08 -0400 Subject: [PATCH 08/14] Added decode_batch and additional tests, also addressed minor complaints --- CHANGELOG.md | 2 +- pyproject.toml | 2 +- src/nypl_py_utils/classes/avro_client.py | 49 ++++++++++-------------- tests/test_avro_client.py | 32 +++++++++++----- 4 files changed, 45 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f7b9675..1c7a5ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # Changelog -## v1.1.6 6/26/24 +## v1.2.0 7/8/24 - Generalized Avro functions and separated encoding/decoding behavior. ## v1.1.5 6/6/24 diff --git a/pyproject.toml b/pyproject.toml index 534e902..0672e8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "nypl_py_utils" -version = "1.1.5" +version = "1.2.0" authors = [ { name="Aaron Friedman", email="aaronfriedman@nypl.org" }, ] diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index 2cc5cc3..140f482 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -1,6 +1,5 @@ import avro.schema import base64 -import json import requests from avro.errors import AvroException @@ -109,39 +108,18 @@ class AvroDecoder(AvroClient): Platform API endpoint from which to fetch the schema in JSON format. """ - def decode_record(self, record, encoding="binary"): + def decode_record(self, record): """ Decodes a single record represented either as a byte or base64 string, using the given Avro schema. Returns a dictionary where each key is a field in the schema. """ - self.logger.info('Decoding {rec} of type {type} using {schema} schema' - .format(rec=record, type=encoding, - schema=self.schema.name)) - - if encoding == "base64": - return self._decode_base64(record) - elif encoding == "binary": - return self._decode_binary(record) - else: - self.logger.error( - 'Failed to decode record due to encoding type: {}' - .format(encoding)) - raise AvroClientError( - 'Invalid encoding type: {}'.format(encoding)) - - def _decode_base64(self, record): - decoded_data = base64.b64decode(record).decode("utf-8") - try: - return json.loads(decoded_data) - except Exception as e: - if isinstance(decoded_data, bytes): - self._decode_binary(decoded_data) - else: - self.logger.error('Failed to decode record: {}'.format(e)) - raise AvroClientError( - 'Failed to decode record: {}'.format(e)) from None + self.logger.info('Decoding {rec} using {schema} schema' + .format(rec=record, schema=self.schema.name)) + bytes_input = base64.b64decode(record) if ( + isinstance(record, str)) else record + return self._decode_binary(bytes_input) def _decode_binary(self, record): datum_reader = DatumReader(self.schema) @@ -154,6 +132,21 @@ def _decode_binary(self, record): raise AvroClientError( 'Failed to decode record: {}'.format(e)) from None + def decode_batch(self, record_list): + """ + Decodes a list of JSON records using the given Avro schema. + + Returns a list of strings where each string is an decoded record. + """ + self.logger.info( + 'Encoding ({num_rec}) records using {schema} schema'.format( + num_rec=len(record_list), schema=self.schema.name)) + decoded_records = [] + for record in record_list: + decoded_record = self._decode_binary(record) + decoded_records.append(decoded_record) + return decoded_records + class AvroClientError(Exception): def __init__(self, message=None): diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py index 0a91869..236b94d 100644 --- a/tests/test_avro_client.py +++ b/tests/test_avro_client.py @@ -2,7 +2,7 @@ import pytest from nypl_py_utils.classes.avro_client import ( - AvroDecoder, AvroEncoder, AvroClientError) + AvroClientError, AvroDecoder, AvroEncoder) from requests.exceptions import ConnectTimeout _TEST_SCHEMA = {'data': {'schema': json.dumps({ @@ -39,8 +39,8 @@ def test_get_json_schema(self, test_avro_encoder_instance, test_avro_decoder_instance): assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data'][ 'schema'] - assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data'][ - 'schema'] + # assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data'][ + # 'schema'] def test_request_error(self, requests_mock): requests_mock.get('https://test_schema_url', exc=ConnectTimeout) @@ -98,14 +98,26 @@ def test_decode_record_binary(self, test_avro_decoder_instance): assert test_avro_decoder_instance.decode_record( TEST_ENCODED_RECORD) == TEST_DECODED_RECORD - def test_decode_record_b64(self, test_avro_decoder_instance): - TEST_DECODED_RECORD = {"patron_id'": 123, "library_branch": "aa"} - TEST_ENCODED_RECORD = ( - "eyJwYXRyb25faWQnIjogMTIzLCAibGlicmFyeV9icmFuY2giOiAiYWEifQ==") - assert test_avro_decoder_instance.decode_record( - TEST_ENCODED_RECORD, "base64") == TEST_DECODED_RECORD - def test_decode_record_error(self, test_avro_decoder_instance): TEST_ENCODED_RECORD = b'bad-encoding' with pytest.raises(AvroClientError): test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) + + def test_decode_batch(self, test_avro_decoder_instance): + TEST_ENCODED_BATCH = [ + b'\xf6\x01\x02\x04aa', + b'\x90\x07\x00', + b'\xaa\x0c\x02\x04bb'] + TEST_DECODED_BATCH = [ + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'library_branch': None}, + {'patron_id': 789, 'library_branch': 'bb'}] + assert test_avro_decoder_instance.decode_batch( + TEST_ENCODED_BATCH) == TEST_DECODED_BATCH + + def test_decode_batch_error(self, test_avro_decoder_instance): + BAD_BATCH = [ + b'\xf6\x01\x02\x04aa', + b'bad-encoding'] + with pytest.raises(AvroClientError): + test_avro_decoder_instance.decode_batch(BAD_BATCH) From a90e019e5cc2823af9ec8c67146c760e7a04c71f Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Tue, 9 Jul 2024 11:39:23 -0400 Subject: [PATCH 09/14] Resolved grammar errors plus removed decode binary func --- src/nypl_py_utils/classes/avro_client.py | 76 ++++++++++++++---------- tests/test_avro_client.py | 4 +- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index 140f482..c28c3ff 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -16,7 +16,7 @@ class AvroClient: """ def __init__(self, platform_schema_url): - self.logger = create_log('avro_encoder') + self.logger = create_log("avro_encoder") self.schema = avro.schema.parse( self.get_json_schema(platform_schema_url)) @@ -25,29 +25,35 @@ def get_json_schema(self, platform_schema_url): Fetches a JSON response from the input Platform API endpoint and interprets it as an Avro schema. """ - self.logger.info('Fetching Avro schema from {}'.format( - platform_schema_url)) + self.logger.info( + "Fetching Avro schema from {}".format(platform_schema_url)) try: response = requests.get(platform_schema_url) response.raise_for_status() except RequestException as e: self.logger.error( - 'Failed to retrieve schema from {url}: {error}'.format( - url=platform_schema_url, error=e)) + "Failed to retrieve schema from {url}: {error}".format( + url=platform_schema_url, error=e + ) + ) raise AvroClientError( - 'Failed to retrieve schema from {url}: {error}'.format( - url=platform_schema_url, error=e)) from None + "Failed to retrieve schema from {url}: {error}".format( + url=platform_schema_url, error=e + ) + ) from None try: json_response = response.json() - return json_response['data']['schema'] + return json_response["data"]["schema"] except (JSONDecodeError, KeyError) as e: self.logger.error( - 'Retrieved schema is malformed: {errorType} {errorMessage}' - .format(errorType=type(e), errorMessage=e)) + "Retrieved schema is malformed: {errorType} {errorMessage}" + .format(errorType=type(e), errorMessage=e) + ) raise AvroClientError( - 'Retrieved schema is malformed: {errorType} {errorMessage}' - .format(errorType=type(e), errorMessage=e)) from None + "Retrieved schema is malformed: {errorType} {errorMessage}" + .format(errorType=type(e), errorMessage=e) + ) from None class AvroEncoder(AvroClient): @@ -63,8 +69,9 @@ def encode_record(self, record): Returns the encoded record as a byte string. """ self.logger.debug( - 'Encoding record using {schema} schema'.format( - schema=self.schema.name)) + "Encoding record using {schema} schema".format( + schema=self.schema.name) + ) datum_writer = DatumWriter(self.schema) with BytesIO() as output_stream: encoder = BinaryEncoder(output_stream) @@ -72,9 +79,9 @@ def encode_record(self, record): datum_writer.write(record, encoder) return output_stream.getvalue() except AvroException as e: - self.logger.error('Failed to encode record: {}'.format(e)) + self.logger.error("Failed to encode record: {}".format(e)) raise AvroClientError( - 'Failed to encode record: {}'.format(e)) from None + "Failed to encode record: {}".format(e)) from None def encode_batch(self, record_list): """ @@ -83,8 +90,10 @@ def encode_batch(self, record_list): Returns a list of byte strings where each string is an encoded record. """ self.logger.info( - 'Encoding ({num_rec}) records using {schema} schema'.format( - num_rec=len(record_list), schema=self.schema.name)) + "Encoding ({num_rec}) records using {schema} schema".format( + num_rec=len(record_list), schema=self.schema.name + ) + ) encoded_records = [] datum_writer = DatumWriter(self.schema) with BytesIO() as output_stream: @@ -96,9 +105,10 @@ def encode_batch(self, record_list): output_stream.seek(0) output_stream.truncate(0) except AvroException as e: - self.logger.error('Failed to encode record: {}'.format(e)) + self.logger.error("Failed to encode record: {}".format(e)) raise AvroClientError( - 'Failed to encode record: {}'.format(e)) from None + "Failed to encode record: {}".format(e) + ) from None return encoded_records @@ -115,35 +125,37 @@ def decode_record(self, record): Returns a dictionary where each key is a field in the schema. """ - self.logger.info('Decoding {rec} using {schema} schema' - .format(rec=record, schema=self.schema.name)) + self.logger.info( + "Decoding {rec} using {schema} schema".format( + rec=record, schema=self.schema.name + ) + ) bytes_input = base64.b64decode(record) if ( isinstance(record, str)) else record - return self._decode_binary(bytes_input) - - def _decode_binary(self, record): datum_reader = DatumReader(self.schema) - with BytesIO(record) as input_stream: + with BytesIO(bytes_input) as input_stream: decoder = BinaryDecoder(input_stream) try: return datum_reader.read(decoder) except Exception as e: - self.logger.error('Failed to decode record: {}'.format(e)) + self.logger.error("Failed to decode record: {}".format(e)) raise AvroClientError( - 'Failed to decode record: {}'.format(e)) from None + "Failed to decode record: {}".format(e)) from None def decode_batch(self, record_list): """ Decodes a list of JSON records using the given Avro schema. - Returns a list of strings where each string is an decoded record. + Returns a list of strings where each string is a decoded record. """ self.logger.info( - 'Encoding ({num_rec}) records using {schema} schema'.format( - num_rec=len(record_list), schema=self.schema.name)) + "Decoding ({num_rec}) records using {schema} schema".format( + num_rec=len(record_list), schema=self.schema.name + ) + ) decoded_records = [] for record in record_list: - decoded_record = self._decode_binary(record) + decoded_record = self.decode_record(record) decoded_records.append(decoded_record) return decoded_records diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py index 236b94d..9d78c67 100644 --- a/tests/test_avro_client.py +++ b/tests/test_avro_client.py @@ -39,8 +39,8 @@ def test_get_json_schema(self, test_avro_encoder_instance, test_avro_decoder_instance): assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data'][ 'schema'] - # assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data'][ - # 'schema'] + assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data'][ + 'schema'] def test_request_error(self, requests_mock): requests_mock.get('https://test_schema_url', exc=ConnectTimeout) From 1f7eb8536e5547fa7c4556f9ff13f5058e9735b6 Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Tue, 9 Jul 2024 14:49:16 -0400 Subject: [PATCH 10/14] Changed test method name --- tests/test_avro_client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py index 9d78c67..1a0b527 100644 --- a/tests/test_avro_client.py +++ b/tests/test_avro_client.py @@ -32,7 +32,9 @@ def test_avro_encoder_instance(self, requests_mock): @pytest.fixture def test_avro_decoder_instance(self, requests_mock): requests_mock.get( - 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + 'https://test_schema_url', text=json.dumps(LOCATION_NEW)) + # requests_mock.get( + # 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) return AvroDecoder('https://test_schema_url') def test_get_json_schema(self, test_avro_encoder_instance, @@ -92,7 +94,7 @@ def test_encode_batch_error(self, test_avro_encoder_instance): with pytest.raises(AvroClientError): test_avro_encoder_instance.encode_batch(BAD_BATCH) - def test_decode_record_binary(self, test_avro_decoder_instance): + def test_decode_record(self, test_avro_decoder_instance): TEST_DECODED_RECORD = {"patron_id": 123, "library_branch": "aa"} TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' assert test_avro_decoder_instance.decode_record( From 03e11a4357460b52a3239c4564a4532ebaba9656 Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Tue, 9 Jul 2024 14:53:26 -0400 Subject: [PATCH 11/14] Oops --- tests/test_avro_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py index 1a0b527..c4429c6 100644 --- a/tests/test_avro_client.py +++ b/tests/test_avro_client.py @@ -32,9 +32,7 @@ def test_avro_encoder_instance(self, requests_mock): @pytest.fixture def test_avro_decoder_instance(self, requests_mock): requests_mock.get( - 'https://test_schema_url', text=json.dumps(LOCATION_NEW)) - # requests_mock.get( - # 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) return AvroDecoder('https://test_schema_url') def test_get_json_schema(self, test_avro_encoder_instance, From 04e26a7a9b78d37b4cef3289459317e0bb6c4e94 Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Mon, 15 Jul 2024 12:05:26 -0500 Subject: [PATCH 12/14] Removed base64 decoding stage --- src/nypl_py_utils/classes/avro_client.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index c28c3ff..0295d5a 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -1,5 +1,4 @@ import avro.schema -import base64 import requests from avro.errors import AvroException @@ -130,10 +129,8 @@ def decode_record(self, record): rec=record, schema=self.schema.name ) ) - bytes_input = base64.b64decode(record) if ( - isinstance(record, str)) else record datum_reader = DatumReader(self.schema) - with BytesIO(bytes_input) as input_stream: + with BytesIO(record) as input_stream: decoder = BinaryDecoder(input_stream) try: return datum_reader.read(decoder) From a9810e0300d0a79155e98d705a6c5a498e9491ae Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Wed, 17 Jul 2024 08:24:33 -0500 Subject: [PATCH 13/14] Changed wording --- CHANGELOG.md | 2 +- src/nypl_py_utils/classes/avro_client.py | 3 +-- tests/test_avro_client.py | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c7a5ca..c560cf8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # Changelog -## v1.2.0 7/8/24 +## v1.2.0 7/17/24 - Generalized Avro functions and separated encoding/decoding behavior. ## v1.1.5 6/6/24 diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index 0295d5a..01ea031 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -119,8 +119,7 @@ class AvroDecoder(AvroClient): def decode_record(self, record): """ - Decodes a single record represented either as a byte or - base64 string, using the given Avro schema. + Decodes a single record represented using the given Avro schema. Returns a dictionary where each key is a field in the schema. """ diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py index c4429c6..af9e87c 100644 --- a/tests/test_avro_client.py +++ b/tests/test_avro_client.py @@ -22,7 +22,6 @@ class TestAvroClient: - @pytest.fixture def test_avro_encoder_instance(self, requests_mock): requests_mock.get( From e06efef401a116683c5c02117d1ba6e52038faca Mon Sep 17 00:00:00 2001 From: fatimarahman Date: Wed, 17 Jul 2024 09:17:05 -0500 Subject: [PATCH 14/14] Updated method descriptions --- src/nypl_py_utils/classes/avro_client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py index 01ea031..dd49f26 100644 --- a/src/nypl_py_utils/classes/avro_client.py +++ b/src/nypl_py_utils/classes/avro_client.py @@ -119,7 +119,8 @@ class AvroDecoder(AvroClient): def decode_record(self, record): """ - Decodes a single record represented using the given Avro schema. + Decodes a single record represented using the given Avro + schema. Input must be a bytes-like object. Returns a dictionary where each key is a field in the schema. """ @@ -140,7 +141,8 @@ def decode_record(self, record): def decode_batch(self, record_list): """ - Decodes a list of JSON records using the given Avro schema. + Decodes a list of JSON records using the given Avro schema. Input + must be a list of bytes-like objects. Returns a list of strings where each string is a decoded record. """