Skip to content

Commit

Permalink
Implement Java TestPrismRunner and PrismRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
damondouglas committed Aug 23, 2024
1 parent 1e80815 commit 02c94dd
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 191 deletions.
1 change: 1 addition & 0 deletions runners/prism/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies {
implementation library.java.slf4j_api
implementation library.java.vendored_grpc_1_60_1
implementation library.java.vendored_guava_32_1_2_jre
compileOnly library.java.hamcrest

testImplementation library.java.junit
testImplementation library.java.mockito_core
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
abstract class PrismExecutor {

private static final Logger LOG = LoggerFactory.getLogger(PrismExecutor.class);
static final String JOB_PORT_FLAG_TEMPLATE = "-job_port=%s";
static final String SERVE_HTTP_FLAG_TEMPLATE = "-serve_http=%s";

protected @MonotonicNonNull Process process;
protected ExecutorService executorService = Executors.newSingleThreadExecutor();
Expand All @@ -71,7 +73,7 @@ void stop() {
}
executorService.shutdown();
try {
boolean ignored = executorService.awaitTermination(1000L, TimeUnit.MILLISECONDS);
boolean ignored = executorService.awaitTermination(5000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException ignored) {
}
if (process == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.prism;

import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PortablePipelineOptions;

Expand All @@ -25,6 +26,9 @@
* org.apache.beam.sdk.Pipeline} on the {@link PrismRunner}.
*/
public interface PrismPipelineOptions extends PortablePipelineOptions {

String JOB_PORT_FLAG_NAME = "job_port";

@Description(
"Path or URL to a prism binary, or zipped binary for the current "
+ "platform (Operating System and Architecture). May also be an Apache "
Expand All @@ -41,4 +45,10 @@ public interface PrismPipelineOptions extends PortablePipelineOptions {
String getPrismVersionOverride();

void setPrismVersionOverride(String prismVersionOverride);

@Description("Enable or disable Prism Web UI")
@Default.Boolean(true)
Boolean getEnableWebUI();

void setEnableWebUI(Boolean enableWebUI);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.IOException;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.metrics.MetricResults;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;

/**
Expand All @@ -29,14 +28,8 @@
*/
class PrismPipelineResult implements PipelineResult {

static PrismPipelineResult of(PipelineResult delegate, PrismExecutor executor) {
return new PrismPipelineResult(delegate, executor::stop);
}

private final PipelineResult delegate;
private final Runnable cancel;
private @Nullable MetricResults terminalMetrics;
private @Nullable State terminalState;

/**
* Instantiate the {@link PipelineResult} from the {@param delegate} and a {@param cancel} to be
Expand All @@ -50,9 +43,6 @@ static PrismPipelineResult of(PipelineResult delegate, PrismExecutor executor) {
/** Forwards the result of the delegate {@link PipelineResult#getState}. */
@Override
public State getState() {
if (terminalState != null) {
return terminalState;
}
return delegate.getState();
}

Expand All @@ -64,8 +54,6 @@ public State getState() {
@Override
public State cancel() throws IOException {
State state = delegate.cancel();
this.terminalMetrics = delegate.metrics();
this.terminalState = state;
this.cancel.run();
return state;
}
Expand All @@ -78,8 +66,6 @@ public State cancel() throws IOException {
@Override
public State waitUntilFinish(Duration duration) {
State state = delegate.waitUntilFinish(duration);
this.terminalMetrics = delegate.metrics();
this.terminalState = state;
this.cancel.run();
return state;
}
Expand All @@ -92,18 +78,13 @@ public State waitUntilFinish(Duration duration) {
@Override
public State waitUntilFinish() {
State state = delegate.waitUntilFinish();
this.terminalMetrics = delegate.metrics();
this.terminalState = state;
this.cancel.run();
return state;
}

/** Forwards the result of the delegate {@link PipelineResult#metrics}. */
@Override
public MetricResults metrics() {
if (terminalMetrics != null) {
return terminalMetrics;
}
return delegate.metrics();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
*/
package org.apache.beam.runners.prism;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.io.IOException;
import java.net.ServerSocket;
import java.util.Arrays;
import org.apache.beam.runners.portability.PortableRunner;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
Expand All @@ -34,34 +40,38 @@
* submit to an already running Prism service, use the {@link PortableRunner} with the {@link
* PortablePipelineOptions#getJobEndpoint()} option instead. Prism is a {@link
* org.apache.beam.runners.portability.PortableRunner} maintained at <a
* href="https://github.com/apache/beam/tree/master/sdks/go/cmd/prism">sdks/go/cmd/prism</a>.
* href="https://github.com/apache/beam/tree/master/sdks/go/cmd/prism">sdks/go/cmd/prism</a>. For
* testing, use {@link TestPrismRunner}.
*/
// TODO(https://github.com/apache/beam/issues/31793): add public modifier after finalizing
// PrismRunner. Depends on: https://github.com/apache/beam/issues/31402 and
// https://github.com/apache/beam/issues/31792.
class PrismRunner extends PipelineRunner<PipelineResult> {
public class PrismRunner extends PipelineRunner<PipelineResult> {

private static final Logger LOG = LoggerFactory.getLogger(PrismRunner.class);

private static final String DEFAULT_PRISM_ENDPOINT = "localhost:8073";

private final PortableRunner internal;
private final PrismPipelineOptions prismPipelineOptions;

private PrismRunner(PortableRunner internal, PrismPipelineOptions prismPipelineOptions) {
this.internal = internal;
protected PrismRunner(PrismPipelineOptions prismPipelineOptions) {
this.prismPipelineOptions = prismPipelineOptions;
}

PrismPipelineOptions getPrismPipelineOptions() {
return prismPipelineOptions;
}

/**
* Invoked from {@link Pipeline#run} where {@link PrismRunner} instantiates using {@link
* PrismPipelineOptions} configuration details.
*/
public static PrismRunner fromOptions(PipelineOptions options) {
PrismPipelineOptions prismPipelineOptions = options.as(PrismPipelineOptions.class);
validate(prismPipelineOptions);
assignDefaultsIfNeeded(prismPipelineOptions);
PortableRunner internal = PortableRunner.fromOptions(options);
return new PrismRunner(internal, prismPipelineOptions);
return new PrismRunner(prismPipelineOptions);
}

private static void validate(PrismPipelineOptions options) {
checkArgument(
Strings.isNullOrEmpty(options.getJobEndpoint()),
"when specifying --jobEndpoint, use --runner=PortableRunner instead");
}

@Override
Expand All @@ -72,15 +82,44 @@ public PipelineResult run(Pipeline pipeline) {
prismPipelineOptions.getDefaultEnvironmentType(),
prismPipelineOptions.getJobEndpoint());

return internal.run(pipeline);
try {
PrismExecutor executor = startPrism();
PortableRunner delegate = PortableRunner.fromOptions(prismPipelineOptions);
return new PrismPipelineResult(delegate.run(pipeline), executor::stop);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

PrismExecutor startPrism() throws IOException {
PrismLocator locator = new PrismLocator(prismPipelineOptions);
int port = findAvailablePort();
String portFlag = String.format(PrismExecutor.JOB_PORT_FLAG_TEMPLATE, port);
String serveHttpFlag =
String.format(
PrismExecutor.SERVE_HTTP_FLAG_TEMPLATE, prismPipelineOptions.getEnableWebUI());
String endpoint = "localhost:" + port;
prismPipelineOptions.setJobEndpoint(endpoint);
String command = locator.resolve();
PrismExecutor executor =
PrismExecutor.builder()
.setCommand(command)
.setArguments(Arrays.asList(portFlag, serveHttpFlag))
.build();
executor.execute();
checkState(executor.isAlive());
return executor;
}

private static void assignDefaultsIfNeeded(PrismPipelineOptions prismPipelineOptions) {
if (Strings.isNullOrEmpty(prismPipelineOptions.getDefaultEnvironmentType())) {
prismPipelineOptions.setDefaultEnvironmentType(Environments.ENVIRONMENT_LOOPBACK);
}
if (Strings.isNullOrEmpty(prismPipelineOptions.getJobEndpoint())) {
prismPipelineOptions.setJobEndpoint(DEFAULT_PRISM_ENDPOINT);
}

private static int findAvailablePort() throws IOException {
try (ServerSocket socket = new ServerSocket(0)) {
return socket.getLocalPort();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import com.google.auto.service.AutoService;
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.runners.PipelineRunnerRegistrar;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;

/**
* Registers {@link PrismRunner} and {@link TestPrismRunner} with {@link PipelineRunnerRegistrar}.
*/
@AutoService(PipelineRunnerRegistrar.class)
public class PrismRunnerRegistrar implements PipelineRunnerRegistrar {

@Override
public Iterable<Class<? extends PipelineRunner<?>>> getPipelineRunners() {
return ImmutableList.of(PrismRunner.class, TestPrismRunner.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import org.apache.beam.sdk.testing.TestPipelineOptions;

/** {@link org.apache.beam.sdk.options.PipelineOptions} for use with the {@link TestPrismRunner}. */
public interface TestPrismPipelineOptions extends PrismPipelineOptions, TestPipelineOptions {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
import static org.hamcrest.MatcherAssert.assertThat;

import java.util.function.Supplier;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.options.PipelineOptions;
import org.hamcrest.Matchers;
import org.joda.time.Duration;

/**
* {@link TestPrismRunner} is the recommended {@link PipelineRunner} to use for tests that rely on
* <a href="https://github.com/apache/beam/tree/master/sdks/go/cmd/prism">sdks/go/cmd/prism</a>. See
* {@link PrismRunner} for more details.
*/
public class TestPrismRunner extends PipelineRunner<PipelineResult> {

private final PrismRunner internal;
private final TestPrismPipelineOptions prismPipelineOptions;

/**
* Invoked from {@link Pipeline#run} where {@link TestPrismRunner} instantiates using {@link
* TestPrismPipelineOptions} configuration details.
*/
public static TestPrismRunner fromOptions(PipelineOptions options) {
TestPrismPipelineOptions prismPipelineOptions = options.as(TestPrismPipelineOptions.class);
PrismRunner delegate = PrismRunner.fromOptions(options);
return new TestPrismRunner(delegate, prismPipelineOptions);
}

private TestPrismRunner(PrismRunner internal, TestPrismPipelineOptions options) {
this.internal = internal;
this.prismPipelineOptions = options;
}

TestPrismPipelineOptions getTestPrismPipelineOptions() {
return prismPipelineOptions;
}

@Override
public PipelineResult run(Pipeline pipeline) {
PipelineResult result = internal.run(pipeline);
PipelineResult.State state = getWaitUntilFinishRunnable(result).get();
assertThat(
"Pipeline did not succeed. Check Prism logs for further details.",
state,
Matchers.is(PipelineResult.State.DONE));
return result;
}

private Supplier<PipelineResult.State> getWaitUntilFinishRunnable(PipelineResult result) {
if (prismPipelineOptions.getTestTimeoutSeconds() != null) {
Long testTimeoutSeconds = checkStateNotNull(prismPipelineOptions.getTestTimeoutSeconds());
return () -> result.waitUntilFinish(Duration.standardSeconds(testTimeoutSeconds));
}
return result::waitUntilFinish;
}
}
Loading

0 comments on commit 02c94dd

Please sign in to comment.