Skip to content

Commit

Permalink
Implement Java TestPrismRunner and PrismRunner (apache#32294)
Browse files Browse the repository at this point in the history
  • Loading branch information
damondouglas authored and reeba212 committed Dec 4, 2024
1 parent 32b2d57 commit df45125
Show file tree
Hide file tree
Showing 12 changed files with 336 additions and 201 deletions.
1 change: 1 addition & 0 deletions runners/prism/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies {
implementation library.java.slf4j_api
implementation library.java.vendored_grpc_1_60_1
implementation library.java.vendored_guava_32_1_2_jre
compileOnly library.java.hamcrest

testImplementation library.java.junit
testImplementation library.java.mockito_core
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
abstract class PrismExecutor {

private static final Logger LOG = LoggerFactory.getLogger(PrismExecutor.class);
static final String IDLE_SHUTDOWN_TIMEOUT = "-idle_shutdown_timeout=%s";
static final String JOB_PORT_FLAG_TEMPLATE = "-job_port=%s";
static final String SERVE_HTTP_FLAG_TEMPLATE = "-serve_http=%s";

protected @MonotonicNonNull Process process;
protected ExecutorService executorService = Executors.newSingleThreadExecutor();
Expand All @@ -71,7 +74,7 @@ void stop() {
}
executorService.shutdown();
try {
boolean ignored = executorService.awaitTermination(1000L, TimeUnit.MILLISECONDS);
boolean ignored = executorService.awaitTermination(5000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException ignored) {
}
if (process == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ private String resolve(Path from, Path to) throws IOException {
}

copyFn.accept(from.toUri().toURL().openStream(), to);
ByteStreams.copy(from.toUri().toURL().openStream(), Files.newOutputStream(to));
try (OutputStream out = Files.newOutputStream(to)) {
ByteStreams.copy(from.toUri().toURL().openStream(), out);
}
Files.setPosixFilePermissions(to, PERMS);

return to.toString();
Expand Down Expand Up @@ -159,16 +161,16 @@ private static void unzip(InputStream from, Path to) {
}

private static void copy(InputStream from, Path to) {
try {
ByteStreams.copy(from, Files.newOutputStream(to));
try (OutputStream out = Files.newOutputStream(to)) {
ByteStreams.copy(from, out);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static void download(URL from, Path to) {
try {
ByteStreams.copy(from.openStream(), Files.newOutputStream(to));
try (OutputStream out = Files.newOutputStream(to)) {
ByteStreams.copy(from.openStream(), out);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.prism;

import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PortablePipelineOptions;

Expand All @@ -25,6 +26,9 @@
* org.apache.beam.sdk.Pipeline} on the {@link PrismRunner}.
*/
public interface PrismPipelineOptions extends PortablePipelineOptions {

String JOB_PORT_FLAG_NAME = "job_port";

@Description(
"Path or URL to a prism binary, or zipped binary for the current "
+ "platform (Operating System and Architecture). May also be an Apache "
Expand All @@ -41,4 +45,17 @@ public interface PrismPipelineOptions extends PortablePipelineOptions {
String getPrismVersionOverride();

void setPrismVersionOverride(String prismVersionOverride);

@Description("Enable or disable Prism Web UI")
@Default.Boolean(true)
Boolean getEnableWebUI();

void setEnableWebUI(Boolean enableWebUI);

@Description(
"Duration, represented as a String, that prism will wait for a new job before shutting itself down. Negative durations disable auto shutdown. Valid time units are \"ns\", \"us\" (or \"µs\"), \"ms\", \"s\", \"m\", \"h\".")
@Default.String("5m")
String getIdleShutdownTimeout();

void setIdleShutdownTimeout(String idleShutdownTimeout);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.IOException;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.metrics.MetricResults;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;

/**
Expand All @@ -29,30 +28,25 @@
*/
class PrismPipelineResult implements PipelineResult {

static PrismPipelineResult of(PipelineResult delegate, PrismExecutor executor) {
return new PrismPipelineResult(delegate, executor::stop);
}

private final PipelineResult delegate;
private final Runnable cancel;
private @Nullable MetricResults terminalMetrics;
private @Nullable State terminalState;
private final Runnable cleanup;

/**
* Instantiate the {@link PipelineResult} from the {@param delegate} and a {@param cancel} to be
* called when stopping the underlying executable Job management service.
*/
PrismPipelineResult(PipelineResult delegate, Runnable cancel) {
this.delegate = delegate;
this.cancel = cancel;
this.cleanup = cancel;
}

Runnable getCleanup() {
return cleanup;
}

/** Forwards the result of the delegate {@link PipelineResult#getState}. */
@Override
public State getState() {
if (terminalState != null) {
return terminalState;
}
return delegate.getState();
}

Expand All @@ -64,9 +58,7 @@ public State getState() {
@Override
public State cancel() throws IOException {
State state = delegate.cancel();
this.terminalMetrics = delegate.metrics();
this.terminalState = state;
this.cancel.run();
this.cleanup.run();
return state;
}

Expand All @@ -78,9 +70,7 @@ public State cancel() throws IOException {
@Override
public State waitUntilFinish(Duration duration) {
State state = delegate.waitUntilFinish(duration);
this.terminalMetrics = delegate.metrics();
this.terminalState = state;
this.cancel.run();
this.cleanup.run();
return state;
}

Expand All @@ -92,18 +82,13 @@ public State waitUntilFinish(Duration duration) {
@Override
public State waitUntilFinish() {
State state = delegate.waitUntilFinish();
this.terminalMetrics = delegate.metrics();
this.terminalState = state;
this.cancel.run();
this.cleanup.run();
return state;
}

/** Forwards the result of the delegate {@link PipelineResult#metrics}. */
@Override
public MetricResults metrics() {
if (terminalMetrics != null) {
return terminalMetrics;
}
return delegate.metrics();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
*/
package org.apache.beam.runners.prism;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.io.IOException;
import java.net.ServerSocket;
import java.util.Arrays;
import org.apache.beam.runners.portability.PortableRunner;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
Expand All @@ -34,34 +40,38 @@
* submit to an already running Prism service, use the {@link PortableRunner} with the {@link
* PortablePipelineOptions#getJobEndpoint()} option instead. Prism is a {@link
* org.apache.beam.runners.portability.PortableRunner} maintained at <a
* href="https://github.com/apache/beam/tree/master/sdks/go/cmd/prism">sdks/go/cmd/prism</a>.
* href="https://github.com/apache/beam/tree/master/sdks/go/cmd/prism">sdks/go/cmd/prism</a>. For
* testing, use {@link TestPrismRunner}.
*/
// TODO(https://github.com/apache/beam/issues/31793): add public modifier after finalizing
// PrismRunner. Depends on: https://github.com/apache/beam/issues/31402 and
// https://github.com/apache/beam/issues/31792.
class PrismRunner extends PipelineRunner<PipelineResult> {
public class PrismRunner extends PipelineRunner<PipelineResult> {

private static final Logger LOG = LoggerFactory.getLogger(PrismRunner.class);

private static final String DEFAULT_PRISM_ENDPOINT = "localhost:8073";

private final PortableRunner internal;
private final PrismPipelineOptions prismPipelineOptions;

private PrismRunner(PortableRunner internal, PrismPipelineOptions prismPipelineOptions) {
this.internal = internal;
protected PrismRunner(PrismPipelineOptions prismPipelineOptions) {
this.prismPipelineOptions = prismPipelineOptions;
}

PrismPipelineOptions getPrismPipelineOptions() {
return prismPipelineOptions;
}

/**
* Invoked from {@link Pipeline#run} where {@link PrismRunner} instantiates using {@link
* PrismPipelineOptions} configuration details.
*/
public static PrismRunner fromOptions(PipelineOptions options) {
PrismPipelineOptions prismPipelineOptions = options.as(PrismPipelineOptions.class);
validate(prismPipelineOptions);
assignDefaultsIfNeeded(prismPipelineOptions);
PortableRunner internal = PortableRunner.fromOptions(options);
return new PrismRunner(internal, prismPipelineOptions);
return new PrismRunner(prismPipelineOptions);
}

private static void validate(PrismPipelineOptions options) {
checkArgument(
Strings.isNullOrEmpty(options.getJobEndpoint()),
"when specifying --jobEndpoint, use --runner=PortableRunner instead");
}

@Override
Expand All @@ -72,15 +82,47 @@ public PipelineResult run(Pipeline pipeline) {
prismPipelineOptions.getDefaultEnvironmentType(),
prismPipelineOptions.getJobEndpoint());

return internal.run(pipeline);
try {
PrismExecutor executor = startPrism();
PortableRunner delegate = PortableRunner.fromOptions(prismPipelineOptions);
return new PrismPipelineResult(delegate.run(pipeline), executor::stop);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

PrismExecutor startPrism() throws IOException {
PrismLocator locator = new PrismLocator(prismPipelineOptions);
int port = findAvailablePort();
String portFlag = String.format(PrismExecutor.JOB_PORT_FLAG_TEMPLATE, port);
String serveHttpFlag =
String.format(
PrismExecutor.SERVE_HTTP_FLAG_TEMPLATE, prismPipelineOptions.getEnableWebUI());
String idleShutdownTimeoutFlag =
String.format(
PrismExecutor.IDLE_SHUTDOWN_TIMEOUT, prismPipelineOptions.getIdleShutdownTimeout());
String endpoint = "localhost:" + port;
prismPipelineOptions.setJobEndpoint(endpoint);
String command = locator.resolve();
PrismExecutor executor =
PrismExecutor.builder()
.setCommand(command)
.setArguments(Arrays.asList(portFlag, serveHttpFlag, idleShutdownTimeoutFlag))
.build();
executor.execute();
checkState(executor.isAlive());
return executor;
}

private static void assignDefaultsIfNeeded(PrismPipelineOptions prismPipelineOptions) {
if (Strings.isNullOrEmpty(prismPipelineOptions.getDefaultEnvironmentType())) {
prismPipelineOptions.setDefaultEnvironmentType(Environments.ENVIRONMENT_LOOPBACK);
}
if (Strings.isNullOrEmpty(prismPipelineOptions.getJobEndpoint())) {
prismPipelineOptions.setJobEndpoint(DEFAULT_PRISM_ENDPOINT);
}

private static int findAvailablePort() throws IOException {
try (ServerSocket socket = new ServerSocket(0)) {
return socket.getLocalPort();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import com.google.auto.service.AutoService;
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.runners.PipelineRunnerRegistrar;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;

/**
* Registers {@link PrismRunner} and {@link TestPrismRunner} with {@link PipelineRunnerRegistrar}.
*/
@AutoService(PipelineRunnerRegistrar.class)
public class PrismRunnerRegistrar implements PipelineRunnerRegistrar {

@Override
public Iterable<Class<? extends PipelineRunner<?>>> getPipelineRunners() {
return ImmutableList.of(PrismRunner.class, TestPrismRunner.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import org.apache.beam.sdk.testing.TestPipelineOptions;

/** {@link org.apache.beam.sdk.options.PipelineOptions} for use with the {@link TestPrismRunner}. */
public interface TestPrismPipelineOptions extends PrismPipelineOptions, TestPipelineOptions {}
Loading

0 comments on commit df45125

Please sign in to comment.