diff --git a/.gitignore b/.gitignore index 28aac37..bd867e1 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__/ *env/ *.py[cod] *$py.class +*.pytest_cache \ No newline at end of file diff --git a/.python-version b/.python-version index fd15561..9f3d4c1 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.9.16 \ No newline at end of file +3.9.16 diff --git a/CHANGELOG.md b/CHANGELOG.md index eb5135b..d080e37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Changelog +## v1.2.0 7/17/24 +- Generalized Avro functions and separated encoding/decoding behavior. + ## v1.1.6 7/12/24 - Add put functionality to Oauth2 Client - Update pyproject version diff --git a/pyproject.toml b/pyproject.toml index d15680c..0672e8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "nypl_py_utils" -version = "1.1.6" +version = "1.2.0" authors = [ { name="Aaron Friedman", email="aaronfriedman@nypl.org" }, ] @@ -23,7 +23,7 @@ dependencies = [] "Bug Tracker" = "https://github.com/NYPL/python-utils/issues" [project.optional-dependencies] -avro-encoder = [ +avro-client = [ "avro>=1.11.1", "requests>=2.28.1" ] @@ -67,7 +67,7 @@ research-catalog-identifier-helper = [ "requests>=2.28.1" ] development = [ - "nypl_py_utils[avro-encoder,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,postgresql-pool-client,redshift-client,s3-client,config-helper,obfuscation-helper,research-catalog-identifier-helper]", + "nypl_py_utils[avro-client,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,postgresql-pool-client,redshift-client,s3-client,config-helper,obfuscation-helper,research-catalog-identifier-helper]", "flake8>=6.0.0", "freezegun>=1.2.2", "mock>=4.0.3", @@ -75,3 +75,11 @@ development = [ "pytest-mock>=3.10.0", "requests-mock>=1.10.0" ] + +[tool.pytest.ini_options] +minversion = "8.0" +addopts = "-ra -q" +pythonpath = "src" +testpaths = [ + "tests" +] \ No newline at end of file diff --git a/src/nypl_py_utils/classes/avro_client.py b/src/nypl_py_utils/classes/avro_client.py new file mode 100644 index 0000000..dd49f26 --- /dev/null +++ b/src/nypl_py_utils/classes/avro_client.py @@ -0,0 +1,163 @@ +import avro.schema +import requests + +from avro.errors import AvroException +from avro.io import BinaryDecoder, BinaryEncoder, DatumReader, DatumWriter +from io import BytesIO +from nypl_py_utils.functions.log_helper import create_log +from requests.exceptions import JSONDecodeError, RequestException + + +class AvroClient: + """ + Base class for Avro schema interaction. Takes as input the + Platform API endpoint from which to fetch the schema in JSON format. + """ + + def __init__(self, platform_schema_url): + self.logger = create_log("avro_encoder") + self.schema = avro.schema.parse( + self.get_json_schema(platform_schema_url)) + + def get_json_schema(self, platform_schema_url): + """ + Fetches a JSON response from the input Platform API endpoint and + interprets it as an Avro schema. + """ + self.logger.info( + "Fetching Avro schema from {}".format(platform_schema_url)) + try: + response = requests.get(platform_schema_url) + response.raise_for_status() + except RequestException as e: + self.logger.error( + "Failed to retrieve schema from {url}: {error}".format( + url=platform_schema_url, error=e + ) + ) + raise AvroClientError( + "Failed to retrieve schema from {url}: {error}".format( + url=platform_schema_url, error=e + ) + ) from None + + try: + json_response = response.json() + return json_response["data"]["schema"] + except (JSONDecodeError, KeyError) as e: + self.logger.error( + "Retrieved schema is malformed: {errorType} {errorMessage}" + .format(errorType=type(e), errorMessage=e) + ) + raise AvroClientError( + "Retrieved schema is malformed: {errorType} {errorMessage}" + .format(errorType=type(e), errorMessage=e) + ) from None + + +class AvroEncoder(AvroClient): + """ + Class for encoding records using an Avro schema. Takes as input the + Platform API endpoint from which to fetch the schema in JSON format. + """ + + def encode_record(self, record): + """ + Encodes a single JSON record using the given Avro schema. + + Returns the encoded record as a byte string. + """ + self.logger.debug( + "Encoding record using {schema} schema".format( + schema=self.schema.name) + ) + datum_writer = DatumWriter(self.schema) + with BytesIO() as output_stream: + encoder = BinaryEncoder(output_stream) + try: + datum_writer.write(record, encoder) + return output_stream.getvalue() + except AvroException as e: + self.logger.error("Failed to encode record: {}".format(e)) + raise AvroClientError( + "Failed to encode record: {}".format(e)) from None + + def encode_batch(self, record_list): + """ + Encodes a list of JSON records using the given Avro schema. + + Returns a list of byte strings where each string is an encoded record. + """ + self.logger.info( + "Encoding ({num_rec}) records using {schema} schema".format( + num_rec=len(record_list), schema=self.schema.name + ) + ) + encoded_records = [] + datum_writer = DatumWriter(self.schema) + with BytesIO() as output_stream: + encoder = BinaryEncoder(output_stream) + for record in record_list: + try: + datum_writer.write(record, encoder) + encoded_records.append(output_stream.getvalue()) + output_stream.seek(0) + output_stream.truncate(0) + except AvroException as e: + self.logger.error("Failed to encode record: {}".format(e)) + raise AvroClientError( + "Failed to encode record: {}".format(e) + ) from None + return encoded_records + + +class AvroDecoder(AvroClient): + """ + Class for decoding records using an Avro schema. Takes as input the + Platform API endpoint from which to fetch the schema in JSON format. + """ + + def decode_record(self, record): + """ + Decodes a single record represented using the given Avro + schema. Input must be a bytes-like object. + + Returns a dictionary where each key is a field in the schema. + """ + self.logger.info( + "Decoding {rec} using {schema} schema".format( + rec=record, schema=self.schema.name + ) + ) + datum_reader = DatumReader(self.schema) + with BytesIO(record) as input_stream: + decoder = BinaryDecoder(input_stream) + try: + return datum_reader.read(decoder) + except Exception as e: + self.logger.error("Failed to decode record: {}".format(e)) + raise AvroClientError( + "Failed to decode record: {}".format(e)) from None + + def decode_batch(self, record_list): + """ + Decodes a list of JSON records using the given Avro schema. Input + must be a list of bytes-like objects. + + Returns a list of strings where each string is a decoded record. + """ + self.logger.info( + "Decoding ({num_rec}) records using {schema} schema".format( + num_rec=len(record_list), schema=self.schema.name + ) + ) + decoded_records = [] + for record in record_list: + decoded_record = self.decode_record(record) + decoded_records.append(decoded_record) + return decoded_records + + +class AvroClientError(Exception): + def __init__(self, message=None): + self.message = message diff --git a/src/nypl_py_utils/classes/avro_encoder.py b/src/nypl_py_utils/classes/avro_encoder.py deleted file mode 100644 index 8ff5229..0000000 --- a/src/nypl_py_utils/classes/avro_encoder.py +++ /dev/null @@ -1,118 +0,0 @@ -import avro.schema -import requests - -from avro.errors import AvroException -from avro.io import BinaryDecoder, BinaryEncoder, DatumReader, DatumWriter -from io import BytesIO -from nypl_py_utils.functions.log_helper import create_log -from requests.exceptions import JSONDecodeError, RequestException - - -class AvroEncoder: - """ - Class for encoding records using an Avro schema. Takes as input the - Platform API endpoint from which to fetch the schema in JSON format. - """ - - def __init__(self, platform_schema_url): - self.logger = create_log('avro_encoder') - self.schema = avro.schema.parse( - self._get_json_schema(platform_schema_url)) - - def encode_record(self, record): - """ - Encodes a single JSON record using the given Avro schema. - - Returns the encoded record as a byte string. - """ - self.logger.debug( - 'Encoding record using {schema} schema'.format( - schema=self.schema.name)) - datum_writer = DatumWriter(self.schema) - with BytesIO() as output_stream: - encoder = BinaryEncoder(output_stream) - try: - datum_writer.write(record, encoder) - return output_stream.getvalue() - except AvroException as e: - self.logger.error('Failed to encode record: {}'.format(e)) - raise AvroEncoderError( - 'Failed to encode record: {}'.format(e)) from None - - def encode_batch(self, record_list): - """ - Encodes a list of JSON records using the given Avro schema. - - Returns a list of byte strings where each string is an encoded record. - """ - self.logger.info( - 'Encoding ({num_rec}) records using {schema} schema'.format( - num_rec=len(record_list), schema=self.schema.name)) - encoded_records = [] - datum_writer = DatumWriter(self.schema) - with BytesIO() as output_stream: - encoder = BinaryEncoder(output_stream) - for record in record_list: - try: - datum_writer.write(record, encoder) - encoded_records.append(output_stream.getvalue()) - output_stream.seek(0) - output_stream.truncate(0) - except AvroException as e: - self.logger.error('Failed to encode record: {}'.format(e)) - raise AvroEncoderError( - 'Failed to encode record: {}'.format(e)) from None - return encoded_records - - def decode_record(self, record): - """ - Decodes a single record represented as a byte string using the given - Avro schema. - - Returns a dictionary where each key is a field in the schema. - """ - self.logger.debug('Decoding {rec} using {schema} schema'.format( - rec=record, schema=self.schema.name)) - datum_reader = DatumReader(self.schema) - with BytesIO(record) as input_stream: - decoder = BinaryDecoder(input_stream) - try: - return datum_reader.read(decoder) - except Exception as e: - self.logger.error('Failed to decode record: {}'.format(e)) - raise AvroEncoderError( - 'Failed to decode record: {}'.format(e)) from None - - def _get_json_schema(self, platform_schema_url): - """ - Fetches a JSON response from the input Platform API endpoint and - interprets it as an Avro schema. - """ - self.logger.info('Fetching Avro schema from {}'.format( - platform_schema_url)) - try: - response = requests.get(platform_schema_url) - response.raise_for_status() - except RequestException as e: - self.logger.error( - 'Failed to retrieve schema from {url}: {error}'.format( - url=platform_schema_url, error=e)) - raise AvroEncoderError( - 'Failed to retrieve schema from {url}: {error}'.format( - url=platform_schema_url, error=e)) from None - - try: - json_response = response.json() - return json_response['data']['schema'] - except (JSONDecodeError, KeyError) as e: - self.logger.error( - 'Retrieved schema is malformed: {errorType} {errorMessage}' - .format(errorType=type(e), errorMessage=e)) - raise AvroEncoderError( - 'Retrieved schema is malformed: {errorType} {errorMessage}' - .format(errorType=type(e), errorMessage=e)) from None - - -class AvroEncoderError(Exception): - def __init__(self, message=None): - self.message = message diff --git a/tests/test_avro_client.py b/tests/test_avro_client.py new file mode 100644 index 0000000..af9e87c --- /dev/null +++ b/tests/test_avro_client.py @@ -0,0 +1,122 @@ +import json +import pytest + +from nypl_py_utils.classes.avro_client import ( + AvroClientError, AvroDecoder, AvroEncoder) +from requests.exceptions import ConnectTimeout + +_TEST_SCHEMA = {'data': {'schema': json.dumps({ + 'name': 'TestSchema', + 'type': 'record', + 'fields': [ + { + 'name': 'patron_id', + 'type': 'int' + }, + { + 'name': 'library_branch', + 'type': ['null', 'string'] + } + ] +})}} + + +class TestAvroClient: + @pytest.fixture + def test_avro_encoder_instance(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + return AvroEncoder('https://test_schema_url') + + @pytest.fixture + def test_avro_decoder_instance(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) + return AvroDecoder('https://test_schema_url') + + def test_get_json_schema(self, test_avro_encoder_instance, + test_avro_decoder_instance): + assert test_avro_encoder_instance.schema == _TEST_SCHEMA['data'][ + 'schema'] + assert test_avro_decoder_instance.schema == _TEST_SCHEMA['data'][ + 'schema'] + + def test_request_error(self, requests_mock): + requests_mock.get('https://test_schema_url', exc=ConnectTimeout) + with pytest.raises(AvroClientError): + AvroEncoder('https://test_schema_url') + + def test_bad_json_error(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text='bad json') + with pytest.raises(AvroClientError): + AvroEncoder('https://test_schema_url') + + def test_missing_key_error(self, requests_mock): + requests_mock.get( + 'https://test_schema_url', text=json.dumps({'field': 'value'})) + with pytest.raises(AvroClientError): + AvroEncoder('https://test_schema_url') + + def test_encode_record(self, test_avro_encoder_instance, + test_avro_decoder_instance): + TEST_RECORD = {'patron_id': 123, 'library_branch': 'aa'} + encoded_record = test_avro_encoder_instance.encode_record(TEST_RECORD) + assert type(encoded_record) is bytes + assert test_avro_decoder_instance.decode_record( + encoded_record) == TEST_RECORD + + def test_encode_record_error(self, test_avro_encoder_instance): + TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} + with pytest.raises(AvroClientError): + test_avro_encoder_instance.encode_record(TEST_RECORD) + + def test_encode_batch(self, test_avro_encoder_instance, + test_avro_decoder_instance): + TEST_BATCH = [ + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'library_branch': None}, + {'patron_id': 789, 'library_branch': 'bb'}] + encoded_records = test_avro_encoder_instance.encode_batch(TEST_BATCH) + assert len(encoded_records) == len(TEST_BATCH) + for i in range(3): + assert type(encoded_records[i]) is bytes + assert test_avro_decoder_instance.decode_record( + encoded_records[i]) == TEST_BATCH[i] + + def test_encode_batch_error(self, test_avro_encoder_instance): + BAD_BATCH = [ + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'bad_field': 'bad'}] + with pytest.raises(AvroClientError): + test_avro_encoder_instance.encode_batch(BAD_BATCH) + + def test_decode_record(self, test_avro_decoder_instance): + TEST_DECODED_RECORD = {"patron_id": 123, "library_branch": "aa"} + TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' + assert test_avro_decoder_instance.decode_record( + TEST_ENCODED_RECORD) == TEST_DECODED_RECORD + + def test_decode_record_error(self, test_avro_decoder_instance): + TEST_ENCODED_RECORD = b'bad-encoding' + with pytest.raises(AvroClientError): + test_avro_decoder_instance.decode_record(TEST_ENCODED_RECORD) + + def test_decode_batch(self, test_avro_decoder_instance): + TEST_ENCODED_BATCH = [ + b'\xf6\x01\x02\x04aa', + b'\x90\x07\x00', + b'\xaa\x0c\x02\x04bb'] + TEST_DECODED_BATCH = [ + {'patron_id': 123, 'library_branch': 'aa'}, + {'patron_id': 456, 'library_branch': None}, + {'patron_id': 789, 'library_branch': 'bb'}] + assert test_avro_decoder_instance.decode_batch( + TEST_ENCODED_BATCH) == TEST_DECODED_BATCH + + def test_decode_batch_error(self, test_avro_decoder_instance): + BAD_BATCH = [ + b'\xf6\x01\x02\x04aa', + b'bad-encoding'] + with pytest.raises(AvroClientError): + test_avro_decoder_instance.decode_batch(BAD_BATCH) diff --git a/tests/test_avro_encoder.py b/tests/test_avro_encoder.py deleted file mode 100644 index 079088a..0000000 --- a/tests/test_avro_encoder.py +++ /dev/null @@ -1,90 +0,0 @@ -import json -import pytest - -from nypl_py_utils.classes.avro_encoder import AvroEncoder, AvroEncoderError -from requests.exceptions import ConnectTimeout - -_TEST_SCHEMA = {'data': {'schema': json.dumps({ - 'name': 'TestSchema', - 'type': 'record', - 'fields': [ - { - 'name': 'patron_id', - 'type': 'int' - }, - { - 'name': 'library_branch', - 'type': ['null', 'string'] - } - ] -})}} - - -class TestAvroEncoder: - - @pytest.fixture - def test_instance(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text=json.dumps(_TEST_SCHEMA)) - return AvroEncoder('https://test_schema_url') - - def test_get_json_schema(self, test_instance): - assert test_instance.schema == _TEST_SCHEMA['data']['schema'] - - def test_request_error(self, requests_mock): - requests_mock.get('https://test_schema_url', exc=ConnectTimeout) - with pytest.raises(AvroEncoderError): - AvroEncoder('https://test_schema_url') - - def test_bad_json_error(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text='bad json') - with pytest.raises(AvroEncoderError): - AvroEncoder('https://test_schema_url') - - def test_missing_key_error(self, requests_mock): - requests_mock.get( - 'https://test_schema_url', text=json.dumps({'field': 'value'})) - with pytest.raises(AvroEncoderError): - AvroEncoder('https://test_schema_url') - - def test_encode_record(self, test_instance): - TEST_RECORD = {'patron_id': 123, 'library_branch': 'aa'} - encoded_record = test_instance.encode_record(TEST_RECORD) - assert type(encoded_record) is bytes - assert test_instance.decode_record(encoded_record) == TEST_RECORD - - def test_encode_record_error(self, test_instance): - TEST_RECORD = {'patron_id': 123, 'bad_field': 'bad'} - with pytest.raises(AvroEncoderError): - test_instance.encode_record(TEST_RECORD) - - def test_encode_batch(self, test_instance): - TEST_BATCH = [ - {'patron_id': 123, 'library_branch': 'aa'}, - {'patron_id': 456, 'library_branch': None}, - {'patron_id': 789, 'library_branch': 'bb'}] - encoded_records = test_instance.encode_batch(TEST_BATCH) - assert len(encoded_records) == len(TEST_BATCH) - for i in range(3): - assert type(encoded_records[i]) is bytes - assert test_instance.decode_record( - encoded_records[i]) == TEST_BATCH[i] - - def test_encode_batch_error(self, test_instance): - BAD_BATCH = [ - {'patron_id': 123, 'library_branch': 'aa'}, - {'patron_id': 456, 'bad_field': 'bad'}] - with pytest.raises(AvroEncoderError): - test_instance.encode_batch(BAD_BATCH) - - def test_decode_record(self, test_instance): - TEST_DECODED_RECORD = {'patron_id': 123, 'library_branch': 'aa'} - TEST_ENCODED_RECORD = b'\xf6\x01\x02\x04aa' - assert test_instance.decode_record( - TEST_ENCODED_RECORD) == TEST_DECODED_RECORD - - def test_decode_record_error(self, test_instance): - TEST_ENCODED_RECORD = b'bad-encoding' - with pytest.raises(AvroEncoderError): - test_instance.decode_record(TEST_ENCODED_RECORD)