Skip to content

Commit

Permalink
Added KMS permissions to lambda handler for e2e tests
Browse files Browse the repository at this point in the history
  • Loading branch information
seshubaws committed Sep 13, 2023
1 parent c5233af commit bcc735a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
3 changes: 2 additions & 1 deletion tests/e2e/data_masking/handlers/basic_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def lambda_handler(event, context):
logger.append_keys(**append_keys)
logger.info(message)

kms_key = event.get("kms_key")
# Encrypting data for test_encryption_in_handler test
kms_key = event.get("kms_key", "")
data_masker = DataMasking(provider=AwsEncryptionSdkProvider(keys=[kms_key]))
value = [1, 2, "string", 4.5]
encrypted_data = data_masker.encrypt(value)
Expand Down
7 changes: 6 additions & 1 deletion tests/e2e/data_masking/infrastructure.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import aws_cdk.aws_kms as kms
from aws_cdk import CfnOutput, Duration
from aws_cdk import aws_iam as iam

from tests.e2e.utils.infrastructure import BaseInfrastructure


class DataMaskingStack(BaseInfrastructure):
def create_resources(self):
self.create_lambda_functions(function_props={"timeout": Duration.seconds(10)})
functions = self.create_lambda_functions(function_props={"timeout": Duration.seconds(10)})

key1 = kms.Key(self.stack, "MyKMSKey1", description="My KMS Key1")
CfnOutput(self.stack, "KMSKey1Arn", value=key1.key_arn, description="ARN of the created KMS Key1")

key2 = kms.Key(self.stack, "MyKMSKey2", description="My KMS Key2")
CfnOutput(self.stack, "KMSKey2Arn", value=key2.key_arn, description="ARN of the created KMS Key2")

functions["BasicHandler"].add_to_role_policy(
iam.PolicyStatement(effect=iam.Effect.ALLOW, actions=["kms:*"], resources=[key1.key_arn, key2.key_arn]),
)
23 changes: 12 additions & 11 deletions tests/e2e/data_masking/test_data_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_encryption(data_masker):
decrypted_data = data_masker.decrypt(encrypted_data)

# THEN the result is the original input data
assert decrypted_data == str(value)
assert decrypted_data == value


@pytest.mark.xdist_group(name="data_masking")
Expand All @@ -64,7 +64,7 @@ def test_encryption_context(data_masker):
decrypted_data = data_masker.decrypt(encrypted_data, encryption_context=context)

# THEN the result is the original input data
assert decrypted_data == str(value)
assert decrypted_data == value


@pytest.mark.xdist_group(name="data_masking")
Expand Down Expand Up @@ -121,14 +121,14 @@ def test_encryption_provider_singleton(data_masker, kms_key1_arn, kms_key2_arn):
decrypted_data = data_masker_2.decrypt(encrypted_data)

# THEN the result is the original input data
assert decrypted_data == str(value)
assert decrypted_data == value

data_masker_3 = DataMasking(provider=AwsEncryptionSdkProvider(keys=[kms_key2_arn]))
assert data_masker_2.provider is not data_masker_3.provider


@pytest.mark.xdist_group(name="data_masking")
def test_encryption_in_logs(data_masker, basic_handler_fn, basic_handler_fn_arn):
def test_encryption_in_logs(data_masker, basic_handler_fn, basic_handler_fn_arn, kms_key1_arn):
# GIVEN an instantiation of DataMasking with the AWS encryption provider

# WHEN encrypting a value and logging it
Expand All @@ -137,7 +137,7 @@ def test_encryption_in_logs(data_masker, basic_handler_fn, basic_handler_fn_arn)
message = encrypted_data
custom_key = "order_id"
additional_keys = {custom_key: f"{uuid4()}"}
payload = json.dumps({"message": message, "append_keys": additional_keys})
payload = json.dumps({"message": message, "kms_key": kms_key1_arn, "append_keys": additional_keys})

_, execution_time = data_fetcher.get_lambda_response(lambda_arn=basic_handler_fn_arn, payload=payload)
data_fetcher.get_lambda_response(lambda_arn=basic_handler_fn_arn, payload=payload)
Expand All @@ -148,20 +148,21 @@ def test_encryption_in_logs(data_masker, basic_handler_fn, basic_handler_fn_arn)
for log in logs.get_log(key=custom_key):
encrypted_data = log.message
decrypted_data = data_masker.decrypt(encrypted_data)
assert decrypted_data == str(value)
assert decrypted_data == value


# NOTE: This test is failing currently, need to find a fix for building correct dependencies
@pytest.mark.xdist_group(name="data_masking")
def test_encryption_in_handler(basic_handler_fn_arn, kms_key1_arn):
payload = {"kms_key": kms_key1_arn, "append_keys": {"order_id": f"{uuid4()}"}}
def test_encryption_in_handler(data_masker, basic_handler_fn_arn, kms_key1_arn):
# GIVEN a lambda_handler with an instantiation the AWS encryption provider data masker

# WHEN a lambda handler for encryption is invoked
payload = {"kms_key": kms_key1_arn}

# WHEN the handler is invoked to encrypt data
handler_result, _ = data_fetcher.get_lambda_response(lambda_arn=basic_handler_fn_arn, payload=json.dumps(payload))

response = json.loads(handler_result["Payload"].read())
encrypted_data = response["encrypted_data"]
decrypted_data = data_masker.decrypt(encrypted_data)

# THEN decrypting the encrypted data from the response should result in the original value
assert decrypted_data == str([1, 2, "string", 4.5])
assert decrypted_data == [1, 2, "string", 4.5]

0 comments on commit bcc735a

Please sign in to comment.