Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Java TestPrismRunner and PrismRunner #32294

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions runners/prism/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies {
implementation library.java.slf4j_api
implementation library.java.vendored_grpc_1_60_1
implementation library.java.vendored_guava_32_1_2_jre
compileOnly library.java.hamcrest

testImplementation library.java.junit
testImplementation library.java.mockito_core
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
abstract class PrismExecutor {

private static final Logger LOG = LoggerFactory.getLogger(PrismExecutor.class);
static final String IDLE_SHUTDOWN_TIMEOUT = "-idle_shutdown_timeout=%s";
static final String JOB_PORT_FLAG_TEMPLATE = "-job_port=%s";
static final String SERVE_HTTP_FLAG_TEMPLATE = "-serve_http=%s";

protected @MonotonicNonNull Process process;
protected ExecutorService executorService = Executors.newSingleThreadExecutor();
Expand All @@ -71,7 +74,7 @@ void stop() {
}
executorService.shutdown();
try {
boolean ignored = executorService.awaitTermination(1000L, TimeUnit.MILLISECONDS);
boolean ignored = executorService.awaitTermination(5000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException ignored) {
}
if (process == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ private String resolve(Path from, Path to) throws IOException {
}

copyFn.accept(from.toUri().toURL().openStream(), to);
ByteStreams.copy(from.toUri().toURL().openStream(), Files.newOutputStream(to));
try (OutputStream out = Files.newOutputStream(to)) {
ByteStreams.copy(from.toUri().toURL().openStream(), out);
}
Files.setPosixFilePermissions(to, PERMS);

return to.toString();
Expand Down Expand Up @@ -159,16 +161,16 @@ private static void unzip(InputStream from, Path to) {
}

private static void copy(InputStream from, Path to) {
try {
ByteStreams.copy(from, Files.newOutputStream(to));
try (OutputStream out = Files.newOutputStream(to)) {
ByteStreams.copy(from, out);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static void download(URL from, Path to) {
try {
ByteStreams.copy(from.openStream(), Files.newOutputStream(to));
try (OutputStream out = Files.newOutputStream(to)) {
ByteStreams.copy(from.openStream(), out);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.prism;

import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PortablePipelineOptions;

Expand All @@ -25,6 +26,9 @@
* org.apache.beam.sdk.Pipeline} on the {@link PrismRunner}.
*/
public interface PrismPipelineOptions extends PortablePipelineOptions {

String JOB_PORT_FLAG_NAME = "job_port";

@Description(
"Path or URL to a prism binary, or zipped binary for the current "
+ "platform (Operating System and Architecture). May also be an Apache "
Expand All @@ -41,4 +45,17 @@ public interface PrismPipelineOptions extends PortablePipelineOptions {
String getPrismVersionOverride();

void setPrismVersionOverride(String prismVersionOverride);

@Description("Enable or disable Prism Web UI")
@Default.Boolean(true)
Boolean getEnableWebUI();

void setEnableWebUI(Boolean enableWebUI);

@Description(
"Duration, represented as a String, that prism will wait for a new job before shutting itself down. Negative durations disable auto shutdown. Valid time units are \"ns\", \"us\" (or \"µs\"), \"ms\", \"s\", \"m\", \"h\".")
@Default.String("5m")
String getIdleShutdownTimeout();

void setIdleShutdownTimeout(String idleShutdownTimeout);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.IOException;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.metrics.MetricResults;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;

/**
Expand All @@ -29,30 +28,25 @@
*/
class PrismPipelineResult implements PipelineResult {

static PrismPipelineResult of(PipelineResult delegate, PrismExecutor executor) {
return new PrismPipelineResult(delegate, executor::stop);
}

private final PipelineResult delegate;
private final Runnable cancel;
private @Nullable MetricResults terminalMetrics;
private @Nullable State terminalState;
private final Runnable cleanup;

/**
* Instantiate the {@link PipelineResult} from the {@param delegate} and a {@param cancel} to be
* called when stopping the underlying executable Job management service.
*/
PrismPipelineResult(PipelineResult delegate, Runnable cancel) {
this.delegate = delegate;
this.cancel = cancel;
this.cleanup = cancel;
}

Runnable getCleanup() {
return cleanup;
}

/** Forwards the result of the delegate {@link PipelineResult#getState}. */
@Override
public State getState() {
if (terminalState != null) {
return terminalState;
}
return delegate.getState();
}

Expand All @@ -64,9 +58,7 @@ public State getState() {
@Override
public State cancel() throws IOException {
State state = delegate.cancel();
this.terminalMetrics = delegate.metrics();
this.terminalState = state;
this.cancel.run();
this.cleanup.run();
return state;
}

Expand All @@ -78,9 +70,7 @@ public State cancel() throws IOException {
@Override
public State waitUntilFinish(Duration duration) {
State state = delegate.waitUntilFinish(duration);
this.terminalMetrics = delegate.metrics();
this.terminalState = state;
this.cancel.run();
this.cleanup.run();
return state;
}

Expand All @@ -92,18 +82,13 @@ public State waitUntilFinish(Duration duration) {
@Override
public State waitUntilFinish() {
State state = delegate.waitUntilFinish();
this.terminalMetrics = delegate.metrics();
this.terminalState = state;
this.cancel.run();
this.cleanup.run();
return state;
}

/** Forwards the result of the delegate {@link PipelineResult#metrics}. */
@Override
public MetricResults metrics() {
if (terminalMetrics != null) {
return terminalMetrics;
}
return delegate.metrics();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
*/
package org.apache.beam.runners.prism;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.io.IOException;
import java.net.ServerSocket;
import java.util.Arrays;
import org.apache.beam.runners.portability.PortableRunner;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
Expand All @@ -34,34 +40,38 @@
* submit to an already running Prism service, use the {@link PortableRunner} with the {@link
* PortablePipelineOptions#getJobEndpoint()} option instead. Prism is a {@link
* org.apache.beam.runners.portability.PortableRunner} maintained at <a
* href="https://github.com/apache/beam/tree/master/sdks/go/cmd/prism">sdks/go/cmd/prism</a>.
* href="https://github.com/apache/beam/tree/master/sdks/go/cmd/prism">sdks/go/cmd/prism</a>. For
* testing, use {@link TestPrismRunner}.
*/
// TODO(https://github.com/apache/beam/issues/31793): add public modifier after finalizing
// PrismRunner. Depends on: https://github.com/apache/beam/issues/31402 and
// https://github.com/apache/beam/issues/31792.
class PrismRunner extends PipelineRunner<PipelineResult> {
public class PrismRunner extends PipelineRunner<PipelineResult> {

private static final Logger LOG = LoggerFactory.getLogger(PrismRunner.class);

private static final String DEFAULT_PRISM_ENDPOINT = "localhost:8073";

private final PortableRunner internal;
private final PrismPipelineOptions prismPipelineOptions;

private PrismRunner(PortableRunner internal, PrismPipelineOptions prismPipelineOptions) {
this.internal = internal;
protected PrismRunner(PrismPipelineOptions prismPipelineOptions) {
this.prismPipelineOptions = prismPipelineOptions;
}

PrismPipelineOptions getPrismPipelineOptions() {
return prismPipelineOptions;
}

/**
* Invoked from {@link Pipeline#run} where {@link PrismRunner} instantiates using {@link
* PrismPipelineOptions} configuration details.
*/
public static PrismRunner fromOptions(PipelineOptions options) {
PrismPipelineOptions prismPipelineOptions = options.as(PrismPipelineOptions.class);
validate(prismPipelineOptions);
assignDefaultsIfNeeded(prismPipelineOptions);
PortableRunner internal = PortableRunner.fromOptions(options);
return new PrismRunner(internal, prismPipelineOptions);
return new PrismRunner(prismPipelineOptions);
}

private static void validate(PrismPipelineOptions options) {
checkArgument(
Strings.isNullOrEmpty(options.getJobEndpoint()),
"when specifying --jobEndpoint, use --runner=PortableRunner instead");
lostluck marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
Expand All @@ -72,15 +82,47 @@ public PipelineResult run(Pipeline pipeline) {
prismPipelineOptions.getDefaultEnvironmentType(),
prismPipelineOptions.getJobEndpoint());

return internal.run(pipeline);
try {
PrismExecutor executor = startPrism();
PortableRunner delegate = PortableRunner.fromOptions(prismPipelineOptions);
return new PrismPipelineResult(delegate.run(pipeline), executor::stop);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

PrismExecutor startPrism() throws IOException {
PrismLocator locator = new PrismLocator(prismPipelineOptions);
int port = findAvailablePort();
String portFlag = String.format(PrismExecutor.JOB_PORT_FLAG_TEMPLATE, port);
String serveHttpFlag =
String.format(
PrismExecutor.SERVE_HTTP_FLAG_TEMPLATE, prismPipelineOptions.getEnableWebUI());
lostluck marked this conversation as resolved.
Show resolved Hide resolved
String idleShutdownTimeoutFlag =
String.format(
PrismExecutor.IDLE_SHUTDOWN_TIMEOUT, prismPipelineOptions.getIdleShutdownTimeout());
String endpoint = "localhost:" + port;
prismPipelineOptions.setJobEndpoint(endpoint);
String command = locator.resolve();
PrismExecutor executor =
PrismExecutor.builder()
.setCommand(command)
.setArguments(Arrays.asList(portFlag, serveHttpFlag, idleShutdownTimeoutFlag))
.build();
executor.execute();
checkState(executor.isAlive());
return executor;
}

private static void assignDefaultsIfNeeded(PrismPipelineOptions prismPipelineOptions) {
if (Strings.isNullOrEmpty(prismPipelineOptions.getDefaultEnvironmentType())) {
prismPipelineOptions.setDefaultEnvironmentType(Environments.ENVIRONMENT_LOOPBACK);
}
if (Strings.isNullOrEmpty(prismPipelineOptions.getJobEndpoint())) {
prismPipelineOptions.setJobEndpoint(DEFAULT_PRISM_ENDPOINT);
}

private static int findAvailablePort() throws IOException {
try (ServerSocket socket = new ServerSocket(0)) {
return socket.getLocalPort();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import com.google.auto.service.AutoService;
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.runners.PipelineRunnerRegistrar;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;

/**
* Registers {@link PrismRunner} and {@link TestPrismRunner} with {@link PipelineRunnerRegistrar}.
*/
@AutoService(PipelineRunnerRegistrar.class)
public class PrismRunnerRegistrar implements PipelineRunnerRegistrar {

@Override
public Iterable<Class<? extends PipelineRunner<?>>> getPipelineRunners() {
return ImmutableList.of(PrismRunner.class, TestPrismRunner.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import org.apache.beam.sdk.testing.TestPipelineOptions;

/** {@link org.apache.beam.sdk.options.PipelineOptions} for use with the {@link TestPrismRunner}. */
public interface TestPrismPipelineOptions extends PrismPipelineOptions, TestPipelineOptions {}
Loading
Loading