Skip to content

Commit

Permalink
Updated json serializer for aws encrypt sdk to return original data type
Browse files Browse the repository at this point in the history
  • Loading branch information
seshubaws committed Sep 11, 2023
1 parent 371ea05 commit b3d123d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import json
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import botocore
from aws_encryption_sdk import (
Expand All @@ -21,21 +22,21 @@ def __init__(self, key):


class Singleton:
_instances: Dict[Any, "AwsEncryptionSdkProvider"] = {}
_instances: Dict[Tuple, "AwsEncryptionSdkProvider"] = {}

def __new__(cls, *args, **kwargs):
def __new__(cls, *args, force_new_instance=False, **kwargs):
# Generate a unique key based on the configuration.
# Create a tuple by iterating through the values in kwargs, sorting them,
# and then adding them to the tuple.
config_key = tuple()
config_key = ()
for value in kwargs.values():
if isinstance(value, Iterable):
for val in sorted(value):
config_key += (val,)
else:
config_key += (value,)

if config_key not in cls._instances:
if force_new_instance or config_key not in cls._instances:
cls._instances[config_key] = super(Singleton, cls).__new__(cls, *args)
return cls._instances[config_key]

Expand Down Expand Up @@ -82,12 +83,14 @@ def __init__(
)

def _serialize(self, data: Any) -> bytes:
return bytes(str(data), "utf-8")
json_data = json.dumps(data)
return json_data.encode("utf-8")

def _deserialize(self, data: bytes) -> str:
return data.decode("utf-8")
def _deserialize(self, data: bytes) -> Any:
json_data = data.decode("utf-8")
return json.loads(json_data)

def encrypt(self, data: Union[bytes, str], **provider_options) -> str:
def encrypt(self, data: Union[bytes, str], **provider_options) -> bytes:
"""
Encrypt data using the AwsEncryptionSdkProvider.
Expand All @@ -108,7 +111,7 @@ def encrypt(self, data: Union[bytes, str], **provider_options) -> str:
ciphertext = base64.b64encode(ciphertext).decode()
return ciphertext

def decrypt(self, data: str, **provider_options) -> str:
def decrypt(self, data: str, **provider_options) -> Any:
"""
Decrypt data using AwsEncryptionSdkProvider.
Expand Down
14 changes: 7 additions & 7 deletions tests/functional/data_masking/test_aws_encryption_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_encrypt_int(data_masker):
decrypted_data = data_masker.decrypt(encrypted_data)

# THEN the result is the original input data
assert decrypted_data == str(-1)
assert decrypted_data == -1


def test_encrypt_float(data_masker):
Expand All @@ -159,7 +159,7 @@ def test_encrypt_float(data_masker):
decrypted_data = data_masker.decrypt(encrypted_data)

# THEN the result is the original input data
assert decrypted_data == str(-1.11)
assert decrypted_data == -1.11


def test_encrypt_bool(data_masker):
Expand All @@ -170,7 +170,7 @@ def test_encrypt_bool(data_masker):
decrypted_data = data_masker.decrypt(encrypted_data)

# THEN the result is the original input data
assert decrypted_data == str(True)
assert decrypted_data is True


def test_encrypt_none(data_masker):
Expand All @@ -181,7 +181,7 @@ def test_encrypt_none(data_masker):
decrypted_data = data_masker.decrypt(encrypted_data)

# THEN the result is the original input data
assert decrypted_data == str(None)
assert decrypted_data is None


def test_encrypt_str(data_masker):
Expand All @@ -192,7 +192,7 @@ def test_encrypt_str(data_masker):
decrypted_data = data_masker.decrypt(encrypted_data)

# THEN the result is the original input data
assert decrypted_data == str("this is a string")
assert decrypted_data == "this is a string"


def test_encrypt_list(data_masker):
Expand All @@ -203,7 +203,7 @@ def test_encrypt_list(data_masker):
decrypted_data = data_masker.decrypt(encrypted_data)

# THEN the result is the original input data
assert decrypted_data == str([1, 2, "a string", 3.4])
assert decrypted_data == [1, 2, "a string", 3.4]


def test_encrypt_dict(data_masker):
Expand All @@ -221,7 +221,7 @@ def test_encrypt_dict(data_masker):
decrypted_data = data_masker.decrypt(encrypted_data)

# THEN the result is the original input data
assert decrypted_data == str(data)
assert decrypted_data == data


def test_encrypt_dict_with_fields(data_masker):
Expand Down

0 comments on commit b3d123d

Please sign in to comment.