Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify budget distribution logic and new worker metadata consumption #32775

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.streaming.harness;

import java.io.Closeable;
import java.util.function.Supplier;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;

@Internal
@ThreadSafe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the class to be threadsafe the provided supplier needs to be thread safe, can we add a comment on the supplier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// TODO (m-trieu): replace Supplier<Stream> with Stream after github.com/apache/beam/pull/32774/ is
// merged
final class GlobalDataStreamSender implements Closeable, Supplier<GetDataStream> {
private final Endpoint endpoint;
private final Supplier<GetDataStream> delegate;
private volatile boolean started;

GlobalDataStreamSender(Supplier<GetDataStream> delegate, Endpoint endpoint) {
// Ensures that the Supplier is thread-safe
this.delegate = Suppliers.memoize(delegate::get);
this.started = false;
this.endpoint = endpoint;
}

@Override
public GetDataStream get() {
if (!started) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not thread safe, the reads and writes to started needs to be synchronized.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

started is volatile but it also isn't controlling anything useful at the moment as far as I can tell. Should started be removed for now instead? Or should close() be changed to do something only if started?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the change i wanted is actually here https://github.com/apache/beam/pull/32774/files#diff-561fe80cd3d4e69975cab3d41268f5eb6cda8f583f9d1e5dfac91334efb351e0

but that depends on the other PR

ill guard just close for now

started = true;
}

return delegate.get();
}

@Override
public void close() {
if (started) {
delegate.get().shutdown();
}
}

