From 656a296a82d00f1d17745bbe6a161e75104506e7 Mon Sep 17 00:00:00 2001 From: Damon Date: Wed, 7 Aug 2024 13:16:00 -0700 Subject: [PATCH] Enable Job management for the Prism runner (#32091) --- .../beam/runners/prism/PrismJobManager.java | 160 +++++++++++++ .../runners/prism/PrismJobManagerTest.java | 211 ++++++++++++++++++ 2 files changed, 371 insertions(+) create mode 100644 runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismJobManager.java create mode 100644 runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismJobManagerTest.java diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismJobManager.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismJobManager.java new file mode 100644 index 000000000000..e461e92c4749 --- /dev/null +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismJobManager.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import com.google.auto.value.AutoValue; +import java.io.Closeable; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import org.apache.beam.model.jobmanagement.v1.JobApi; +import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc; +import org.apache.beam.model.pipeline.v1.Endpoints; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.fn.channel.ManagedChannelFactory; +import org.apache.beam.sdk.options.PortablePipelineOptions; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.joda.time.Duration; + +/** + * A wrapper for {@link JobServiceGrpc.JobServiceBlockingStub} that {@link #close}es when {@link + * StateListener#onStateChanged} is invoked with a {@link PipelineResult.State} that is {@link + * PipelineResult.State#isTerminal}. + */ +@AutoValue +abstract class PrismJobManager implements StateListener, Closeable { + + /** + * Instantiate a {@link PrismJobManager} with {@param options}, assigning {@link #getEndpoint} + * from {@link PortablePipelineOptions#getJobEndpoint} and {@link #getTimeout} from {@link + * PortablePipelineOptions#getJobServerTimeout}. Defaults the instantiations of {@link + * #getManagedChannel} and {@link #getBlockingStub}. See respective getters for more details. + */ + static PrismJobManager of(PortablePipelineOptions options) { + return builder() + .setEndpoint(options.getJobEndpoint()) + .setTimeout(Duration.standardSeconds(options.getJobServerTimeout())) + .build(); + } + + static Builder builder() { + return new AutoValue_PrismJobManager.Builder(); + } + + /** + * Executes {@link #getBlockingStub()}'s {@link JobServiceGrpc.JobServiceBlockingStub#prepare} + * method. + */ + JobApi.PrepareJobResponse prepare(JobApi.PrepareJobRequest request) { + return getBlockingStub().prepare(request); + } + + /** + * Executes {@link #getBlockingStub()}'s {@link JobServiceGrpc.JobServiceBlockingStub#run} method. + */ + JobApi.RunJobResponse run(JobApi.RunJobRequest request) { + return getBlockingStub().run(request); + } + + /** The {@link JobServiceGrpc} endpoint. */ + abstract String getEndpoint(); + + /** The {@link JobServiceGrpc} timeout. */ + abstract Duration getTimeout(); + + /** The {@link #getBlockingStub}'s channel. Defaulted from the {@link #getEndpoint()}. */ + abstract ManagedChannel getManagedChannel(); + + /** The wrapped service defaulted using the {@link #getManagedChannel}. */ + abstract JobServiceGrpc.JobServiceBlockingStub getBlockingStub(); + + /** Shuts down {@link #getManagedChannel}, if not {@link #isShutdown}. */ + @Override + public void close() { + if (isShutdown()) { + return; + } + getManagedChannel().shutdown(); + try { + getManagedChannel().awaitTermination(3000L, TimeUnit.MILLISECONDS); + } catch (InterruptedException ignored) { + } + } + + /** Queries whether {@link #getManagedChannel} {@link ManagedChannel#isShutdown}. */ + boolean isShutdown() { + return getManagedChannel().isShutdown(); + } + + /** + * Override of {@link StateListener#onStateChanged}. Invokes {@link #close} when {@link + * PipelineResult.State} {@link PipelineResult.State#isTerminal}. + */ + @Override + public void onStateChanged(PipelineResult.State state) { + if (state.isTerminal()) { + close(); + } + } + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setEndpoint(String endpoint); + + abstract Optional getEndpoint(); + + abstract Builder setTimeout(Duration timeout); + + abstract Optional getTimeout(); + + abstract Builder setManagedChannel(ManagedChannel managedChannel); + + abstract Optional getManagedChannel(); + + abstract Builder setBlockingStub(JobServiceGrpc.JobServiceBlockingStub blockingStub); + + abstract Optional getBlockingStub(); + + abstract PrismJobManager autoBuild(); + + final PrismJobManager build() { + + checkState(getEndpoint().isPresent(), "endpoint is not set"); + checkState(getTimeout().isPresent(), "timeout is not set"); + + if (!getManagedChannel().isPresent()) { + ManagedChannelFactory channelFactory = ManagedChannelFactory.createDefault(); + + setManagedChannel( + channelFactory.forDescriptor( + Endpoints.ApiServiceDescriptor.newBuilder().setUrl(getEndpoint().get()).build())); + } + + if (!getBlockingStub().isPresent()) { + setBlockingStub( + JobServiceGrpc.newBlockingStub(getManagedChannel().get()) + .withDeadlineAfter(getTimeout().get().getMillis(), TimeUnit.MILLISECONDS) + .withWaitForReady()); + } + + return autoBuild(); + } + } +} diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismJobManagerTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismJobManagerTest.java new file mode 100644 index 000000000000..1e38e4f8d12e --- /dev/null +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismJobManagerTest.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.prism; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.util.Optional; +import org.apache.beam.model.jobmanagement.v1.JobApi; +import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc; +import org.apache.beam.model.pipeline.v1.Endpoints; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.joda.time.Duration; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link PrismJobManager}. */ +@RunWith(JUnit4.class) +public class PrismJobManagerTest { + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + @Rule public TestName testName = new TestName(); + + @Test + public void givenPrepareError_forwardsException_canGracefulShutdown() { + TestJobService service = + new TestJobService().withErrorResponse(new RuntimeException(testName.getMethodName())); + PrismJobManager underTest = prismJobManager(service); + assertThat(underTest.isShutdown()).isFalse(); + assertThrows( + RuntimeException.class, + () -> + underTest.prepare( + JobApi.PrepareJobRequest.newBuilder().setPipeline(pipelineOf()).build())); + assertThat(underTest.isShutdown()).isFalse(); + underTest.close(); + assertThat(underTest.isShutdown()).isTrue(); + } + + @Test + public void givenPrepareSuccess_forwardsResponse_canGracefulShutdown() { + TestJobService service = + new TestJobService() + .withPrepareJobResponse( + JobApi.PrepareJobResponse.newBuilder() + .setStagingSessionToken("token") + .setPreparationId("preparationId") + .setArtifactStagingEndpoint( + Endpoints.ApiServiceDescriptor.newBuilder() + .setUrl("localhost:1234") + .build()) + .build()); + PrismJobManager underTest = prismJobManager(service); + assertThat(underTest.isShutdown()).isFalse(); + JobApi.PrepareJobResponse response = + underTest.prepare(JobApi.PrepareJobRequest.newBuilder().setPipeline(pipelineOf()).build()); + assertThat(underTest.isShutdown()).isFalse(); + assertThat(response.getStagingSessionToken()).isEqualTo("token"); + assertThat(response.getPreparationId()).isEqualTo("preparationId"); + underTest.close(); + assertThat(underTest.isShutdown()).isTrue(); + } + + @Test + public void givenRunError_forwardsException_canGracefulShutdown() { + TestJobService service = + new TestJobService().withErrorResponse(new RuntimeException(testName.getMethodName())); + PrismJobManager underTest = prismJobManager(service); + assertThat(underTest.isShutdown()).isFalse(); + assertThrows( + RuntimeException.class, + () -> + underTest.run(JobApi.RunJobRequest.newBuilder().setPreparationId("prepareId").build())); + assertThat(underTest.isShutdown()).isFalse(); + underTest.close(); + assertThat(underTest.isShutdown()).isTrue(); + } + + @Test + public void givenRunSuccess_forwardsResponse_canGracefulShutdown() { + TestJobService service = + new TestJobService() + .withRunJobResponse(JobApi.RunJobResponse.newBuilder().setJobId("jobId").build()); + PrismJobManager underTest = prismJobManager(service); + assertThat(underTest.isShutdown()).isFalse(); + JobApi.RunJobResponse runJobResponse = + underTest.run(JobApi.RunJobRequest.newBuilder().setPreparationId("preparationId").build()); + assertThat(underTest.isShutdown()).isFalse(); + assertThat(runJobResponse.getJobId()).isEqualTo("jobId"); + underTest.close(); + assertThat(underTest.isShutdown()).isTrue(); + } + + @Test + public void givenTerminalState_closes() { + PrismJobManager underTest = prismJobManager(new TestJobService()); + assertThat(underTest.isShutdown()).isFalse(); + underTest.onStateChanged(PipelineResult.State.RUNNING); + assertThat(underTest.isShutdown()).isFalse(); + underTest.onStateChanged(PipelineResult.State.RUNNING); + assertThat(underTest.isShutdown()).isFalse(); + underTest.onStateChanged(PipelineResult.State.CANCELLED); + assertThat(underTest.isShutdown()).isTrue(); + + underTest.close(); + } + + private PrismJobManager prismJobManager(TestJobService service) { + String serverName = InProcessServerBuilder.generateName(); + try { + grpcCleanup.register( + InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(service) + .build() + .start()); + } catch (IOException e) { + throw new RuntimeException(e); + } + + ManagedChannel channel = + grpcCleanup.register(InProcessChannelBuilder.forName(serverName).build()); + + return PrismJobManager.builder() + .setTimeout(Duration.millis(3000L)) + .setEndpoint("ignore") + .setManagedChannel(channel) + .build(); + } + + private static class TestJobService extends JobServiceGrpc.JobServiceImplBase { + + private Optional prepareJobResponse = Optional.empty(); + private Optional runJobResponse = Optional.empty(); + private Optional error = Optional.empty(); + + TestJobService withPrepareJobResponse(JobApi.PrepareJobResponse prepareJobResponse) { + this.prepareJobResponse = Optional.of(prepareJobResponse); + return this; + } + + TestJobService withRunJobResponse(JobApi.RunJobResponse runJobResponse) { + this.runJobResponse = Optional.of(runJobResponse); + return this; + } + + TestJobService withErrorResponse(RuntimeException error) { + this.error = Optional.of(error); + return this; + } + + @Override + public void prepare( + JobApi.PrepareJobRequest request, + StreamObserver responseObserver) { + if (prepareJobResponse.isPresent()) { + responseObserver.onNext(prepareJobResponse.get()); + responseObserver.onCompleted(); + } + if (error.isPresent()) { + responseObserver.onError(error.get()); + } + } + + @Override + public void run( + JobApi.RunJobRequest request, StreamObserver responseObserver) { + if (runJobResponse.isPresent()) { + responseObserver.onNext(runJobResponse.get()); + responseObserver.onCompleted(); + } + if (error.isPresent()) { + responseObserver.onError(error.get()); + } + } + } + + private static RunnerApi.Pipeline pipelineOf() { + Pipeline pipeline = Pipeline.create(); + pipeline.apply(Impulse.create()); + return PipelineTranslation.toProto(pipeline); + } +}