Skip to content

Commit

Permalink
Added ability for user input custom json de/serializer in base class
Browse files Browse the repository at this point in the history
  • Loading branch information
seshubaws committed Sep 12, 2023
1 parent b3d123d commit c0c3f2f
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 61 deletions.
4 changes: 4 additions & 0 deletions aws_lambda_powertools/utilities/data_masking/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
DATA_MASKING_STRING: str = "*****"
CACHE_CAPACITY: int = 100
MAX_CACHE_AGE_SECONDS: float = 300.0
MAX_MESSAGES_ENCRYPTED: int = 200
# NOTE: You can also set max messages/bytes per data key
15 changes: 13 additions & 2 deletions aws_lambda_powertools/utilities/data_masking/provider.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
from abc import ABC
import json
from abc import ABCMeta
from collections.abc import Iterable
from typing import Union

from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING


class BaseProvider(ABC):
class BaseProvider(metaclass=ABCMeta):
"""
When you try to create an instance of a subclass that does not implement the encrypt method,
you will get a NotImplementedError with a message that says the method is not implemented:
"""

def __init__(self, json_serializer=None, json_deserializer=None) -> None:
self.json_serializer = json_serializer or self.default_json_serializer
self.json_deserializer = json_deserializer or self.default_json_deserializer

def default_json_serializer(self, data):
return json.dumps(data).encode("utf-8")

def default_json_deserializer(self, data):
return json.loads(data.decode("utf-8"))

def encrypt(self, data) -> Union[bytes, str]:
raise NotImplementedError("Subclasses must implement encrypt()")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64
import json
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import botocore
from aws_encryption_sdk import (
Expand All @@ -12,6 +11,11 @@
)

from aws_lambda_powertools.shared.user_agent import register_feature_to_botocore_session
from aws_lambda_powertools.utilities.data_masking.constants import (
CACHE_CAPACITY,
MAX_CACHE_AGE_SECONDS,
MAX_MESSAGES_ENCRYPTED,
)
from aws_lambda_powertools.utilities.data_masking.provider import BaseProvider


Expand Down Expand Up @@ -41,12 +45,6 @@ def __new__(cls, *args, force_new_instance=False, **kwargs):
return cls._instances[config_key]


CACHE_CAPACITY: int = 100
MAX_CACHE_AGE_SECONDS: float = 300.0
MAX_MESSAGES_ENCRYPTED: int = 200
# NOTE: You can also set max messages/bytes per data key