Endpoint endpoint() {
return endpoint;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,37 @@
package org.apache.beam.runners.dataflow.worker.streaming.harness;

import com.google.auto.value.AutoValue;
import java.util.function.Supplier;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;

/**
* Represents the current state of connections to Streaming Engine. Connections are updated when
* backend workers assigned to the key ranges being processed by this user worker change during
* Represents the current state of connections to the Streaming Engine backend. Backends are updated
* when backend workers assigned to the key ranges being processed by this user worker change during
* pipeline execution. For example, changes can happen via autoscaling, load-balancing, or other
* backend updates.
*/
@AutoValue
abstract class StreamingEngineConnectionState {
static final StreamingEngineConnectionState EMPTY = builder().build();
abstract class StreamingEngineBackends {
static final StreamingEngineBackends EMPTY = builder().build();

static Builder builder() {
return new AutoValue_StreamingEngineConnectionState.Builder()
.setWindmillConnections(ImmutableMap.of())
return new AutoValue_StreamingEngineBackends.Builder()
.setWindmillStreams(ImmutableMap.of())
.setGlobalDataStreams(ImmutableMap.of());
}

abstract ImmutableMap<Endpoint, WindmillConnection> windmillConnections();

abstract ImmutableMap<WindmillConnection, WindmillStreamSender> windmillStreams();
abstract ImmutableMap<Endpoint, WindmillStreamSender> windmillStreams();

/** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them. */
abstract ImmutableMap<String, Supplier<GetDataStream>> globalDataStreams();
abstract ImmutableMap<String, GlobalDataStreamSender> globalDataStreams();

@AutoValue.Builder
abstract static class Builder {
public abstract Builder setWindmillConnections(
ImmutableMap<Endpoint, WindmillConnection> value);

public abstract Builder setWindmillStreams(
ImmutableMap<WindmillConnection, WindmillStreamSender> value);
public abstract Builder setWindmillStreams(ImmutableMap<Endpoint, WindmillStreamSender> value);

public abstract Builder setGlobalDataStreams(
ImmutableMap<String, Supplier<GetDataStream>> value);
ImmutableMap<String, GlobalDataStreamSender> value);

public abstract StreamingEngineConnectionState build();
public abstract StreamingEngineBackends build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.dataflow.worker.streaming.harness;

import java.io.Closeable;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
Expand Down Expand Up @@ -49,7 +50,7 @@
* {@link GetWorkBudget} is set.
*
* <p>Once started, the underlying streams are "alive" until they are manually closed via {@link
* #closeAllStreams()}.
* #close()} ()}.
*
* <p>If closed, it means that the backend endpoint is no longer in the worker set. Once closed,
* these instances are not reused.
Expand All @@ -59,7 +60,7 @@
*/
@Internal
@ThreadSafe
final class WindmillStreamSender implements GetWorkBudgetSpender {
final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable {
private final AtomicBoolean started;
private final AtomicReference<GetWorkBudget> getWorkBudget;
private final Supplier<GetWorkStream> getWorkStream;
Expand Down Expand Up @@ -103,9 +104,9 @@ private WindmillStreamSender(
connection,
withRequestBudget(getWorkRequest, getWorkBudget.get()),
streamingEngineThrottleTimers.getWorkThrottleTimer(),
() -> FixedStreamHeartbeatSender.create(getDataStream.get()),
() -> getDataClientFactory.apply(getDataStream.get()),
workCommitter,
FixedStreamHeartbeatSender.create(getDataStream.get()),
getDataClientFactory.apply(getDataStream.get()),
workCommitter.get(),
workItemScheduler));
}

Expand Down Expand Up @@ -141,7 +142,8 @@ void startStreams() {
started.set(true);
}

void closeAllStreams() {
@Override
public void close() {
// Supplier<Stream>.get() starts the stream which is an expensive operation as it initiates the
// streaming RPCs by possibly making calls over the network. Do not close the streams unless
// they have already been started.
Expand All @@ -154,18 +156,13 @@ void closeAllStreams() {
}

@Override
public void adjustBudget(long itemsDelta, long bytesDelta) {
getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta));
public void setBudget(long items, long bytes) {
getWorkBudget.set(getWorkBudget.get().apply(items, bytes));
if (started.get()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since started is set inside startStreams after getWorkStream.get() without mutexes guarding them, started can be false here setBudget is called by a different thread when startStreams is in the middle of execution. do we need to change the synchronization in this class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is in this PR #32774 and i added synchronization there
this is still the memoized Supplier.get() which is threadsafe

getWorkStream.get().adjustBudget(itemsDelta, bytesDelta);
getWorkStream.get().setBudget(items, bytes);
}
}

@Override
public GetWorkBudget remainingBudget() {
return started.get() ? getWorkStream.get().remainingBudget() : getWorkBudget.get();
}

long getAndResetThrottleTime() {
return streamingEngineThrottleTimers.getAndResetThrottleTime();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet;

import com.google.auto.value.AutoValue;
import java.net.Inet6Address;
Expand All @@ -27,8 +27,8 @@
import java.util.Map;
import java.util.Optional;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -41,6 +41,14 @@
public abstract class WindmillEndpoints {
private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class);

public static WindmillEndpoints none() {
return WindmillEndpoints.builder()
.setVersion(Long.MAX_VALUE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min seems safer. Otherwise if somehow none() was observed the logic to ensure version is increasing mean's we'd never process another endpoint set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.setWindmillEndpoints(ImmutableSet.of())
.setGlobalDataEndpoints(ImmutableMap.of())
.build();
}

public static WindmillEndpoints from(
Windmill.WorkerMetadataResponse workerMetadataResponseProto) {
ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers =
Expand All @@ -53,14 +61,15 @@ public static WindmillEndpoints from(
endpoint.getValue(),
workerMetadataResponseProto.getExternalEndpoint())));

ImmutableList<WindmillEndpoints.Endpoint> windmillServers =
ImmutableSet<WindmillEndpoints.Endpoint> windmillServers =
workerMetadataResponseProto.getWorkEndpointsList().stream()
.map(
endpointProto ->
Endpoint.from(endpointProto, workerMetadataResponseProto.getExternalEndpoint()))
.collect(toImmutableList());
.collect(toImmutableSet());

return WindmillEndpoints.builder()
.setVersion(workerMetadataResponseProto.getMetadataVersion())
.setGlobalDataEndpoints(globalDataServers)
.setWindmillEndpoints(windmillServers)
.build();
Expand Down Expand Up @@ -123,6 +132,9 @@ private static Optional<HostAndPort> tryParseDirectEndpointIntoIpV6Address(
directEndpointAddress.getHostAddress(), (int) endpointProto.getPort()));
}

/** Version of the endpoints which increases with every modification. */
public abstract long version();

/**
* Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key
* is a global data tag and the value is the endpoint where the data associated with the global
Expand All @@ -138,7 +150,7 @@ private static Optional<HostAndPort> tryParseDirectEndpointIntoIpV6Address(
* Windmill servers. Returns a list of endpoints used to communicate with the corresponding
* Windmill servers.
*/
public abstract ImmutableList<Endpoint> windmillEndpoints();
public abstract ImmutableSet<Endpoint> windmillEndpoints();

/**
* Representation of an endpoint in {@link Windmill.WorkerMetadataResponse.Endpoint} proto with
Expand Down Expand Up @@ -204,13 +216,15 @@ public abstract static class Builder {

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setVersion(long version);

public abstract Builder setGlobalDataEndpoints(
ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers);

public abstract Builder setWindmillEndpoints(
ImmutableList<WindmillEndpoints.Endpoint> windmillServers);
ImmutableSet<WindmillEndpoints.Endpoint> windmillServers);

abstract ImmutableList.Builder<WindmillEndpoints.Endpoint> windmillEndpointsBuilder();
abstract ImmutableSet.Builder<WindmillEndpoints.Endpoint> windmillEndpointsBuilder();

public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint endpoint) {
windmillEndpointsBuilder().add(endpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,36 @@

import com.google.auto.value.AutoOneOf;
import com.google.auto.value.AutoValue;
import java.net.Inet6Address;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;

/** Used to create channels to communicate with Streaming Engine via gRpc. */
@AutoOneOf(WindmillServiceAddress.Kind.class)
public abstract class WindmillServiceAddress {
public static WindmillServiceAddress create(Inet6Address ipv6Address) {
return AutoOneOf_WindmillServiceAddress.ipv6(ipv6Address);
}

public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) {
return AutoOneOf_WindmillServiceAddress.gcpServiceAddress(gcpServiceAddress);
}

public abstract Kind getKind();
public static WindmillServiceAddress create(
AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) {
return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress(
authenticatedGcpServiceAddress);
}

public abstract Inet6Address ipv6();
public abstract Kind getKind();

public abstract HostAndPort gcpServiceAddress();

public abstract AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress();

public static WindmillServiceAddress create(
AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) {
return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress(
authenticatedGcpServiceAddress);
public final HostAndPort getServiceAddress() {
return getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS
? gcpServiceAddress()
: authenticatedGcpServiceAddress().gcpServiceAddress();
}

public enum Kind {
IPV6,
GCP_SERVICE_ADDRESS,
// TODO(m-trieu): Use for direct connections when ALTS is enabled.
AUTHENTICATED_GCP_SERVICE_ADDRESS
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ public interface WindmillStream {
@ThreadSafe
interface GetWorkStream extends WindmillStream {
/** Adjusts the {@link GetWorkBudget} for the stream. */
void adjustBudget(long itemsDelta, long bytesDelta);
void setBudget(GetWorkBudget newBudget);

/** Returns the remaining in-flight {@link GetWorkBudget}. */
GetWorkBudget remainingBudget();
default void setBudget(long newItems, long newBytes) {
setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build());
}
}

/** Interface for streaming GetDataRequests to Windmill. */
Expand Down
Loading
Loading