diff --git a/sdks/java/io/rrio/build.gradle b/sdks/java/io/rrio/build.gradle index 6963fcb23ddf..52119c91b47e 100644 --- a/sdks/java/io/rrio/build.gradle +++ b/sdks/java/io/rrio/build.gradle @@ -30,6 +30,8 @@ dependencies { implementation project(path: ":sdks:java:core", configuration: "shadow") implementation library.java.joda_time implementation library.java.vendored_guava_32_1_2_jre + implementation library.java.jackson_core + implementation library.java.jackson_databind implementation "redis.clients:jedis:$jedisVersion" testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/ApiIOError.java b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/ApiIOError.java index 5936c5dd84b0..cfff3bd89414 100644 --- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/ApiIOError.java +++ b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/ApiIOError.java @@ -17,11 +17,16 @@ */ package org.apache.beam.io.requestresponse; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.auto.value.AutoValue; +import java.util.Optional; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.Instant; /** {@link ApiIOError} is a data class for storing details about an error. */ @@ -30,12 +35,31 @@ @AutoValue public abstract class ApiIOError { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + /** + * Instantiate an {@link ApiIOError} from an {@link ErrorT} {@link T} element. The {@link T} + * element is converted to a JSON string. + */ + static ApiIOError of(@NonNull ErrorT e, @NonNull T element) + throws JsonProcessingException { + + String json = OBJECT_MAPPER.writeValueAsString(element); + + return ApiIOError.builder() + .setRequestAsJsonString(json) + .setMessage(Optional.ofNullable(e.getMessage()).orElse("")) + .setObservedTimestamp(Instant.now()) + .setStackTrace(Throwables.getStackTraceAsString(e)) + .build(); + } + static Builder builder() { return new AutoValue_ApiIOError.Builder(); } - /** The encoded UTF-8 string representation of the related processed element. */ - public abstract String getEncodedElementAsUtfString(); + /** The JSON string representation of the request associated with the error. */ + public abstract String getRequestAsJsonString(); /** The observed timestamp of the error. */ public abstract Instant getObservedTimestamp(); @@ -49,13 +73,13 @@ static Builder builder() { @AutoValue.Builder abstract static class Builder { - public abstract Builder setEncodedElementAsUtfString(String value); + abstract Builder setRequestAsJsonString(String value); - public abstract Builder setObservedTimestamp(Instant value); + abstract Builder setObservedTimestamp(Instant value); - public abstract Builder setMessage(String value); + abstract Builder setMessage(String value); - public abstract Builder setStackTrace(String value); + abstract Builder setStackTrace(String value); abstract ApiIOError build(); } diff --git a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java index 4f854ea69c7e..52181af534ed 100644 --- a/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java +++ b/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse/Call.java @@ -17,55 +17,250 @@ */ package org.apache.beam.io.requestresponse; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import com.fasterxml.jackson.core.JsonProcessingException; import com.google.auto.value.AutoValue; +import java.io.Serializable; import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import org.apache.beam.io.requestresponse.Call.Result; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.joda.time.Duration; /** * {@link Call} transforms a {@link RequestT} {@link PCollection} into a {@link ResponseT} {@link * PCollection} and {@link ApiIOError} {@link PCollection}, both wrapped in a {@link Result}. */ -class Call extends PTransform, Result> { +class Call + extends PTransform<@NonNull PCollection, @NonNull Result> { + + /** + * The default {@link Duration} to wait until completion of user code. A {@link + * UserCodeTimeoutException} is thrown when {@link Caller#call}, {@link SetupTeardown#setup}, or + * {@link SetupTeardown#teardown} exceed this timeout. + */ + static final Duration DEFAULT_TIMEOUT = Duration.standardSeconds(30L); + + /** + * Instantiates a {@link Call} {@link PTransform} with the required {@link Caller} and {@link + * ResponseT} {@link Coder}. Checks for the {@link Caller}'s {@link + * SerializableUtils#ensureSerializable} serializable errors. + */ + static Call of( + Caller caller, Coder responseTCoder) { + caller = SerializableUtils.ensureSerializable(caller); + return new Call<>( + Configuration.builder() + .setCaller(caller) + .setResponseCoder(responseTCoder) + .build()); + } + + /** + * Instantiates a {@link Call} {@link PTransform} with an implementation of both the {@link + * Caller} and {@link SetupTeardown} in one class and the required {@link ResponseT} {@link + * Coder}. Checks for {@link SerializableUtils#ensureSerializable} to report serializable errors. + */ + static < + RequestT, + ResponseT, + CallerSetupTeardownT extends Caller & SetupTeardown> + Call ofCallerAndSetupTeardown( + CallerSetupTeardownT implementsCallerAndSetupTeardown, Coder responseTCoder) { + implementsCallerAndSetupTeardown = + SerializableUtils.ensureSerializable(implementsCallerAndSetupTeardown); + return new Call<>( + Configuration.builder() + .setCaller(implementsCallerAndSetupTeardown) + .setResponseCoder(responseTCoder) + .setSetupTeardown(implementsCallerAndSetupTeardown) + .build()); + } private static final TupleTag FAILURE_TAG = new TupleTag() {}; - // TODO(damondouglas): remove suppress warnings when configuration utilized in future PR. - @SuppressWarnings({"unused"}) private final Configuration configuration; private Call(Configuration configuration) { this.configuration = configuration; } + /** + * Sets the {@link SetupTeardown} to the {@link Call} {@link PTransform} instance. Checks for + * {@link SerializableUtils#ensureSerializable} serializable errors. + */ + Call withSetupTeardown(SetupTeardown setupTeardown) { + setupTeardown = SerializableUtils.ensureSerializable(setupTeardown); + return new Call<>(configuration.toBuilder().setSetupTeardown(setupTeardown).build()); + } + + /** + * Overrides the default {@link #DEFAULT_TIMEOUT}. A {@link UserCodeTimeoutException} is thrown + * when {@link Caller#call}, {@link SetupTeardown#setup}, or {@link SetupTeardown#teardown} exceed + * the timeout. + */ + Call withTimeout(Duration timeout) { + return new Call<>(configuration.toBuilder().setTimeout(timeout).build()); + } + + @Override + public @NonNull Result expand(PCollection input) { + TupleTag responseTag = new TupleTag() {}; + + PCollectionTuple pct = + input.apply( + CallFn.class.getSimpleName(), + ParDo.of(new CallFn<>(responseTag, configuration)) + .withOutputTags(responseTag, TupleTagList.of(FAILURE_TAG))); + + return Result.of(configuration.getResponseCoder(), responseTag, pct); + } + + private static class CallFn extends DoFn { + private final TupleTag responseTag; + private final CallerWithTimeout caller; + private final SetupTeardownWithTimeout setupTeardown; + + private transient @MonotonicNonNull ExecutorService executor; + + private CallFn( + TupleTag responseTag, Configuration configuration) { + this.responseTag = responseTag; + this.caller = new CallerWithTimeout<>(configuration.getTimeout(), configuration.getCaller()); + this.setupTeardown = + new SetupTeardownWithTimeout( + configuration.getTimeout(), configuration.getSetupTeardown()); + } + + /** + * Invokes {@link SetupTeardown#setup} forwarding its {@link UserCodeExecutionException}, if + * thrown. + */ + @Setup + public void setup() throws UserCodeExecutionException { + this.executor = Executors.newSingleThreadExecutor(); + this.caller.setExecutor(executor); + this.setupTeardown.setExecutor(executor); + + // TODO(damondouglas): Incorporate repeater when https://github.com/apache/beam/issues/28926 + // resolves. + this.setupTeardown.setup(); + } + + /** + * Invokes {@link SetupTeardown#teardown} forwarding its {@link UserCodeExecutionException}, if + * thrown. + */ + @Teardown + public void teardown() throws UserCodeExecutionException { + // TODO(damondouglas): Incorporate repeater when https://github.com/apache/beam/issues/28926 + // resolves. + this.setupTeardown.teardown(); + checkStateNotNull(executor).shutdown(); + try { + boolean ignored = executor.awaitTermination(3L, TimeUnit.SECONDS); + } catch (InterruptedException ignored) { + } + } + + @ProcessElement + public void process(@Element @NonNull RequestT request, MultiOutputReceiver receiver) + throws JsonProcessingException { + try { + // TODO(damondouglas): https://github.com/apache/beam/issues/29248 + ResponseT response = this.caller.call(request); + receiver.get(responseTag).output(response); + } catch (UserCodeExecutionException e) { + receiver.get(FAILURE_TAG).output(ApiIOError.of(e, request)); + } + } + } + /** Configuration details for {@link Call}. */ @AutoValue - abstract static class Configuration { + abstract static class Configuration implements Serializable { static Builder builder() { return new AutoValue_Call_Configuration.Builder<>(); } + /** The user custom code that converts a {@link RequestT} into a {@link ResponseT}. */ + abstract Caller getCaller(); + + /** The user custom code that implements setup and teardown methods. */ + abstract SetupTeardown getSetupTeardown(); + + /** + * The expected timeout of all user custom code. If user custom code exceeds this timeout, then + * a {@link UserCodeTimeoutException} is thrown. User custom code may throw this exception prior + * to the configured timeout value on their own. + */ + abstract Duration getTimeout(); + + /** + * The {@link Coder} for the {@link ResponseT}. Note that the {@link RequestT}'s {@link Coder} + * is derived from the input {@link PCollection} but can't be determined for the {@link + * ResponseT} and therefore requires explicit setting in the {@link Configuration}. + */ + abstract Coder getResponseCoder(); + abstract Builder toBuilder(); @AutoValue.Builder abstract static class Builder { - abstract Configuration build(); - } - } + /** See {@link #getCaller()}. */ + abstract Builder setCaller(Caller value); - @Override - public Result expand(PCollection input) { - return Result.of(new TupleTag() {}, PCollectionTuple.empty(input.getPipeline())); + /** See {@link #getSetupTeardown()}. */ + abstract Builder setSetupTeardown(SetupTeardown value); + + abstract Optional getSetupTeardown(); + + /** See {@link #getTimeout()}. */ + abstract Builder setTimeout(Duration value); + + abstract Optional getTimeout(); + + abstract Builder setResponseCoder(Coder value); + + abstract Configuration autoBuild(); + + final Configuration build() { + if (!getSetupTeardown().isPresent()) { + setSetupTeardown(new NoopSetupTeardown()); + } + + if (!getTimeout().isPresent()) { + setTimeout(DEFAULT_TIMEOUT); + } + + return autoBuild(); + } + } } /** @@ -73,8 +268,9 @@ public Result expand(PCollection input) { */ static class Result implements POutput { - static Result of(TupleTag responseTag, PCollectionTuple pct) { - return new Result<>(responseTag, pct); + static Result of( + Coder responseTCoder, TupleTag responseTag, PCollectionTuple pct) { + return new Result<>(responseTCoder, responseTag, pct); } private final Pipeline pipeline; @@ -82,10 +278,11 @@ static Result of(TupleTag responseTag, PCollec private final PCollection responses; private final PCollection failures; - private Result(TupleTag responseTag, PCollectionTuple pct) { + private Result( + Coder responseTCoder, TupleTag responseTag, PCollectionTuple pct) { this.pipeline = pct.getPipeline(); this.responseTag = responseTag; - this.responses = pct.get(responseTag); + this.responses = pct.get(responseTag).setCoder(responseTCoder); this.failures = pct.get(FAILURE_TAG); } @@ -98,12 +295,12 @@ public PCollection getFailures() { } @Override - public Pipeline getPipeline() { + public @NonNull Pipeline getPipeline() { return this.pipeline; } @Override - public Map, PValue> expand() { + public @NonNull Map, PValue> expand() { return ImmutableMap.of( responseTag, responses, FAILURE_TAG, failures); @@ -111,6 +308,112 @@ public Map, PValue> expand() { @Override public void finishSpecifyingOutput( - String transformName, PInput input, PTransform transform) {} + @NonNull String transformName, + @NonNull PInput input, + @NonNull PTransform transform) {} + } + + private static class NoopSetupTeardown implements SetupTeardown { + + @Override + public void setup() throws UserCodeExecutionException { + // Noop + } + + @Override + public void teardown() throws UserCodeExecutionException { + // Noop + } + } + + private static class CallerWithTimeout + implements Caller { + private final Duration timeout; + private final Caller caller; + private @MonotonicNonNull ExecutorService executor; + + private CallerWithTimeout(Duration timeout, Caller caller) { + this.timeout = timeout; + this.caller = caller; + } + + private void setExecutor(ExecutorService executor) { + this.executor = executor; + } + + @Override + public ResponseT call(RequestT request) throws UserCodeExecutionException { + Future future = checkStateNotNull(executor).submit(() -> caller.call(request)); + try { + return future.get(timeout.getMillis(), TimeUnit.MILLISECONDS); + } catch (TimeoutException | InterruptedException e) { + throw new UserCodeTimeoutException(e); + } catch (ExecutionException e) { + parseAndThrow(future, e); + } + throw new UserCodeExecutionException("could not complete request"); + } + } + + private static class SetupTeardownWithTimeout implements SetupTeardown { + private final Duration timeout; + private final SetupTeardown setupTeardown; + private @MonotonicNonNull ExecutorService executor; + + SetupTeardownWithTimeout(Duration timeout, SetupTeardown setupTeardown) { + this.timeout = timeout; + this.setupTeardown = setupTeardown; + } + + private void setExecutor(ExecutorService executor) { + this.executor = executor; + } + + @Override + public void setup() throws UserCodeExecutionException { + Callable callable = + () -> { + setupTeardown.setup(); + return null; + }; + + executeAsync(callable); + } + + @Override + public void teardown() throws UserCodeExecutionException { + Callable callable = + () -> { + setupTeardown.teardown(); + return null; + }; + + executeAsync(callable); + } + + private void executeAsync(Callable callable) throws UserCodeExecutionException { + Future future = checkStateNotNull(executor).submit(callable); + try { + future.get(timeout.getMillis(), TimeUnit.MILLISECONDS); + } catch (TimeoutException | InterruptedException e) { + future.cancel(true); + throw new UserCodeTimeoutException(e); + } catch (ExecutionException e) { + parseAndThrow(future, e); + } + } + } + + private static void parseAndThrow(Future future, ExecutionException e) + throws UserCodeExecutionException { + future.cancel(true); + if (e.getCause() == null) { + throw new UserCodeExecutionException(e); + } + Throwable cause = checkStateNotNull(e.getCause()); + if (cause instanceof UserCodeQuotaException) { + throw new UserCodeQuotaException(cause); + } + throw new UserCodeExecutionException(cause); } } diff --git a/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java new file mode 100644 index 000000000000..18574b00978d --- /dev/null +++ b/sdks/java/io/rrio/src/test/java/org/apache/beam/io/requestresponse/CallTest.java @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.io.requestresponse; + +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.sdk.values.TypeDescriptors.strings; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.Serializable; +import org.apache.beam.io.requestresponse.Call.Result; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Objects; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.UncheckedExecutionException; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.joda.time.Duration; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link Call}. */ +@RunWith(JUnit4.class) +public class CallTest { + @Rule public TestPipeline pipeline = TestPipeline.create(); + + private static final SerializableCoder<@NonNull Response> RESPONSE_CODER = + SerializableCoder.of(Response.class); + + @Test + public void givenCallerNotSerializable_throwsError() { + assertThrows( + IllegalArgumentException.class, () -> Call.of(new UnSerializableCaller(), RESPONSE_CODER)); + } + + @Test + public void givenSetupTeardownNotSerializable_throwsError() { + assertThrows( + IllegalArgumentException.class, + () -> + Call.ofCallerAndSetupTeardown( + new UnSerializableCallerWithSetupTeardown(), RESPONSE_CODER)); + } + + @Test + public void givenCallerThrowsUserCodeExecutionException_emitsIntoFailurePCollection() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply(Call.of(new CallerThrowsUserCodeExecutionException(), RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeQuotaException.class)).isEqualTo(0L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + + @Test + public void givenCallerThrowsQuotaException_emitsIntoFailurePCollection() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply(Call.of(new CallerInvokesQuotaException(), RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeQuotaException.class)).isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) + .isEqualTo(0L); + + pipeline.run(); + } + + @Test + public void givenCallerTimeout_emitsFailurePCollection() { + Duration timeout = Duration.standardSeconds(1L); + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply(Call.of(new CallerExceedsTimeout(timeout), RESPONSE_CODER).withTimeout(timeout)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(0L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeQuotaException.class)).isEqualTo(0L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) + .isEqualTo(1L); + + pipeline.run(); + } + + @Test + public void givenCallerThrowsTimeoutException_emitsFailurePCollection() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply(Call.of(new CallerThrowsTimeout(), RESPONSE_CODER)); + + PCollection failures = result.getFailures(); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeExecutionException.class)) + .isEqualTo(1L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeQuotaException.class)).isEqualTo(0L); + PAssert.thatSingleton(countStackTracesOf(failures, UserCodeTimeoutException.class)) + .isEqualTo(1L); + + pipeline.run(); + } + + @Test + public void givenSetupThrowsUserCodeExecutionException_throwsError() { + pipeline + .apply(Create.of(new Request(""))) + .apply( + Call.of(new ValidCaller(), RESPONSE_CODER) + .withSetupTeardown(new SetupThrowsUserCodeExecutionException())); + + assertPipelineThrows(UserCodeExecutionException.class, pipeline); + } + + @Test + public void givenSetupThrowsQuotaException_throwsError() { + pipeline + .apply(Create.of(new Request(""))) + .apply( + Call.of(new ValidCaller(), RESPONSE_CODER) + .withSetupTeardown(new SetupThrowsUserCodeQuotaException())); + + assertPipelineThrows(UserCodeQuotaException.class, pipeline); + } + + @Test + public void givenSetupTimeout_throwsError() { + Duration timeout = Duration.standardSeconds(1L); + + pipeline + .apply(Create.of(new Request(""))) + .apply( + Call.of(new ValidCaller(), RESPONSE_CODER) + .withSetupTeardown(new SetupExceedsTimeout(timeout)) + .withTimeout(timeout)); + + assertPipelineThrows(UserCodeTimeoutException.class, pipeline); + } + + @Test + public void givenSetupThrowsTimeoutException_throwsError() { + pipeline + .apply(Create.of(new Request(""))) + .apply( + Call.of(new ValidCaller(), RESPONSE_CODER) + .withSetupTeardown(new SetupThrowsUserCodeTimeoutException())); + + assertPipelineThrows(UserCodeTimeoutException.class, pipeline); + } + + @Test + public void givenTeardownThrowsUserCodeExecutionException_throwsError() { + pipeline + .apply(Create.of(new Request(""))) + .apply( + Call.of(new ValidCaller(), RESPONSE_CODER) + .withSetupTeardown(new TeardownThrowsUserCodeExecutionException())); + + // Exceptions thrown during teardown do not populate with the cause + assertThrows(IllegalStateException.class, () -> pipeline.run()); + } + + @Test + public void givenTeardownThrowsQuotaException_throwsError() { + pipeline + .apply(Create.of(new Request(""))) + .apply( + Call.of(new ValidCaller(), RESPONSE_CODER) + .withSetupTeardown(new TeardownThrowsUserCodeQuotaException())); + + // Exceptions thrown during teardown do not populate with the cause + assertThrows(IllegalStateException.class, () -> pipeline.run()); + } + + @Test + public void givenTeardownTimeout_throwsError() { + Duration timeout = Duration.standardSeconds(1L); + pipeline + .apply(Create.of(new Request(""))) + .apply( + Call.of(new ValidCaller(), RESPONSE_CODER) + .withTimeout(timeout) + .withSetupTeardown(new TeardownExceedsTimeout(timeout))); + + // Exceptions thrown during teardown do not populate with the cause + assertThrows(IllegalStateException.class, () -> pipeline.run()); + } + + @Test + public void givenTeardownThrowsTimeoutException_throwsError() { + pipeline + .apply(Create.of(new Request(""))) + .apply( + Call.of(new ValidCaller(), RESPONSE_CODER) + .withSetupTeardown(new TeardownThrowsUserCodeTimeoutException())); + + // Exceptions thrown during teardown do not populate with the cause + assertThrows(IllegalStateException.class, () -> pipeline.run()); + } + + @Test + public void givenValidCaller_emitValidResponse() { + Result result = + pipeline + .apply(Create.of(new Request("a"))) + .apply(Call.of(new ValidCaller(), RESPONSE_CODER)); + + PAssert.thatSingleton(result.getFailures().apply(Count.globally())).isEqualTo(0L); + PAssert.that(result.getResponses()).containsInAnyOrder(new Response("a")); + + pipeline.run(); + } + + private static class ValidCaller implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + return new Response(request.id); + } + } + + private static class UnSerializableCaller implements Caller { + + @SuppressWarnings({"unused"}) + private final UnSerializable nestedThing = new UnSerializable(); + + @Override + public Response call(Request request) throws UserCodeExecutionException { + return new Response(request.id); + } + } + + private static class UnSerializableCallerWithSetupTeardown extends UnSerializableCaller + implements SetupTeardown { + + @Override + public void setup() throws UserCodeExecutionException {} + + @Override + public void teardown() throws UserCodeExecutionException {} + } + + private static class UnSerializable {} + + private static class Request implements Serializable { + + final String id; + + Request(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Request request = (Request) o; + return Objects.equal(id, request.id); + } + + @Override + public int hashCode() { + return Objects.hashCode(id); + } + } + + private static class Response implements Serializable { + final String id; + + Response(String id) { + this.id = id; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Response response = (Response) o; + return Objects.equal(id, response.id); + } + + @Override + public int hashCode() { + return Objects.hashCode(id); + } + } + + private static class CallerExceedsTimeout implements Caller { + private final Duration timeout; + + CallerExceedsTimeout(Duration timeout) { + this.timeout = timeout.plus(Duration.standardSeconds(1L)); + } + + @Override + public Response call(Request request) throws UserCodeExecutionException { + sleep(timeout); + return new Response(request.id); + } + } + + private static class CallerThrowsUserCodeExecutionException implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UserCodeExecutionException(request.id); + } + } + + private static class CallerThrowsTimeout implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UserCodeTimeoutException(""); + } + } + + private static class CallerInvokesQuotaException implements Caller { + + @Override + public Response call(Request request) throws UserCodeExecutionException { + throw new UserCodeQuotaException(request.id); + } + } + + private static class SetupExceedsTimeout implements SetupTeardown { + + private final Duration timeout; + + private SetupExceedsTimeout(Duration timeout) { + this.timeout = timeout.plus(Duration.standardSeconds(1L)); + } + + @Override + public void setup() throws UserCodeExecutionException { + sleep(timeout); + } + + @Override + public void teardown() throws UserCodeExecutionException {} + } + + private static class SetupThrowsUserCodeExecutionException implements SetupTeardown { + @Override + public void setup() throws UserCodeExecutionException { + throw new UserCodeExecutionException("error message"); + } + + @Override + public void teardown() throws UserCodeExecutionException {} + } + + private static class SetupThrowsUserCodeQuotaException implements SetupTeardown { + @Override + public void setup() throws UserCodeExecutionException { + throw new UserCodeQuotaException(""); + } + + @Override + public void teardown() throws UserCodeExecutionException {} + } + + private static class SetupThrowsUserCodeTimeoutException implements SetupTeardown { + @Override + public void setup() throws UserCodeExecutionException { + throw new UserCodeTimeoutException(""); + } + + @Override + public void teardown() throws UserCodeExecutionException {} + } + + private static class TeardownExceedsTimeout implements SetupTeardown { + private final Duration timeout; + + private TeardownExceedsTimeout(Duration timeout) { + this.timeout = timeout.plus(Duration.standardSeconds(1L)); + } + + @Override + public void setup() throws UserCodeExecutionException {} + + @Override + public void teardown() throws UserCodeExecutionException { + sleep(timeout); + } + } + + private static class TeardownThrowsUserCodeExecutionException implements SetupTeardown { + @Override + public void setup() throws UserCodeExecutionException {} + + @Override + public void teardown() throws UserCodeExecutionException { + throw new UserCodeExecutionException(""); + } + } + + private static class TeardownThrowsUserCodeQuotaException implements SetupTeardown { + @Override + public void setup() throws UserCodeExecutionException {} + + @Override + public void teardown() throws UserCodeExecutionException { + throw new UserCodeQuotaException(""); + } + } + + private static class TeardownThrowsUserCodeTimeoutException implements SetupTeardown { + @Override + public void setup() throws UserCodeExecutionException {} + + @Override + public void teardown() throws UserCodeExecutionException { + throw new UserCodeExecutionException(""); + } + } + + private static void assertPipelineThrows( + Class clazz, TestPipeline p) { + + // Because we need to wrap in a timeout via a java Future, exceptions are thrown as + // UncheckedExecutionException + UncheckedExecutionException error = assertThrows(UncheckedExecutionException.class, p::run); + + // Iterate through the stack trace to assert ErrorT is among stack. + assertTrue( + error.toString(), Throwables.getCausalChain(error).stream().anyMatch(clazz::isInstance)); + } + + private static PCollection countStackTracesOf( + PCollection failures, Class clazz) { + return failures + .apply( + "stackTrace " + clazz.getSimpleName(), + MapElements.into(strings()).via(failure -> checkStateNotNull(failure).getStackTrace())) + .apply( + "filter " + clazz.getSimpleName(), Filter.by(input -> input.contains(clazz.getName()))) + .apply("count " + clazz.getSimpleName(), Count.globally()); + } + + private static void sleep(Duration timeout) { + try { + Thread.sleep(timeout.getMillis()); + } catch (InterruptedException ignored) { + } + } +}