-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable Job management for the Prism runner (#32091)
- Loading branch information
1 parent
07e692b
commit 656a296
Showing
2 changed files
with
371 additions
and
0 deletions.
There are no files selected for viewing
160 changes: 160 additions & 0 deletions
160
runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismJobManager.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.beam.runners.prism; | ||
|
||
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; | ||
|
||
import com.google.auto.value.AutoValue; | ||
import java.io.Closeable; | ||
import java.util.Optional; | ||
import java.util.concurrent.TimeUnit; | ||
import org.apache.beam.model.jobmanagement.v1.JobApi; | ||
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc; | ||
import org.apache.beam.model.pipeline.v1.Endpoints; | ||
import org.apache.beam.sdk.PipelineResult; | ||
import org.apache.beam.sdk.fn.channel.ManagedChannelFactory; | ||
import org.apache.beam.sdk.options.PortablePipelineOptions; | ||
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; | ||
import org.joda.time.Duration; | ||
|
||
/** | ||
* A wrapper for {@link JobServiceGrpc.JobServiceBlockingStub} that {@link #close}es when {@link | ||
* StateListener#onStateChanged} is invoked with a {@link PipelineResult.State} that is {@link | ||
* PipelineResult.State#isTerminal}. | ||
*/ | ||
@AutoValue | ||
abstract class PrismJobManager implements StateListener, Closeable { | ||
|
||
/** | ||
* Instantiate a {@link PrismJobManager} with {@param options}, assigning {@link #getEndpoint} | ||
* from {@link PortablePipelineOptions#getJobEndpoint} and {@link #getTimeout} from {@link | ||
* PortablePipelineOptions#getJobServerTimeout}. Defaults the instantiations of {@link | ||
* #getManagedChannel} and {@link #getBlockingStub}. See respective getters for more details. | ||
*/ | ||
static PrismJobManager of(PortablePipelineOptions options) { | ||
return builder() | ||
.setEndpoint(options.getJobEndpoint()) | ||
.setTimeout(Duration.standardSeconds(options.getJobServerTimeout())) | ||
.build(); | ||
} | ||
|
||
static Builder builder() { | ||
return new AutoValue_PrismJobManager.Builder(); | ||
} | ||
|
||
/** | ||
* Executes {@link #getBlockingStub()}'s {@link JobServiceGrpc.JobServiceBlockingStub#prepare} | ||
* method. | ||
*/ | ||
JobApi.PrepareJobResponse prepare(JobApi.PrepareJobRequest request) { | ||
return getBlockingStub().prepare(request); | ||
} | ||
|
||
/** | ||
* Executes {@link #getBlockingStub()}'s {@link JobServiceGrpc.JobServiceBlockingStub#run} method. | ||
*/ | ||
JobApi.RunJobResponse run(JobApi.RunJobRequest request) { | ||
return getBlockingStub().run(request); | ||
} | ||
|
||
/** The {@link JobServiceGrpc} endpoint. */ | ||
abstract String getEndpoint(); | ||
|
||
/** The {@link JobServiceGrpc} timeout. */ | ||
abstract Duration getTimeout(); | ||
|
||
/** The {@link #getBlockingStub}'s channel. Defaulted from the {@link #getEndpoint()}. */ | ||
abstract ManagedChannel getManagedChannel(); | ||
|
||
/** The wrapped service defaulted using the {@link #getManagedChannel}. */ | ||
abstract JobServiceGrpc.JobServiceBlockingStub getBlockingStub(); | ||
|
||
/** Shuts down {@link #getManagedChannel}, if not {@link #isShutdown}. */ | ||
@Override | ||
public void close() { | ||
if (isShutdown()) { | ||
return; | ||
} | ||
getManagedChannel().shutdown(); | ||
try { | ||
getManagedChannel().awaitTermination(3000L, TimeUnit.MILLISECONDS); | ||
} catch (InterruptedException ignored) { | ||
} | ||
} | ||
|
||
/** Queries whether {@link #getManagedChannel} {@link ManagedChannel#isShutdown}. */ | ||
boolean isShutdown() { | ||
return getManagedChannel().isShutdown(); | ||
} | ||
|
||
/** | ||
* Override of {@link StateListener#onStateChanged}. Invokes {@link #close} when {@link | ||
* PipelineResult.State} {@link PipelineResult.State#isTerminal}. | ||
*/ | ||
@Override | ||
public void onStateChanged(PipelineResult.State state) { | ||
if (state.isTerminal()) { | ||
close(); | ||
} | ||
} | ||
|
||
@AutoValue.Builder | ||
abstract static class Builder { | ||
|
||
abstract Builder setEndpoint(String endpoint); | ||
|
||
abstract Optional<String> getEndpoint(); | ||
|
||
abstract Builder setTimeout(Duration timeout); | ||
|
||
abstract Optional<Duration> getTimeout(); | ||
|
||
abstract Builder setManagedChannel(ManagedChannel managedChannel); | ||
|
||
abstract Optional<ManagedChannel> getManagedChannel(); | ||
|
||
abstract Builder setBlockingStub(JobServiceGrpc.JobServiceBlockingStub blockingStub); | ||
|
||
abstract Optional<JobServiceGrpc.JobServiceBlockingStub> getBlockingStub(); | ||
|
||
abstract PrismJobManager autoBuild(); | ||
|
||
final PrismJobManager build() { | ||
|
||
checkState(getEndpoint().isPresent(), "endpoint is not set"); | ||
checkState(getTimeout().isPresent(), "timeout is not set"); | ||
|
||
if (!getManagedChannel().isPresent()) { | ||
ManagedChannelFactory channelFactory = ManagedChannelFactory.createDefault(); | ||
|
||
setManagedChannel( | ||
channelFactory.forDescriptor( | ||
Endpoints.ApiServiceDescriptor.newBuilder().setUrl(getEndpoint().get()).build())); | ||
} | ||
|
||
if (!getBlockingStub().isPresent()) { | ||
setBlockingStub( | ||
JobServiceGrpc.newBlockingStub(getManagedChannel().get()) | ||
.withDeadlineAfter(getTimeout().get().getMillis(), TimeUnit.MILLISECONDS) | ||
.withWaitForReady()); | ||
} | ||
|
||
return autoBuild(); | ||
} | ||
} | ||
} |
211 changes: 211 additions & 0 deletions
211
runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismJobManagerTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.beam.runners.prism; | ||
|
||
import static com.google.common.truth.Truth.assertThat; | ||
import static org.junit.Assert.assertThrows; | ||
|
||
import java.io.IOException; | ||
import java.util.Optional; | ||
import org.apache.beam.model.jobmanagement.v1.JobApi; | ||
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc; | ||
import org.apache.beam.model.pipeline.v1.Endpoints; | ||
import org.apache.beam.model.pipeline.v1.RunnerApi; | ||
import org.apache.beam.sdk.Pipeline; | ||
import org.apache.beam.sdk.PipelineResult; | ||
import org.apache.beam.sdk.transforms.Impulse; | ||
import org.apache.beam.sdk.util.construction.PipelineTranslation; | ||
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; | ||
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; | ||
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; | ||
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; | ||
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; | ||
import org.joda.time.Duration; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.TestName; | ||
import org.junit.runner.RunWith; | ||
import org.junit.runners.JUnit4; | ||
|
||
/** Tests for {@link PrismJobManager}. */ | ||
@RunWith(JUnit4.class) | ||
public class PrismJobManagerTest { | ||
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); | ||
|
||
@Rule public TestName testName = new TestName(); | ||
|
||
@Test | ||
public void givenPrepareError_forwardsException_canGracefulShutdown() { | ||
TestJobService service = | ||
new TestJobService().withErrorResponse(new RuntimeException(testName.getMethodName())); | ||
PrismJobManager underTest = prismJobManager(service); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
assertThrows( | ||
RuntimeException.class, | ||
() -> | ||
underTest.prepare( | ||
JobApi.PrepareJobRequest.newBuilder().setPipeline(pipelineOf()).build())); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
underTest.close(); | ||
assertThat(underTest.isShutdown()).isTrue(); | ||
} | ||
|
||
@Test | ||
public void givenPrepareSuccess_forwardsResponse_canGracefulShutdown() { | ||
TestJobService service = | ||
new TestJobService() | ||
.withPrepareJobResponse( | ||
JobApi.PrepareJobResponse.newBuilder() | ||
.setStagingSessionToken("token") | ||
.setPreparationId("preparationId") | ||
.setArtifactStagingEndpoint( | ||
Endpoints.ApiServiceDescriptor.newBuilder() | ||
.setUrl("localhost:1234") | ||
.build()) | ||
.build()); | ||
PrismJobManager underTest = prismJobManager(service); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
JobApi.PrepareJobResponse response = | ||
underTest.prepare(JobApi.PrepareJobRequest.newBuilder().setPipeline(pipelineOf()).build()); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
assertThat(response.getStagingSessionToken()).isEqualTo("token"); | ||
assertThat(response.getPreparationId()).isEqualTo("preparationId"); | ||
underTest.close(); | ||
assertThat(underTest.isShutdown()).isTrue(); | ||
} | ||
|
||
@Test | ||
public void givenRunError_forwardsException_canGracefulShutdown() { | ||
TestJobService service = | ||
new TestJobService().withErrorResponse(new RuntimeException(testName.getMethodName())); | ||
PrismJobManager underTest = prismJobManager(service); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
assertThrows( | ||
RuntimeException.class, | ||
() -> | ||
underTest.run(JobApi.RunJobRequest.newBuilder().setPreparationId("prepareId").build())); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
underTest.close(); | ||
assertThat(underTest.isShutdown()).isTrue(); | ||
} | ||
|
||
@Test | ||
public void givenRunSuccess_forwardsResponse_canGracefulShutdown() { | ||
TestJobService service = | ||
new TestJobService() | ||
.withRunJobResponse(JobApi.RunJobResponse.newBuilder().setJobId("jobId").build()); | ||
PrismJobManager underTest = prismJobManager(service); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
JobApi.RunJobResponse runJobResponse = | ||
underTest.run(JobApi.RunJobRequest.newBuilder().setPreparationId("preparationId").build()); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
assertThat(runJobResponse.getJobId()).isEqualTo("jobId"); | ||
underTest.close(); | ||
assertThat(underTest.isShutdown()).isTrue(); | ||
} | ||
|
||
@Test | ||
public void givenTerminalState_closes() { | ||
PrismJobManager underTest = prismJobManager(new TestJobService()); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
underTest.onStateChanged(PipelineResult.State.RUNNING); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
underTest.onStateChanged(PipelineResult.State.RUNNING); | ||
assertThat(underTest.isShutdown()).isFalse(); | ||
underTest.onStateChanged(PipelineResult.State.CANCELLED); | ||
assertThat(underTest.isShutdown()).isTrue(); | ||
|
||
underTest.close(); | ||
} | ||
|
||
private PrismJobManager prismJobManager(TestJobService service) { | ||
String serverName = InProcessServerBuilder.generateName(); | ||
try { | ||
grpcCleanup.register( | ||
InProcessServerBuilder.forName(serverName) | ||
.directExecutor() | ||
.addService(service) | ||
.build() | ||
.start()); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
|
||
ManagedChannel channel = | ||
grpcCleanup.register(InProcessChannelBuilder.forName(serverName).build()); | ||
|
||
return PrismJobManager.builder() | ||
.setTimeout(Duration.millis(3000L)) | ||
.setEndpoint("ignore") | ||
.setManagedChannel(channel) | ||
.build(); | ||
} | ||
|
||
private static class TestJobService extends JobServiceGrpc.JobServiceImplBase { | ||
|
||
private Optional<JobApi.PrepareJobResponse> prepareJobResponse = Optional.empty(); | ||
private Optional<JobApi.RunJobResponse> runJobResponse = Optional.empty(); | ||
private Optional<RuntimeException> error = Optional.empty(); | ||
|
||
TestJobService withPrepareJobResponse(JobApi.PrepareJobResponse prepareJobResponse) { | ||
this.prepareJobResponse = Optional.of(prepareJobResponse); | ||
return this; | ||
} | ||
|
||
TestJobService withRunJobResponse(JobApi.RunJobResponse runJobResponse) { | ||
this.runJobResponse = Optional.of(runJobResponse); | ||
return this; | ||
} | ||
|
||
TestJobService withErrorResponse(RuntimeException error) { | ||
this.error = Optional.of(error); | ||
return this; | ||
} | ||
|
||
@Override | ||
public void prepare( | ||
JobApi.PrepareJobRequest request, | ||
StreamObserver<JobApi.PrepareJobResponse> responseObserver) { | ||
if (prepareJobResponse.isPresent()) { | ||
responseObserver.onNext(prepareJobResponse.get()); | ||
responseObserver.onCompleted(); | ||
} | ||
if (error.isPresent()) { | ||
responseObserver.onError(error.get()); | ||
} | ||
} | ||
|
||
@Override | ||
public void run( | ||
JobApi.RunJobRequest request, StreamObserver<JobApi.RunJobResponse> responseObserver) { | ||
if (runJobResponse.isPresent()) { | ||
responseObserver.onNext(runJobResponse.get()); | ||
responseObserver.onCompleted(); | ||
} | ||
if (error.isPresent()) { | ||
responseObserver.onError(error.get()); | ||
} | ||
} | ||
} | ||
|
||
private static RunnerApi.Pipeline pipelineOf() { | ||
Pipeline pipeline = Pipeline.create(); | ||
pipeline.apply(Impulse.create()); | ||
return PipelineTranslation.toProto(pipeline); | ||
} | ||
} |