Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OrderedListState support for SparkRunner #33212

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
* Support OnWindowExpiration in Prism ([#32211](https://github.com/apache/beam/issues/32211)).
* This enables initial Java GroupIntoBatches support.
* Support OrderedListState in Prism ([#32929](https://github.com/apache/beam/issues/32929)).
* Added OrderedList state support in SparkRunner ([#33211](https://github.com/apache/beam/issues/33211)).

## Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.GroupingState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.TimestampedValue;

/** Helpers for merging state. */
@SuppressWarnings({
Expand Down Expand Up @@ -108,6 +110,44 @@ public static <T, W extends BoundedWindow> void mergeBags(
}
}

public static <K, T, W extends BoundedWindow> void mergeOrderedLists(
MergingStateAccessor<K, W> context, StateTag<OrderedListState<T>> address) {
mergeOrderedLists(context.accessInEachMergingWindow(address).values(), context.access(address));
}

public static <T, W extends BoundedWindow> void mergeOrderedLists(
Collection<OrderedListState<T>> sources, OrderedListState<T> result) {
if (sources.isEmpty()) {
// Nothing to merge.
return;
}
// Prefetch everything except what's already in result.
final List<ReadableState<Iterable<TimestampedValue<T>>>> futures =
new ArrayList<>(sources.size());
for (OrderedListState<T> source : sources) {
if (!source.equals(result)) {
prefetchRead(source);
futures.add(source);
}
}
if (futures.isEmpty()) {
// Result already holds all the values.
return;
}
// Transfer from sources to result.
for (ReadableState<Iterable<TimestampedValue<T>>> future : futures) {
for (TimestampedValue<T> timestampedValue : future.read()) {
result.add(timestampedValue);
}
}
// Clear sources except for result.
for (OrderedListState<T> source : sources) {
if (!source.equals(result)) {
source.clear();
}
}
}

/** Merge all set state in {@code address} across all windows under merge. */
public static <K, T, W extends BoundedWindow> void mergeSets(
MergingStateAccessor<K, W> context, StateTag<SetState<T>> address) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.runners.core;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItems;
Expand All @@ -43,16 +44,19 @@
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.GroupingState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.state.WatermarkHoldState;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -76,6 +80,8 @@ public abstract class StateInternalsTest {
StateTags.bag("stringBag", StringUtf8Coder.of());
private static final StateTag<SetState<String>> STRING_SET_ADDR =
StateTags.set("stringSet", StringUtf8Coder.of());
private static final StateTag<OrderedListState<String>> STRING_ORDERED_LIST_ADDR =
StateTags.orderedList("stringOrderedList", StringUtf8Coder.of());
private static final StateTag<MapState<String, Integer>> STRING_MAP_ADDR =
StateTags.map("stringMap", StringUtf8Coder.of(), VarIntCoder.of());
private static final StateTag<WatermarkHoldState> WATERMARK_EARLIEST_ADDR =
Expand Down Expand Up @@ -187,6 +193,99 @@ public void testMergeBagIntoNewNamespace() throws Exception {
assertThat(bag2.read(), Matchers.emptyIterable());
}

@Test
public void testOrderedList() {
final OrderedListState<String> value = underTest.state(NAMESPACE_1, STRING_ORDERED_LIST_ADDR);

assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_ORDERED_LIST_ADDR)));
assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_ORDERED_LIST_ADDR))));

assertThat(value.read(), Matchers.emptyIterable());
Instant base = new Instant(0);
value.add(TimestampedValue.of("world", base.plus(Duration.millis(1))));
assertThat(
value.read(),
containsInAnyOrder(TimestampedValue.of("world", base.plus(Duration.millis(1)))));

value.add(TimestampedValue.of("hello", base));
assertThat(
value.read(),
contains(
TimestampedValue.of("hello", base),
TimestampedValue.of("world", base.plus(Duration.millis(1)))));
value.add(TimestampedValue.of("ignore", base.plus(Duration.millis(10))));

final Iterable<TimestampedValue<String>> range =
value.readRange(base, base.plus(Duration.millis(2L)));
assertThat(
range,
contains(
TimestampedValue.of("hello", base),
TimestampedValue.of("world", base.plus(Duration.millis(1)))));

assertThat(range, not(contains(TimestampedValue.of("ignore", base.plus(Duration.millis(10))))));

value.clear();
assertThat(value.read(), Matchers.emptyIterable());
assertThat(underTest.state(NAMESPACE_1, STRING_ORDERED_LIST_ADDR), equalTo(value));
}

@Test
public void testOrderedListIsEmpty() {
final OrderedListState<String> value = underTest.state(NAMESPACE_1, STRING_ORDERED_LIST_ADDR);
final Instant base = new Instant(0);
assertThat(value.isEmpty().read(), Matchers.is(true));
final ReadableState<Boolean> readFuture = value.isEmpty();
value.add(TimestampedValue.of("hello", base));
assertThat(readFuture.read(), Matchers.is(false));

value.clear();
assertThat(readFuture.read(), Matchers.is(true));
}

@Test
public void testMergeOrderedListIntoSource() {
final OrderedListState<String> bag1 = underTest.state(NAMESPACE_1, STRING_ORDERED_LIST_ADDR);
final OrderedListState<String> bag2 = underTest.state(NAMESPACE_2, STRING_ORDERED_LIST_ADDR);
final Instant base = new Instant();

bag1.add(TimestampedValue.of("World", base.plus(Duration.millis(1L))));
bag2.add(TimestampedValue.of("Hello", base));
bag1.add(TimestampedValue.of("!", base.plus(Duration.millis(5L))));

StateMerging.mergeOrderedLists(Arrays.asList(bag1, bag2), bag1);

assertThat(
bag1.read(),
contains(
TimestampedValue.of("Hello", base),
TimestampedValue.of("World", base.plus(Duration.millis(1L))),
TimestampedValue.of("!", base.plus(Duration.millis(5L)))));
assertThat(bag2.read(), Matchers.emptyIterable());
}

@Test
public void testMergeOrderedListIntoNewNamespace() {
final OrderedListState<String> bag1 = underTest.state(NAMESPACE_1, STRING_ORDERED_LIST_ADDR);
final OrderedListState<String> bag2 = underTest.state(NAMESPACE_2, STRING_ORDERED_LIST_ADDR);
final OrderedListState<String> bag3 = underTest.state(NAMESPACE_3, STRING_ORDERED_LIST_ADDR);
final Instant base = new Instant();

bag1.add(TimestampedValue.of("World", base.plus(Duration.millis(1L))));
bag2.add(TimestampedValue.of("Hello", base));
bag1.add(TimestampedValue.of("!", base.plus(Duration.millis(5L))));

StateMerging.mergeOrderedLists(Arrays.asList(bag1, bag2), bag3);
assertThat(
bag3.read(),
contains(
TimestampedValue.of("Hello", base),
TimestampedValue.of("World", base.plus(Duration.millis(1L))),
TimestampedValue.of("!", base.plus(Duration.millis(5L)))));
assertThat(bag1.read(), Matchers.emptyIterable());
assertThat(bag2.read(), Matchers.emptyIterable());
}

@Test
public void testSet() throws Exception {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ public int hashCode() {
private static class FlinkOrderedListState<T> implements OrderedListState<T> {
private final StateNamespace namespace;
private final String namespaceKey;
private final String stateId;
private final ListStateDescriptor<TimestampedValue<T>> flinkStateDescriptor;
private final KeyedStateBackend<ByteBuffer> flinkStateBackend;

Expand All @@ -466,6 +467,7 @@ private static class FlinkOrderedListState<T> implements OrderedListState<T> {
SerializablePipelineOptions pipelineOptions) {
this.namespace = namespace;
this.namespaceKey = namespace.stringKey();
this.stateId = stateId;
this.flinkStateBackend = flinkStateBackend;
this.flinkStateDescriptor =
new ListStateDescriptor<>(
Expand Down Expand Up @@ -571,6 +573,27 @@ public void clear() {
throw new RuntimeException("Error clearing state.", e);
}
}

@Override
twosom marked this conversation as resolved.
Show resolved Hide resolved
public boolean equals(@Nullable Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}

final FlinkOrderedListState<?> that = (FlinkOrderedListState<?>) o;

return namespace.equals(that.namespace) && stateId.equals(that.stateId);
}

@Override
public int hashCode() {
int result = namespace.hashCode();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use Objects.hashCode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kennknowles
The current implementation follows the same hashCode semantics used in FlinkStateInternals. I'm a bit unclear whether suggesting Objects.hashCode means we should replace the current implementation, or if you're suggesting that overriding hashCode is unnecessary altogether.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overriding equality and hashcode should always occur together, so it is necessary to override hashCode.

I was just suggesting this pattern instead of doing your own math:

return Objects.hashCode(isSuccess, site, throwable);

result = 31 * result + stateId.hashCode();
return result;
}
}

private static class FlinkBagState<T> implements BagState<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,20 @@ public void testSetReadable() {}
@Override
@Ignore
public void testMapReadable() {}

@Override
@Ignore
public void testOrderedList() {}

@Override
@Ignore
public void testOrderedListIsEmpty() {}

@Override
@Ignore
public void testMergeOrderedListIntoSource() {}

@Override
@Ignore
public void testMergeOrderedListIntoNewNamespace() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.function.Function;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespace;
Expand All @@ -37,6 +38,7 @@
import org.apache.beam.sdk.coders.SetCoder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.GroupingState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.MultimapState;
import org.apache.beam.sdk.state.OrderedListState;
Expand All @@ -53,8 +55,12 @@
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Instant;
Expand Down Expand Up @@ -149,8 +155,7 @@ public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
@Override
public <T> OrderedListState<T> bindOrderedList(
String id, StateSpec<OrderedListState<T>> spec, Coder<T> elemCoder) {
throw new UnsupportedOperationException(
String.format("%s is not supported", OrderedListState.class.getSimpleName()));
return new SparkOrderedListState<>(namespace, id, elemCoder);
}

@Override
Expand Down Expand Up @@ -622,4 +627,82 @@ public Boolean read() {
};
}
}

private final class SparkOrderedListState<T> extends AbstractState<List<TimestampedValue<T>>>
implements OrderedListState<T> {

private SparkOrderedListState(StateNamespace namespace, String id, Coder<T> coder) {
super(namespace, id, ListCoder.of(TimestampedValue.TimestampedValueCoder.of(coder)));
}

private SortedMap<Instant, TimestampedValue<T>> readAsMap() {
final List<TimestampedValue<T>> listValues =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for each additional kind of state is to efficiently offer a novel form of a state access. The state access here as the same performance characteristics as ValueState. It is actually better for the runner to reject a pipeline than to run it with performance characteristics that don't match the expected performance contract.

Is there some underlying mechanism in Spark that could implement OrderedListState efficiently and scalably?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your point. Let me share my thoughts on why I chose this implementation.

I've noticed that ListState/OrderedListState is mostly used in situations where writes happen much more frequently than reads. That's why I went with ArrayList instead of SortedMap - it's simply better at handling these frequent writes.

When it comes to reading data, it usually happens in just a couple of scenarios - either during OnTimer execution or when the list hits a certain size. So even if the read performance takes a small hit, it's not really going to affect the overall performance much.

It's also worth mentioning that FlinkOrderedListState uses the same approach, which gives me confidence in this design choice.

That's why I think the current implementation makes more sense for real-world usage patterns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If Flink is implemented then it is OK with me to follow that precedent. My point was that this does not actually add capability that is more than ValueState provides. It is just a minor API wrapper adjustment - still useful but not the main purpose.

So we can merge with this design. But if you think about following up, here is how we would really like this to behave:

  • add should call some native Spark API that writes the element without reading the list
  • readRange should only read the requested range, ideally seeking in near-constant time (aka without a scan or sort)
  • clearRange should also seek in near-constant time
  • isEmpty should not read the list

MoreObjects.firstNonNull(this.readValue(), Lists.newArrayList());
final SortedMap<Instant, TimestampedValue<T>> sortedMap = Maps.newTreeMap();
for (TimestampedValue<T> value : listValues) {
sortedMap.put(value.getTimestamp(), value);
}
return sortedMap;
}

@Override
public Iterable<TimestampedValue<T>> readRange(Instant minTimestamp, Instant limitTimestamp) {
return this.readAsMap().subMap(minTimestamp, limitTimestamp).values();
}

@Override
public void clearRange(Instant minTimestamp, Instant limitTimestamp) {
final SortedMap<Instant, TimestampedValue<T>> sortedMap = this.readAsMap();
sortedMap.subMap(minTimestamp, limitTimestamp).clear();
this.writeValue(Lists.newArrayList(sortedMap.values()));
}

@Override
public OrderedListState<T> readRangeLater(Instant minTimestamp, Instant limitTimestamp) {
return this;
}

@Override
public void add(TimestampedValue<T> value) {
final List<TimestampedValue<T>> listValue =
MoreObjects.firstNonNull(this.readValue(), Lists.newArrayList());
listValue.add(value);
this.writeValue(listValue);
}

@Override
public ReadableState<Boolean> isEmpty() {
return new ReadableState<Boolean>() {
@Override
public Boolean read() {
final List<TimestampedValue<T>> listValue = readValue();
return listValue == null || listValue.isEmpty();
}

@Override
public ReadableState<Boolean> readLater() {
return this;
}
};
}

@Override
public Iterable<TimestampedValue<T>> read() {
return this.readAsMap().values();
}

@Override
public GroupingState<TimestampedValue<T>, Iterable<TimestampedValue<T>>> readLater() {
return this;
}

@Override
public void clear() {
final List<TimestampedValue<T>> listValue = this.readValue();
if (listValue != null) {
listValue.clear();
this.writeValue(listValue);
}
}
}
}
Loading