From 7b06fa75a4e2d89fd33f50d324ca2786c095e92d Mon Sep 17 00:00:00 2001 From: Moritz Mack Date: Thu, 13 Jan 2022 14:28:40 +0100 Subject: [PATCH] [BEAM-8807] Add integration test for SnsIO.write (Sdk v1 & v2) --- .../io/aws/dynamodb/AwsClientsProvider.java | 8 +- .../aws/dynamodb/BasicDynamoDBProvider.java | 15 -- .../sdk/io/aws/sns/AwsClientsProvider.java | 8 +- .../beam/sdk/io/aws/sns/BasicSnsProvider.java | 16 +- .../apache/beam/sdk/io/aws/ITEnvironment.java | 31 +--- .../dynamodb/StaticAwsClientsProvider.java | 6 - .../beam/sdk/io/aws/s3/S3FileSystemIT.java | 12 +- .../apache/beam/sdk/io/aws/sns/SnsIOIT.java | 159 ++++++++++++++++++ .../apache/beam/sdk/io/aws/sns/SnsIOTest.java | 6 - .../beam/sdk/io/aws2/ITEnvironment.java | 31 +--- .../apache/beam/sdk/io/aws2/sns/SnsIOIT.java | 158 +++++++++++++++++ 11 files changed, 359 insertions(+), 91 deletions(-) create mode 100644 sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOIT.java create mode 100644 sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sns/SnsIOIT.java diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AwsClientsProvider.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AwsClientsProvider.java index 8d1c267b404b..e98e633c2ca4 100644 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AwsClientsProvider.java +++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/AwsClientsProvider.java @@ -28,7 +28,13 @@ * ensure it can be sent to worker machines. */ public interface AwsClientsProvider extends Serializable { - AmazonCloudWatch getCloudWatchClient(); + + /** @deprecated DynamoDBIO doesn't require a CloudWatch client */ + @Deprecated + @SuppressWarnings("return.type.incompatible") + default AmazonCloudWatch getCloudWatchClient() { + return null; + } AmazonDynamoDB createDynamoDB(); } diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/BasicDynamoDBProvider.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/BasicDynamoDBProvider.java index bc345618b766..b868b9e43cef 100644 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/BasicDynamoDBProvider.java +++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/dynamodb/BasicDynamoDBProvider.java @@ -24,8 +24,6 @@ import com.amazonaws.auth.BasicAWSCredentials; import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; import com.amazonaws.regions.Regions; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder; import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder; import org.checkerframework.checker.nullness.qual.Nullable; @@ -52,19 +50,6 @@ private AWSCredentialsProvider getCredentialsProvider() { return new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)); } - @Override - public AmazonCloudWatch getCloudWatchClient() { - AmazonCloudWatchClientBuilder clientBuilder = - AmazonCloudWatchClientBuilder.standard().withCredentials(getCredentialsProvider()); - if (serviceEndpoint == null) { - clientBuilder.withRegion(region); - } else { - clientBuilder.withEndpointConfiguration( - new EndpointConfiguration(serviceEndpoint, region.getName())); - } - return clientBuilder.build(); - } - @Override public AmazonDynamoDB createDynamoDB() { AmazonDynamoDBClientBuilder clientBuilder = diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/AwsClientsProvider.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/AwsClientsProvider.java index dd11daa64315..6582b510b089 100644 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/AwsClientsProvider.java +++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/AwsClientsProvider.java @@ -28,7 +28,13 @@ * ensure it can be sent to worker machines. */ public interface AwsClientsProvider extends Serializable { - AmazonCloudWatch getCloudWatchClient(); + + /** @deprecated SnsIO doesn't require a CloudWatch client */ + @Deprecated + @SuppressWarnings("return.type.incompatible") + default AmazonCloudWatch getCloudWatchClient() { + return null; + } AmazonSNS createSnsPublisher(); } diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/BasicSnsProvider.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/BasicSnsProvider.java index 8f380bd4f993..10599ddcadb9 100644 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/BasicSnsProvider.java +++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/BasicSnsProvider.java @@ -24,8 +24,6 @@ import com.amazonaws.auth.BasicAWSCredentials; import com.amazonaws.client.builder.AwsClientBuilder; import com.amazonaws.regions.Regions; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; -import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder; import com.amazonaws.services.sns.AmazonSNS; import com.amazonaws.services.sns.AmazonSNSClientBuilder; import org.checkerframework.checker.nullness.qual.Nullable; @@ -54,9 +52,9 @@ private AWSCredentialsProvider getCredentialsProvider() { } @Override - public AmazonCloudWatch getCloudWatchClient() { - AmazonCloudWatchClientBuilder clientBuilder = - AmazonCloudWatchClientBuilder.standard().withCredentials(getCredentialsProvider()); + public AmazonSNS createSnsPublisher() { + AmazonSNSClientBuilder clientBuilder = + AmazonSNSClientBuilder.standard().withCredentials(getCredentialsProvider()); if (serviceEndpoint == null) { clientBuilder.withRegion(region); } else { @@ -65,12 +63,4 @@ public AmazonCloudWatch getCloudWatchClient() { } return clientBuilder.build(); } - - @Override - public AmazonSNS createSnsPublisher() { - return AmazonSNSClientBuilder.standard() - .withCredentials(getCredentialsProvider()) - .withRegion(region) - .build(); - } } diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/ITEnvironment.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/ITEnvironment.java index 9e95d72e22e1..e05e7244db6d 100644 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/ITEnvironment.java +++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/ITEnvironment.java @@ -18,10 +18,10 @@ package org.apache.beam.sdk.io.aws; import static org.apache.beam.sdk.testing.TestPipeline.testingPipelineOptions; +import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3; import com.amazonaws.client.builder.AwsClientBuilder; import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; -import java.util.function.Consumer; import org.apache.beam.sdk.io.aws.options.AwsOptions; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; @@ -29,6 +29,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.commons.lang3.StringUtils; import org.junit.rules.ExternalResource; import org.slf4j.LoggerFactory; import org.testcontainers.containers.localstack.LocalStackContainer; @@ -76,43 +77,28 @@ public interface ITOptions extends AwsOptions, TestPipelineOptions { void setLocalstackLogLevel(String level); } - private final Service service; private final OptionsT options; private final LocalStackContainer localstack; - public ITEnvironment(Service service, Class optionsClass) { - this(service, optionsClass, o -> {}); - } - - public ITEnvironment( - Service service, Class optionsClass, Consumer optionsMutator) { - this(service, optionsClass, optionsMutator, new String[0]); - } - public ITEnvironment(Service service, Class optionsClass, String... env) { - this(service, optionsClass, o -> {}, env); + this(new Service[] {service}, optionsClass, env); } - public ITEnvironment( - Service service, - Class optionsClass, - Consumer optionsMutator, - String... env) { - this.service = service; + public ITEnvironment(Service[] services, Class optionsClass, String... env) { localstack = new LocalStackContainer(DockerImageName.parse(LOCALSTACK).withTag(LOCALSTACK_VERSION)) - .withServices(service) + .withServices(services) .withStartupAttempts(3); PipelineOptionsFactory.register(optionsClass); options = testingPipelineOptions().as(optionsClass); - optionsMutator.accept(options); localstack.setEnv(ImmutableList.copyOf(env)); if (options.getLocalstackLogLevel() != null) { localstack .withEnv("LS_LOG", options.getLocalstackLogLevel()) - .withLogConsumer(new Slf4jLogConsumer(LoggerFactory.getLogger(service.name()))); + .withLogConsumer( + new Slf4jLogConsumer(LoggerFactory.getLogger(StringUtils.join(services)))); } } @@ -150,7 +136,8 @@ protected void after() { /** Necessary setup for localstack environment. */ private void startLocalstack() { localstack.start(); - options.setAwsServiceEndpoint(localstack.getEndpointOverride(service).toString()); + options.setAwsServiceEndpoint( + localstack.getEndpointOverride(S3).toString()); // service irrelevant options.setAwsRegion(localstack.getRegion()); options.setAwsCredentialsProvider(localstack.getDefaultCredentialsProvider()); } diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/StaticAwsClientsProvider.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/StaticAwsClientsProvider.java index ed2c080253c2..d3f676cf1096 100644 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/StaticAwsClientsProvider.java +++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/dynamodb/StaticAwsClientsProvider.java @@ -19,7 +19,6 @@ import static java.util.Collections.synchronizedMap; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; import java.util.HashMap; import java.util.Map; @@ -42,11 +41,6 @@ static AwsClientsProvider of(AmazonDynamoDB client) { return provider; } - @Override - public AmazonCloudWatch getCloudWatchClient() { - return null; // never used - } - @Override public AmazonDynamoDB createDynamoDB() { return clients.get(id); diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemIT.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemIT.java index e1518b2a57fd..112ab95463b4 100644 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemIT.java +++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemIT.java @@ -67,7 +67,13 @@ public interface S3ITOptions extends ITEnvironment.ITOptions, S3Options {} @ClassRule public static ITEnvironment env = - new ITEnvironment<>(S3, S3ITOptions.class, S3ClientFixFix::set); + new ITEnvironment(S3, S3ITOptions.class) { + @Override + protected void before() { + super.before(); + options().setS3ClientFactoryClass(S3ClientFixFix.class); + } + }; @Rule public TestPipeline pipelineWrite = env.createTestPipeline(); @Rule public TestPipeline pipelineRead = env.createTestPipeline(); @@ -111,10 +117,6 @@ protected void before() { // Fix duplicated Content-Length header due to case-sensitive handling of header names // https://github.com/aws/aws-sdk-java/issues/2503 private static class S3ClientFixFix extends DefaultS3ClientBuilderFactory { - private static void set(S3Options s3Options) { - s3Options.setS3ClientFactoryClass(S3ClientFixFix.class); - } - @Override public AmazonS3ClientBuilder createBuilder(S3Options s3Options) { return super.createBuilder(s3Options) diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOIT.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOIT.java new file mode 100644 index 000000000000..c19aada628fa --- /dev/null +++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOIT.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws.sns; + +import static org.apache.beam.sdk.io.common.IOITHelper.executeWithRetry; +import static org.apache.beam.sdk.io.common.TestRow.getExpectedHashForRowCount; +import static org.apache.beam.sdk.values.TypeDescriptors.strings; +import static org.testcontainers.containers.localstack.LocalStackContainer.Service.SNS; +import static org.testcontainers.containers.localstack.LocalStackContainer.Service.SQS; + +import com.amazonaws.regions.Regions; +import com.amazonaws.services.sns.AmazonSNS; +import com.amazonaws.services.sns.AmazonSNSClientBuilder; +import com.amazonaws.services.sns.model.PublishRequest; +import com.amazonaws.services.sqs.AmazonSQS; +import com.amazonaws.services.sqs.AmazonSQSClientBuilder; +import com.amazonaws.services.sqs.model.Message; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.Serializable; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.io.aws.ITEnvironment; +import org.apache.beam.sdk.io.aws.sqs.SqsIO; +import org.apache.beam.sdk.io.common.HashingFn; +import org.apache.beam.sdk.io.common.TestRow; +import org.apache.beam.sdk.io.common.TestRow.DeterministicallyConstructTestRowFn; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExternalResource; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.testcontainers.containers.localstack.LocalStackContainer.Service; + +@RunWith(JUnit4.class) +public class SnsIOIT { + public interface ITOptions extends ITEnvironment.ITOptions {} + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final TypeDescriptor publishRequests = + TypeDescriptor.of(PublishRequest.class); + + @ClassRule + public static ITEnvironment env = + new ITEnvironment<>(new Service[] {SQS, SNS}, ITOptions.class, "SQS_PROVIDER=elasticmq"); + + @Rule public Timeout globalTimeout = Timeout.seconds(600); + + @Rule public TestPipeline pipelineWrite = env.createTestPipeline(); + @Rule public TestPipeline pipelineRead = env.createTestPipeline(); + @Rule public AwsResources resources = new AwsResources(); + + @Test + public void testWriteThenRead() { + ITOptions opts = env.options(); + int rows = opts.getNumberOfRows(); + + // Write test dataset to SNS + + pipelineWrite + .apply("Generate Sequence", GenerateSequence.from(0).to(rows)) + .apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn())) + .apply("SNS request", MapElements.into(publishRequests).via(resources::publishRequest)) + .apply( + "Write to SNS", + SnsIO.write() + .withTopicName(resources.snsTopic) + .withResultOutputTag(new TupleTag<>()) + .withAWSClientsProvider( + opts.getAwsCredentialsProvider().getCredentials().getAWSAccessKeyId(), + opts.getAwsCredentialsProvider().getCredentials().getAWSSecretKey(), + Regions.fromName(opts.getAwsRegion()), + opts.getAwsServiceEndpoint())); + + // Read test dataset from SQS. + PCollection output = + pipelineRead + .apply( + "Read from SQS", + SqsIO.read().withQueueUrl(resources.sqsQueue).withMaxNumRecords(rows)) + .apply("Extract message", MapElements.into(strings()).via(SnsIOIT::extractMessage)); + + PAssert.thatSingleton(output.apply("Count All", Count.globally())).isEqualTo((long) rows); + + PAssert.that(output.apply(Combine.globally(new HashingFn()).withoutDefaults())) + .containsInAnyOrder(getExpectedHashForRowCount(rows)); + + pipelineWrite.run(); + pipelineRead.run(); + } + + private static String extractMessage(Message msg) { + try { + return MAPPER.readTree(msg.getBody()).get("Message").asText(); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private static class AwsResources extends ExternalResource implements Serializable { + private transient AmazonSQS sqs = env.buildClient(AmazonSQSClientBuilder.standard()); + private transient AmazonSNS sns = env.buildClient(AmazonSNSClientBuilder.standard()); + + private String sqsQueue; + private String snsTopic; + private String sns2Sqs; + + PublishRequest publishRequest(TestRow r) { + return new PublishRequest(snsTopic, r.name()); + } + + @Override + protected void before() throws Throwable { + snsTopic = sns.createTopic("beam-snsio-it").getTopicArn(); + // add SQS subscription so we can read the messages again + sqsQueue = sqs.createQueue("beam-snsio-it").getQueueUrl(); + sns2Sqs = sns.subscribe(snsTopic, "sqs", sqsQueue).getSubscriptionArn(); + } + + @Override + protected void after() { + try { + executeWithRetry(() -> sns.unsubscribe(sns2Sqs)); + executeWithRetry(() -> sns.deleteTopic(snsTopic)); + executeWithRetry(() -> sqs.deleteQueue(sqsQueue)); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + sns.shutdown(); + sqs.shutdown(); + } + } + } +} diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java index bbd5a47c4a18..2a50f008d924 100644 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java +++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java @@ -22,7 +22,6 @@ import static org.joda.time.Duration.standardSeconds; import com.amazonaws.http.SdkHttpMetadata; -import com.amazonaws.services.cloudwatch.AmazonCloudWatch; import com.amazonaws.services.sns.AmazonSNS; import com.amazonaws.services.sns.model.GetTopicAttributesResult; import com.amazonaws.services.sns.model.InternalErrorException; @@ -79,11 +78,6 @@ public Provider(AmazonSNS pub) { publisher = pub; } - @Override - public AmazonCloudWatch getCloudWatchClient() { - return Mockito.mock(AmazonCloudWatch.class); - } - @Override public AmazonSNS createSnsPublisher() { return publisher; diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/ITEnvironment.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/ITEnvironment.java index 7348bf91164d..b88151ac2b3b 100644 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/ITEnvironment.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/ITEnvironment.java @@ -18,9 +18,9 @@ package org.apache.beam.sdk.io.aws2; import static org.apache.beam.sdk.testing.TestPipeline.testingPipelineOptions; +import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3; import java.net.URI; -import java.util.function.Consumer; import org.apache.beam.sdk.io.aws2.options.AwsOptions; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; @@ -28,6 +28,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.commons.lang3.StringUtils; import org.junit.rules.ExternalResource; import org.slf4j.LoggerFactory; import org.testcontainers.containers.localstack.LocalStackContainer; @@ -79,43 +80,29 @@ public interface ITOptions extends AwsOptions, TestPipelineOptions { void setLocalstackLogLevel(String level); } - private final Service service; private final OptionsT options; private final LocalStackContainer localstack; - public ITEnvironment(Service service, Class optionsClass) { - this(service, optionsClass, o -> {}); - } - - public ITEnvironment( - Service service, Class optionsClass, Consumer optionsMutator) { - this(service, optionsClass, optionsMutator, new String[0]); - } - public ITEnvironment(Service service, Class optionsClass, String... env) { - this(service, optionsClass, o -> {}, env); + this(new Service[] {service}, optionsClass, env); } - public ITEnvironment( - Service service, - Class optionsClass, - Consumer optionsMutator, - String... env) { - this.service = service; + public ITEnvironment(Service[] services, Class optionsClass, String... env) { localstack = new LocalStackContainer(DockerImageName.parse(LOCALSTACK).withTag(LOCALSTACK_VERSION)) - .withServices(service) + .withServices(services) .withStartupAttempts(3); PipelineOptionsFactory.register(optionsClass); options = testingPipelineOptions().as(optionsClass); - optionsMutator.accept(options); localstack.setEnv(ImmutableList.copyOf(env)); + if (options.getLocalstackLogLevel() != null) { localstack .withEnv("LS_LOG", options.getLocalstackLogLevel()) - .withLogConsumer(new Slf4jLogConsumer(LoggerFactory.getLogger(service.name()))); + .withLogConsumer( + new Slf4jLogConsumer(LoggerFactory.getLogger(StringUtils.join(services)))); } } @@ -153,7 +140,7 @@ protected void after() { /** Necessary setup for localstack environment. */ private void startLocalstack() { localstack.start(); - options.setEndpoint(localstack.getEndpointOverride(service).toString()); + options.setEndpoint(localstack.getEndpointOverride(S3).toString()); // service irrelevant options.setAwsRegion(localstack.getRegion()); options.setAwsCredentialsProvider( StaticCredentialsProvider.create( diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sns/SnsIOIT.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sns/SnsIOIT.java new file mode 100644 index 000000000000..b08fd55fb34c --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sns/SnsIOIT.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sns; + +import static org.apache.beam.sdk.io.common.IOITHelper.executeWithRetry; +import static org.apache.beam.sdk.io.common.TestRow.getExpectedHashForRowCount; +import static org.apache.beam.sdk.values.TypeDescriptors.strings; +import static org.testcontainers.containers.localstack.LocalStackContainer.Service.SNS; +import static org.testcontainers.containers.localstack.LocalStackContainer.Service.SQS; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.Serializable; +import java.net.URI; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.io.aws2.ITEnvironment; +import org.apache.beam.sdk.io.aws2.sqs.SqsIO; +import org.apache.beam.sdk.io.aws2.sqs.SqsMessage; +import org.apache.beam.sdk.io.common.HashingFn; +import org.apache.beam.sdk.io.common.TestRow; +import org.apache.beam.sdk.io.common.TestRow.DeterministicallyConstructTestRowFn; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExternalResource; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.testcontainers.containers.localstack.LocalStackContainer.Service; +import software.amazon.awssdk.services.sns.SnsClient; +import software.amazon.awssdk.services.sns.model.PublishRequest; +import software.amazon.awssdk.services.sqs.SqsClient; + +@RunWith(JUnit4.class) +public class SnsIOIT { + public interface ITOptions extends ITEnvironment.ITOptions {} + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @ClassRule + public static ITEnvironment env = + new ITEnvironment<>(new Service[] {SQS, SNS}, ITOptions.class, "SQS_PROVIDER=elasticmq"); + + @Rule public Timeout globalTimeout = Timeout.seconds(600); + + @Rule public TestPipeline pipelineWrite = env.createTestPipeline(); + @Rule public TestPipeline pipelineRead = env.createTestPipeline(); + @Rule public AwsResources resources = new AwsResources(); + + @Test + public void testWriteThenRead() { + ITOptions opts = env.options(); + int rows = opts.getNumberOfRows(); + + // Write test dataset to SNS + pipelineWrite + .apply("Generate Sequence", GenerateSequence.from(0).to(rows)) + .apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn())) + .apply( + "Write to SNS", + SnsIO.write() + .withTopicArn(resources.snsTopic) + .withPublishRequestBuilder(r -> PublishRequest.builder().message(r.name())) + .withSnsClientProvider( + opts.getAwsCredentialsProvider(), + opts.getAwsRegion(), + toURI(opts.getEndpoint()))); + + // Read test dataset from SQS. + PCollection output = + pipelineRead + .apply( + "Read from SQS", + SqsIO.read() + .withQueueUrl(resources.sqsQueue) + .withMaxNumRecords(rows) + .withSqsClientProvider( + opts.getAwsCredentialsProvider(), + opts.getAwsRegion(), + toURI(opts.getEndpoint()))) + .apply("Extract message", MapElements.into(strings()).via(SnsIOIT::extractMessage)); + + PAssert.thatSingleton(output.apply("Count All", Count.globally())).isEqualTo((long) rows); + + PAssert.that(output.apply(Combine.globally(new HashingFn()).withoutDefaults())) + .containsInAnyOrder(getExpectedHashForRowCount(rows)); + + pipelineWrite.run(); + pipelineRead.run(); + } + + private static String extractMessage(SqsMessage msg) { + try { + return MAPPER.readTree(msg.getBody()).get("Message").asText(); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private static class AwsResources extends ExternalResource implements Serializable { + private transient SqsClient sqs = env.buildClient(SqsClient.builder()); + private transient SnsClient sns = env.buildClient(SnsClient.builder()); + + private String sqsQueue; + private String snsTopic; + private String sns2Sqs; + + @Override + protected void before() throws Throwable { + snsTopic = sns.createTopic(b -> b.name("beam-snsio-it")).topicArn(); + // add SQS subscription so we can read the messages again + sqsQueue = sqs.createQueue(b -> b.queueName("beam-snsio-it")).queueUrl(); + sns2Sqs = + sns.subscribe(b -> b.topicArn(snsTopic).endpoint(sqsQueue).protocol("sqs")) + .subscriptionArn(); + } + + @Override + protected void after() { + try { + executeWithRetry(() -> sns.unsubscribe(b -> b.subscriptionArn(sns2Sqs))); + executeWithRetry(() -> sns.deleteTopic(b -> b.topicArn(snsTopic))); + executeWithRetry(() -> sqs.deleteQueue(b -> b.queueUrl(sqsQueue))); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + sns.close(); + sqs.close(); + } + } + } + + private static URI toURI(String uri) { + return uri == null ? null : URI.create(uri); + } +}