Skip to content

Commit

Permalink
Allows customization in the Flink State binding
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Liu committed Feb 28, 2024
1 parent 0145270 commit 2a3711d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.apache.beam.runners.flink.translation.utils.Workarounds;
import org.apache.beam.runners.flink.translation.wrappers.streaming.stableinput.BufferingDoFnRunner;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateBinders;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StructuredCoder;
Expand Down Expand Up @@ -552,7 +553,7 @@ private void earlyBindStateIfNeeded() throws IllegalArgumentException, IllegalAc
if (doFn != null) {
DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
FlinkStateInternals.EarlyBinder earlyBinder =
new FlinkStateInternals.EarlyBinder(getKeyedStateBackend(), serializedOptions);
FlinkStateBinders.getEarlyBinder(getKeyedStateBackend(), serializedOptions, stepName);
for (DoFnSignature.StateDeclaration value : signature.stateDeclarations().values()) {
StateSpec<?> spec =
(StateSpec<?>) signature.stateDeclarations().get(value.id()).field().get(doFn);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.flink.translation.wrappers.streaming.state;

import java.util.ServiceLoader;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.flink.runtime.state.KeyedStateBackend;

/**
* LinkedIn-only: allow custom configuration of {@link
* org.apache.flink.api.common.state.StateDescriptor} during the Beam state binding.
*/
@SuppressWarnings({"rawtypes", "nullness"})
public class FlinkStateBinders {
/** An interface that allows custom {@link org.apache.beam.sdk.state.StateBinder}. */
public interface Registrar {
FlinkStateInternals.EarlyBinder getEarlyBinder(
KeyedStateBackend keyedStateBackend,
SerializablePipelineOptions pipelineOptions,
String stepName);
}

private static final Registrar REGISTRAR =
Iterables.getOnlyElement(ServiceLoader.load(Registrar.class), null);

/**
* Returns {@link
* org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals.EarlyBinder}
* that creates the user states from the Flink state backend.
*/
public static FlinkStateInternals.EarlyBinder getEarlyBinder(
KeyedStateBackend keyedStateBackend,
SerializablePipelineOptions pipelineOptions,
String stepName) {
if (REGISTRAR != null) {
return REGISTRAR.getEarlyBinder(keyedStateBackend, pipelineOptions, stepName);
} else {
return new FlinkStateInternals.EarlyBinder(keyedStateBackend, pipelineOptions);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1581,7 +1581,7 @@ public EarlyBinder(
@Override
public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Coder<T> coder) {
try {
keyedStateBackend.getOrCreateKeyedState(
getOrCreateKeyedState(
StringSerializer.INSTANCE,
new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(coder, pipelineOptions)));
} catch (Exception e) {
Expand All @@ -1594,7 +1594,7 @@ public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> spec, Cod
@Override
public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T> elemCoder) {
try {
keyedStateBackend.getOrCreateKeyedState(
getOrCreateKeyedState(
StringSerializer.INSTANCE,
new ListStateDescriptor<>(id, new CoderTypeSerializer<>(elemCoder, pipelineOptions)));
} catch (Exception e) {
Expand All @@ -1607,7 +1607,7 @@ public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, Coder<T>
@Override
public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) {
try {
keyedStateBackend.getOrCreateKeyedState(
getOrCreateKeyedState(
StringSerializer.INSTANCE,
new MapStateDescriptor<>(
id,
Expand All @@ -1626,7 +1626,7 @@ public <KeyT, ValueT> org.apache.beam.sdk.state.MapState<KeyT, ValueT> bindMap(
Coder<KeyT> mapKeyCoder,
Coder<ValueT> mapValueCoder) {
try {
keyedStateBackend.getOrCreateKeyedState(
getOrCreateKeyedState(
StringSerializer.INSTANCE,
new MapStateDescriptor<>(
id,
Expand All @@ -1642,7 +1642,7 @@ public <KeyT, ValueT> org.apache.beam.sdk.state.MapState<KeyT, ValueT> bindMap(
public <T> OrderedListState<T> bindOrderedList(
String id, StateSpec<OrderedListState<T>> spec, Coder<T> elemCoder) {
try {
keyedStateBackend.getOrCreateKeyedState(
getOrCreateKeyedState(
StringSerializer.INSTANCE,
new ListStateDescriptor<>(
id,
Expand Down Expand Up @@ -1671,7 +1671,7 @@ public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> bindCom
Coder<AccumT> accumCoder,
Combine.CombineFn<InputT, AccumT, OutputT> combineFn) {
try {
keyedStateBackend.getOrCreateKeyedState(
getOrCreateKeyedState(
StringSerializer.INSTANCE,
new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, pipelineOptions)));
} catch (Exception e) {
Expand All @@ -1688,7 +1688,7 @@ CombiningState<InputT, AccumT, OutputT> bindCombiningWithContext(
Coder<AccumT> accumCoder,
CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn) {
try {
keyedStateBackend.getOrCreateKeyedState(
getOrCreateKeyedState(
StringSerializer.INSTANCE,
new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(accumCoder, pipelineOptions)));
} catch (Exception e) {
Expand All @@ -1701,7 +1701,7 @@ CombiningState<InputT, AccumT, OutputT> bindCombiningWithContext(
public WatermarkHoldState bindWatermark(
String id, StateSpec<WatermarkHoldState> spec, TimestampCombiner timestampCombiner) {
try {
keyedStateBackend.getOrCreateKeyedState(
getOrCreateKeyedState(
VoidNamespaceSerializer.INSTANCE,
new MapStateDescriptor<>(
"watermark-holds",
Expand All @@ -1712,5 +1712,11 @@ public WatermarkHoldState bindWatermark(
}
return null;
}

protected <NamespaceT, StateT extends org.apache.flink.api.common.state.State, T> StateT getOrCreateKeyedState(
TypeSerializer<NamespaceT> namespaceSerializer, StateDescriptor<StateT, T> stateDescriptor)
throws Exception {
return (StateT) keyedStateBackend.getOrCreateKeyedState(namespaceSerializer, stateDescriptor);
}
}
}

0 comments on commit 2a3711d

Please sign in to comment.