Skip to content

Commit

Permalink
Change WindmillStateReader to not batch OrderedListFetches for the sa…
Browse files Browse the repository at this point in the history
…me family and tag. Fix issue with MultimapState delayed fetches due to batching.
  • Loading branch information
scwhittle committed Oct 31, 2023
1 parent 8b31859 commit 97354ee
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -317,47 +317,88 @@ private void delayUnbatchableMultimapFetches(
}
}

public void startBatchAndBlock() {
// First, drain work out of the pending lookups into a set. These will be the items we fetch.
private void delayUnbatchableOrderedListFetches(
List<StateTag<?>> orderedListTags, HashSet<StateTag<?>> toFetch) {
// Each KeyedGetDataRequest can have at most 1 TagOrderedListRequest per <tag, state_family>
// pair, thus we need to delay unbatchable ordered list requests of the same stateFamily and tag
// into later batches.

Map<String, Map<ByteString, List<StateTag<?>>>> groupedTags =
orderedListTags.stream()
.collect(
Collectors.groupingBy(
StateTag::getStateFamily, Collectors.groupingBy(StateTag::getTag)));

for (Map<ByteString, List<StateTag<?>>> familyTags : groupedTags.values()) {
for (List<StateTag<?>> tags : familyTags.values()) {
StateTag<?> first = tags.remove(0);
toFetch.add(first);
// Add the rest of the reads for the state family and tags back to pending.
pendingLookups.addAll(tags);
}
}
}

private HashSet<StateTag<?>> buildFetchSet() {
HashSet<StateTag<?>> toFetch = Sets.newHashSet();
try {
List<StateTag<?>> multimapTags = Lists.newArrayList();
while (!pendingLookups.isEmpty()) {
StateTag<?> stateTag = pendingLookups.poll();
if (stateTag == null) {
break;
}
if (stateTag.getKind() == Kind.MULTIMAP_ALL
|| stateTag.getKind() == Kind.MULTIMAP_SINGLE_ENTRY) {
multimapTags.add(stateTag);
continue;
}
if (!toFetch.add(stateTag)) {
throw new IllegalStateException("Duplicate tags being fetched.");
}
List<StateTag<?>> multimapTags = Lists.newArrayList();
List<StateTag<?>> orderedListTags = Lists.newArrayList();
while (!pendingLookups.isEmpty()) {
StateTag<?> stateTag = pendingLookups.poll();
if (stateTag == null) {
break;
}
if (!multimapTags.isEmpty()) {
delayUnbatchableMultimapFetches(multimapTags, toFetch);
if (stateTag.getKind() == Kind.MULTIMAP_ALL
|| stateTag.getKind() == Kind.MULTIMAP_SINGLE_ENTRY) {
multimapTags.add(stateTag);
continue;
}
if (stateTag.getKind() == Kind.ORDERED_LIST) {
orderedListTags.add(stateTag);
continue;
}

if (!toFetch.add(stateTag)) {
throw new IllegalStateException("Duplicate tags being fetched.");
}
}
if (!multimapTags.isEmpty()) {
delayUnbatchableMultimapFetches(multimapTags, toFetch);
}
if (!orderedListTags.isEmpty()) {
delayUnbatchableOrderedListFetches(orderedListTags, toFetch);
}
return toFetch;
}

// If we failed to drain anything, some other thread pulled it off the queue. We have no work
// to do.
public void performReads() {
while (true) {
HashSet<StateTag<?>> toFetch = buildFetchSet();
if (toFetch.isEmpty()) {
return;
}

KeyedGetDataResponse response = tryGetDataFromWindmill(toFetch);

// Removes tags from toFetch as they are processed.
consumeResponse(response, toFetch);
} catch (Exception e) {
// Set up all the remaining futures for this key to throw an exception. This ensures that if
// the exception is caught that all futures have been completed and do not block.
for (StateTag<?> stateTag : toFetch) {
waiting.get(stateTag).future.setException(e);
try {
KeyedGetDataResponse response = tryGetDataFromWindmill(toFetch);
// Removes tags from toFetch as they are processed.
consumeResponse(response, toFetch);
if (!toFetch.isEmpty()) {
throw new IllegalStateException(
"Didn't receive responses for all pending fetches. Missing: " + toFetch);
}
} catch (Exception e) {
// Set up all the remaining futures for this key to throw an exception. This ensures that if
// the exception is caught that all futures have been completed and do not block.
for (StateTag<?> stateTag : toFetch) {
waiting.get(stateTag).future.setException(e);
}
// Also setup futures that may have been added back if they were not batched.
while (true) {
@Nullable StateTag<?> stateTag = pendingLookups.poll();
if (stateTag == null) break;
waiting.get(stateTag).future.setException(e);
}
throw new RuntimeException(e);
}

throw new RuntimeException(e);
}
}

Expand Down Expand Up @@ -643,11 +684,6 @@ private void consumeResponse(KeyedGetDataResponse response, Set<StateTag<?>> toF
consumeMultimapSingleEntry(entry, entryTag);
}
}

if (!toFetch.isEmpty()) {
throw new IllegalStateException(
"Didn't receive responses for all pending fetches. Missing: " + toFetch);
}
}

/** The deserialized values in {@code bag} as a read-only array list. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public WrappedFuture(WindmillStateReader reader, Future<T> delegate) {
public T get() throws InterruptedException, ExecutionException {
if (!delegate().isDone() && reader != null) {
// Only one thread per reader, so no race here.
reader.startBatchAndBlock();
reader.performReads();
}
reader = null;
return super.get();
Expand All @@ -56,7 +56,7 @@ public T get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
if (!delegate().isDone() && reader != null) {
// Only one thread per reader, so no race here.
reader.startBatchAndBlock();
reader.performReads();
}
reader = null;
return super.get(timeout, unit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ public class WindmillStateReaderTest {
private static final ByteString STATE_KEY_2 = ByteString.copyFromUtf8("key2");
private static final String STATE_FAMILY = "family";

private static final String STATE_FAMILY2 = "family2";

private static void assertNoReader(Object obj) throws Exception {
WindmillStateTestUtils.assertNoReference(obj, WindmillStateReader.class);
}
Expand Down Expand Up @@ -993,15 +995,19 @@ public void testReadSortedList() throws Exception {
public void testReadSortedListRanges() throws Exception {
Future<Iterable<TimestampedValue<Integer>>> future1 =
underTest.orderedListFuture(Range.closedOpen(0L, 5L), STATE_KEY_1, STATE_FAMILY, INT_CODER);
// Should be put into a subsequent batch as it has the same key and state family.
Future<Iterable<TimestampedValue<Integer>>> future2 =
underTest.orderedListFuture(Range.closedOpen(5L, 6L), STATE_KEY_1, STATE_FAMILY, INT_CODER);
Future<Iterable<TimestampedValue<Integer>>> future3 =
underTest.orderedListFuture(
Range.closedOpen(6L, 10L), STATE_KEY_1, STATE_FAMILY, INT_CODER);
Range.closedOpen(6L, 10L), STATE_KEY_2, STATE_FAMILY, INT_CODER);
Future<Iterable<TimestampedValue<Integer>>> future4 =
underTest.orderedListFuture(
Range.closedOpen(11L, 12L), STATE_KEY_2, STATE_FAMILY2, INT_CODER);
Mockito.verifyNoMoreInteractions(mockWindmill);

// Fetch the entire list.
Windmill.KeyedGetDataRequest.Builder expectedRequest =
Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
Windmill.KeyedGetDataRequest.newBuilder()
.setKey(DATA_KEY)
.setShardingKey(SHARDING_KEY)
Expand All @@ -1015,18 +1021,31 @@ public void testReadSortedListRanges() throws Exception {
.setFetchMaxBytes(WindmillStateReader.MAX_ORDERED_LIST_BYTES))
.addSortedListsToFetch(
Windmill.TagSortedListFetchRequest.newBuilder()
.setTag(STATE_KEY_1)
.setTag(STATE_KEY_2)
.setStateFamily(STATE_FAMILY)
.addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6))
.addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10))
.setFetchMaxBytes(WindmillStateReader.MAX_ORDERED_LIST_BYTES))
.addSortedListsToFetch(
Windmill.TagSortedListFetchRequest.newBuilder()
.setTag(STATE_KEY_2)
.setStateFamily(STATE_FAMILY2)
.addFetchRanges(SortedListRange.newBuilder().setStart(11).setLimit(12))
.setFetchMaxBytes(WindmillStateReader.MAX_ORDERED_LIST_BYTES));

Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
Windmill.KeyedGetDataRequest.newBuilder()
.setKey(DATA_KEY)
.setShardingKey(SHARDING_KEY)
.setWorkToken(WORK_TOKEN)
.setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
.addSortedListsToFetch(
Windmill.TagSortedListFetchRequest.newBuilder()
.setTag(STATE_KEY_1)
.setStateFamily(STATE_FAMILY)
.addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10))
.addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6))
.setFetchMaxBytes(WindmillStateReader.MAX_ORDERED_LIST_BYTES));

Windmill.KeyedGetDataResponse.Builder response =
Windmill.KeyedGetDataResponse.Builder response1 =
Windmill.KeyedGetDataResponse.newBuilder()
.setKey(DATA_KEY)
.addTagSortedLists(
Expand All @@ -1038,41 +1057,41 @@ public void testReadSortedListRanges() throws Exception {
.addFetchRanges(SortedListRange.newBuilder().setStart(0).setLimit(5)))
.addTagSortedLists(
Windmill.TagSortedListFetchResponse.newBuilder()
.setTag(STATE_KEY_1)
.setTag(STATE_KEY_2)
.setStateFamily(STATE_FAMILY)
.addEntries(
SortedListEntry.newBuilder().setValue(intData(6)).setSortKey(6000).setId(5))
.addEntries(
SortedListEntry.newBuilder().setValue(intData(7)).setSortKey(7000).setId(7))
.addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6)))
SortedListEntry.newBuilder().setValue(intData(8)).setSortKey(8000).setId(8))
.addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10)))
.addTagSortedLists(
Windmill.TagSortedListFetchResponse.newBuilder()
.setTag(STATE_KEY_2)
.setStateFamily(STATE_FAMILY2)
.addFetchRanges(SortedListRange.newBuilder().setStart(11).setLimit(12)));

Windmill.KeyedGetDataResponse.Builder response2 =
Windmill.KeyedGetDataResponse.newBuilder()
.setKey(DATA_KEY)
.addTagSortedLists(
Windmill.TagSortedListFetchResponse.newBuilder()
.setTag(STATE_KEY_1)
.setStateFamily(STATE_FAMILY)
.addEntries(
SortedListEntry.newBuilder().setValue(intData(8)).setSortKey(8000).setId(8))
.addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10)));

Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
.thenReturn(response.build());
SortedListEntry.newBuilder().setValue(intData(6)).setSortKey(6000).setId(5))
.addEntries(
SortedListEntry.newBuilder().setValue(intData(7)).setSortKey(7000).setId(7))
.addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6)));

{
Iterable<TimestampedValue<Integer>> results = future1.get();
Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
for (TimestampedValue<Integer> unused : results) {
// Iterate over the results to force loading all the pages.
}
Mockito.verifyNoMoreInteractions(mockWindmill);
assertThat(results, Matchers.contains(TimestampedValue.of(5, Instant.ofEpochMilli(5))));
assertNoReader(future1);
}
Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build()))
.thenReturn(response1.build());
Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build()))
.thenReturn(response2.build());

// Trigger reads of batching. By fetching future2 which is not part of the first batch we ensure
// that all batches are fetched.
{
Iterable<TimestampedValue<Integer>> results = future2.get();
Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
for (TimestampedValue<Integer> unused : results) {
// Iterate over the results to force loading all the pages.
}
Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build());
Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build());
Mockito.verifyNoMoreInteractions(mockWindmill);
assertThat(
results,
Expand All @@ -1082,16 +1101,23 @@ public void testReadSortedListRanges() throws Exception {
assertNoReader(future2);
}

{
Iterable<TimestampedValue<Integer>> results = future1.get();
assertThat(results, Matchers.contains(TimestampedValue.of(5, Instant.ofEpochMilli(5))));
assertNoReader(future1);
}

{
Iterable<TimestampedValue<Integer>> results = future3.get();
Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
for (TimestampedValue<Integer> unused : results) {
// Iterate over the results to force loading all the pages.
}
Mockito.verifyNoMoreInteractions(mockWindmill);
assertThat(results, Matchers.contains(TimestampedValue.of(8, Instant.ofEpochMilli(8))));
assertNoReader(future3);
}

{
Iterable<TimestampedValue<Integer>> results = future4.get();
assertThat(results, Matchers.emptyIterable());
assertNoReader(future4);
}
}

@Test
Expand Down

0 comments on commit 97354ee

Please sign in to comment.