class AwsEncryptionSdkProvider(BaseProvider, Singleton):
"""
The AwsEncryptionSdkProvider is to be used as a Provider for the Datamasking class.
Expand All @@ -67,10 +65,13 @@ def __init__(
self,
keys: List[str],
client: Optional[EncryptionSDKClient] = None,
local_cache_capacity: Optional[int] = CACHE_CAPACITY,
max_cache_age_seconds: Optional[float] = MAX_CACHE_AGE_SECONDS,
max_messages_encrypted: Optional[int] = MAX_MESSAGES_ENCRYPTED,
local_cache_capacity: int = CACHE_CAPACITY,
max_cache_age_seconds: float = MAX_CACHE_AGE_SECONDS,
max_messages_encrypted: int = MAX_MESSAGES_ENCRYPTED,
json_serializer: Optional[Callable[[Dict], str]] = None,
json_deserializer: Optional[Callable[[Union[Dict, str, bool, int, float]], str]] = None,
):
super().__init__(json_serializer=json_serializer, json_deserializer=json_deserializer)
self.client = client or EncryptionSDKClient()
self.keys = keys
self.cache = LocalCryptoMaterialsCache(local_cache_capacity)
Expand All @@ -82,14 +83,6 @@ def __init__(
max_messages_encrypted=max_messages_encrypted,
)

def _serialize(self, data: Any) -> bytes:
json_data = json.dumps(data)
return json_data.encode("utf-8")

def _deserialize(self, data: bytes) -> Any:
json_data = data.decode("utf-8")
return json.loads(json_data)

def encrypt(self, data: Union[bytes, str], **provider_options) -> bytes:
"""
Encrypt data using the AwsEncryptionSdkProvider.
Expand All @@ -106,7 +99,7 @@ def encrypt(self, data: Union[bytes, str], **provider_options) -> bytes:
ciphertext : str
The encrypted data, as a base64-encoded string.
"""
data = self._serialize(data)
data = self.json_serializer(data)
ciphertext, _ = self.client.encrypt(source=data, materials_manager=self.cache_cmm, **provider_options)
ciphertext = base64.b64encode(ciphertext).decode()
return ciphertext
Expand Down Expand Up @@ -141,5 +134,5 @@ def decrypt(self, data: str, **provider_options) -> Any:
if decryptor_header.encryption_context.get(key) != value:
raise ContextMismatchError(key)

ciphertext = self._deserialize(ciphertext)
ciphertext = self.json_deserializer(ciphertext)
return ciphertext
22 changes: 1 addition & 21 deletions tests/functional/data_masking/test_aws_encryption_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_mask_dict(data_masker):


def test_mask_dict_with_fields(data_masker):
# GIVEN the data type is a dictionary
# GIVEN a dict data type
data = {
"a": {
"1": {"None": "hello", "four": "world"},
Expand Down Expand Up @@ -208,7 +208,6 @@ def test_encrypt_list(data_masker):

def test_encrypt_dict(data_masker):
# GIVEN an dict data type

data = {
"a": {
"1": {"None": "hello", "four": "world"},
Expand Down Expand Up @@ -258,22 +257,3 @@ def test_encrypt_json_dict_with_fields(data_masker):

# THEN the result is only the specified fields are masked
assert decrypted_data == json.loads(data)


def test_encrypt_json_blob_with_fields(data_masker):
# GIVEN the data type is a json representation of a dictionary
data = json.dumps(
{
"a": {
"1": {"None": "hello", "four": "world"},
"b": {"3": {"4": "goodbye", "e": "world"}},
},
},
)

# WHEN encrypting and then decrypting the encrypted data
encrypted_data = data_masker.encrypt(data, fields=["a.1.None", "a.b.3.4"])
decrypted_data = data_masker.decrypt(encrypted_data, fields=["a.1.None", "a.b.3.4"])

# THEN the result is only the specified fields are masked
assert decrypted_data == json.loads(data)
28 changes: 11 additions & 17 deletions tests/unit/data_masking/test_data_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_mask_dict(data_masker):


def test_mask_dict_with_fields(data_masker):
# GIVEN the data type is a dictionary
# GIVEN a dict data type
data = {
"a": {
"1": {"None": "hello", "four": "world"},
Expand Down Expand Up @@ -135,35 +135,31 @@ def test_encrypt_not_implemented(data_masker):
# GIVEN DataMasking is not initialized with a Provider

# WHEN attempting to call the encrypt method on the data

# THEN the result is a NotImplementedError
with pytest.raises(NotImplementedError):
# THEN the result is a NotImplementedError
data_masker.encrypt("hello world")


def test_decrypt_not_implemented(data_masker):
# GIVEN DataMasking is not initialized with a Provider

# WHEN attempting to call the decrypt method on the data

# THEN the result is a NotImplementedError
with pytest.raises(NotImplementedError):
# THEN the result is a NotImplementedError
data_masker.decrypt("hello world")


def test_parsing_unsupported_data_type(data_masker):
# GIVEN an initialization of the DataMasking class

# WHEN attempting to pass in a list of fields with input data that is not a dict

# THEN the result is a TypeError
with pytest.raises(TypeError):
# THEN the result is a TypeError
data_masker.mask(42, ["this.field"])


def test_parsing_nonexistent_fields(data_masker):
# GIVEN an initialization of the DataMasking class

# GIVEN a dict data type
data = {
"3": {
"1": {"None": "hello", "four": "world"},
Expand All @@ -172,15 +168,13 @@ def test_parsing_nonexistent_fields(data_masker):
}

# WHEN attempting to pass in fields that do not exist in the input data

# THEN the result is a KeyError
with pytest.raises(KeyError):
# THEN the result is a KeyError
data_masker.mask(data, ["3.1.True"])


def test_parsing_nonstring_fields(data_masker):
# GIVEN an initialization of the DataMasking class

# GIVEN a dict data type
data = {
"3": {
"1": {"None": "hello", "four": "world"},
Expand All @@ -196,16 +190,16 @@ def test_parsing_nonstring_fields(data_masker):


def test_parsing_nonstring_keys_and_fields(data_masker):
# GIVEN an initialization of the DataMasking class

# WHEN the input data is a dictionary with integer keys
# GIVEN a dict data type with integer keys
data = {
3: {
"1": {"None": "hello", "four": "world"},
4: {"33": {"5": "goodbye", "e": "world"}},
},
}

# WHEN masked with a list of fields that are integer keys
masked = data_masker.mask(data, fields=[3.4])

# THEN the result is the value of the nested field should be masked as normal
# THEN the result is the value of the nested field should be masked
assert masked == {"3": {"1": {"None": "hello", "four": "world"}, "4": DATA_MASKING_STRING}}

0 comments on commit c0c3f2f

Please sign in to comment.