diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java index 07d652992c1c..c28939c59ee2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReader.java @@ -317,47 +317,88 @@ private void delayUnbatchableMultimapFetches( } } - public void startBatchAndBlock() { - // First, drain work out of the pending lookups into a set. These will be the items we fetch. + private void delayUnbatchableOrderedListFetches( + List> orderedListTags, HashSet> toFetch) { + // Each KeyedGetDataRequest can have at most 1 TagOrderedListRequest per + // pair, thus we need to delay unbatchable ordered list requests of the same stateFamily and tag + // into later batches. + + Map>>> groupedTags = + orderedListTags.stream() + .collect( + Collectors.groupingBy( + StateTag::getStateFamily, Collectors.groupingBy(StateTag::getTag))); + + for (Map>> familyTags : groupedTags.values()) { + for (List> tags : familyTags.values()) { + StateTag first = tags.remove(0); + toFetch.add(first); + // Add the rest of the reads for the state family and tags back to pending. + pendingLookups.addAll(tags); + } + } + } + + private HashSet> buildFetchSet() { HashSet> toFetch = Sets.newHashSet(); - try { - List> multimapTags = Lists.newArrayList(); - while (!pendingLookups.isEmpty()) { - StateTag stateTag = pendingLookups.poll(); - if (stateTag == null) { - break; - } - if (stateTag.getKind() == Kind.MULTIMAP_ALL - || stateTag.getKind() == Kind.MULTIMAP_SINGLE_ENTRY) { - multimapTags.add(stateTag); - continue; - } - if (!toFetch.add(stateTag)) { - throw new IllegalStateException("Duplicate tags being fetched."); - } + List> multimapTags = Lists.newArrayList(); + List> orderedListTags = Lists.newArrayList(); + while (!pendingLookups.isEmpty()) { + StateTag stateTag = pendingLookups.poll(); + if (stateTag == null) { + break; } - if (!multimapTags.isEmpty()) { - delayUnbatchableMultimapFetches(multimapTags, toFetch); + if (stateTag.getKind() == Kind.MULTIMAP_ALL + || stateTag.getKind() == Kind.MULTIMAP_SINGLE_ENTRY) { + multimapTags.add(stateTag); + continue; + } + if (stateTag.getKind() == Kind.ORDERED_LIST) { + orderedListTags.add(stateTag); + continue; + } + + if (!toFetch.add(stateTag)) { + throw new IllegalStateException("Duplicate tags being fetched."); } + } + if (!multimapTags.isEmpty()) { + delayUnbatchableMultimapFetches(multimapTags, toFetch); + } + if (!orderedListTags.isEmpty()) { + delayUnbatchableOrderedListFetches(orderedListTags, toFetch); + } + return toFetch; + } - // If we failed to drain anything, some other thread pulled it off the queue. We have no work - // to do. + public void performReads() { + while (true) { + HashSet> toFetch = buildFetchSet(); if (toFetch.isEmpty()) { return; } - - KeyedGetDataResponse response = tryGetDataFromWindmill(toFetch); - - // Removes tags from toFetch as they are processed. - consumeResponse(response, toFetch); - } catch (Exception e) { - // Set up all the remaining futures for this key to throw an exception. This ensures that if - // the exception is caught that all futures have been completed and do not block. - for (StateTag stateTag : toFetch) { - waiting.get(stateTag).future.setException(e); + try { + KeyedGetDataResponse response = tryGetDataFromWindmill(toFetch); + // Removes tags from toFetch as they are processed. + consumeResponse(response, toFetch); + if (!toFetch.isEmpty()) { + throw new IllegalStateException( + "Didn't receive responses for all pending fetches. Missing: " + toFetch); + } + } catch (Exception e) { + // Set up all the remaining futures for this key to throw an exception. This ensures that if + // the exception is caught that all futures have been completed and do not block. + for (StateTag stateTag : toFetch) { + waiting.get(stateTag).future.setException(e); + } + // Also setup futures that may have been added back if they were not batched. + while (true) { + @Nullable StateTag stateTag = pendingLookups.poll(); + if (stateTag == null) break; + waiting.get(stateTag).future.setException(e); + } + throw new RuntimeException(e); } - - throw new RuntimeException(e); } } @@ -643,11 +684,6 @@ private void consumeResponse(KeyedGetDataResponse response, Set> toF consumeMultimapSingleEntry(entry, entryTag); } } - - if (!toFetch.isEmpty()) { - throw new IllegalStateException( - "Didn't receive responses for all pending fetches. Missing: " + toFetch); - } } /** The deserialized values in {@code bag} as a read-only array list. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WrappedFuture.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WrappedFuture.java index 035f6ec8e93d..7e894524bef3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WrappedFuture.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WrappedFuture.java @@ -45,7 +45,7 @@ public WrappedFuture(WindmillStateReader reader, Future delegate) { public T get() throws InterruptedException, ExecutionException { if (!delegate().isDone() && reader != null) { // Only one thread per reader, so no race here. - reader.startBatchAndBlock(); + reader.performReads(); } reader = null; return super.get(); @@ -56,7 +56,7 @@ public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { if (!delegate().isDone() && reader != null) { // Only one thread per reader, so no race here. - reader.startBatchAndBlock(); + reader.performReads(); } reader = null; return super.get(timeout, unit); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java index b8c4803a8f34..430e31ee04ff 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateReaderTest.java @@ -87,6 +87,8 @@ public class WindmillStateReaderTest { private static final ByteString STATE_KEY_2 = ByteString.copyFromUtf8("key2"); private static final String STATE_FAMILY = "family"; + private static final String STATE_FAMILY2 = "family2"; + private static void assertNoReader(Object obj) throws Exception { WindmillStateTestUtils.assertNoReference(obj, WindmillStateReader.class); } @@ -993,15 +995,19 @@ public void testReadSortedList() throws Exception { public void testReadSortedListRanges() throws Exception { Future>> future1 = underTest.orderedListFuture(Range.closedOpen(0L, 5L), STATE_KEY_1, STATE_FAMILY, INT_CODER); + // Should be put into a subsequent batch as it has the same key and state family. Future>> future2 = underTest.orderedListFuture(Range.closedOpen(5L, 6L), STATE_KEY_1, STATE_FAMILY, INT_CODER); Future>> future3 = underTest.orderedListFuture( - Range.closedOpen(6L, 10L), STATE_KEY_1, STATE_FAMILY, INT_CODER); + Range.closedOpen(6L, 10L), STATE_KEY_2, STATE_FAMILY, INT_CODER); + Future>> future4 = + underTest.orderedListFuture( + Range.closedOpen(11L, 12L), STATE_KEY_2, STATE_FAMILY2, INT_CODER); Mockito.verifyNoMoreInteractions(mockWindmill); // Fetch the entire list. - Windmill.KeyedGetDataRequest.Builder expectedRequest = + Windmill.KeyedGetDataRequest.Builder expectedRequest1 = Windmill.KeyedGetDataRequest.newBuilder() .setKey(DATA_KEY) .setShardingKey(SHARDING_KEY) @@ -1015,18 +1021,31 @@ public void testReadSortedListRanges() throws Exception { .setFetchMaxBytes(WindmillStateReader.MAX_ORDERED_LIST_BYTES)) .addSortedListsToFetch( Windmill.TagSortedListFetchRequest.newBuilder() - .setTag(STATE_KEY_1) + .setTag(STATE_KEY_2) .setStateFamily(STATE_FAMILY) - .addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6)) + .addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10)) .setFetchMaxBytes(WindmillStateReader.MAX_ORDERED_LIST_BYTES)) + .addSortedListsToFetch( + Windmill.TagSortedListFetchRequest.newBuilder() + .setTag(STATE_KEY_2) + .setStateFamily(STATE_FAMILY2) + .addFetchRanges(SortedListRange.newBuilder().setStart(11).setLimit(12)) + .setFetchMaxBytes(WindmillStateReader.MAX_ORDERED_LIST_BYTES)); + + Windmill.KeyedGetDataRequest.Builder expectedRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) .addSortedListsToFetch( Windmill.TagSortedListFetchRequest.newBuilder() .setTag(STATE_KEY_1) .setStateFamily(STATE_FAMILY) - .addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10)) + .addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6)) .setFetchMaxBytes(WindmillStateReader.MAX_ORDERED_LIST_BYTES)); - Windmill.KeyedGetDataResponse.Builder response = + Windmill.KeyedGetDataResponse.Builder response1 = Windmill.KeyedGetDataResponse.newBuilder() .setKey(DATA_KEY) .addTagSortedLists( @@ -1038,41 +1057,41 @@ public void testReadSortedListRanges() throws Exception { .addFetchRanges(SortedListRange.newBuilder().setStart(0).setLimit(5))) .addTagSortedLists( Windmill.TagSortedListFetchResponse.newBuilder() - .setTag(STATE_KEY_1) + .setTag(STATE_KEY_2) .setStateFamily(STATE_FAMILY) .addEntries( - SortedListEntry.newBuilder().setValue(intData(6)).setSortKey(6000).setId(5)) - .addEntries( - SortedListEntry.newBuilder().setValue(intData(7)).setSortKey(7000).setId(7)) - .addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6))) + SortedListEntry.newBuilder().setValue(intData(8)).setSortKey(8000).setId(8)) + .addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10))) + .addTagSortedLists( + Windmill.TagSortedListFetchResponse.newBuilder() + .setTag(STATE_KEY_2) + .setStateFamily(STATE_FAMILY2) + .addFetchRanges(SortedListRange.newBuilder().setStart(11).setLimit(12))); + + Windmill.KeyedGetDataResponse.Builder response2 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) .addTagSortedLists( Windmill.TagSortedListFetchResponse.newBuilder() .setTag(STATE_KEY_1) .setStateFamily(STATE_FAMILY) .addEntries( - SortedListEntry.newBuilder().setValue(intData(8)).setSortKey(8000).setId(8)) - .addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10))); - - Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build())) - .thenReturn(response.build()); + SortedListEntry.newBuilder().setValue(intData(6)).setSortKey(6000).setId(5)) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(7)).setSortKey(7000).setId(7)) + .addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6))); - { - Iterable> results = future1.get(); - Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); - for (TimestampedValue unused : results) { - // Iterate over the results to force loading all the pages. - } - Mockito.verifyNoMoreInteractions(mockWindmill); - assertThat(results, Matchers.contains(TimestampedValue.of(5, Instant.ofEpochMilli(5)))); - assertNoReader(future1); - } + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build())) + .thenReturn(response1.build()); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build())) + .thenReturn(response2.build()); + // Trigger reads of batching. By fetching future2 which is not part of the first batch we ensure + // that all batches are fetched. { Iterable> results = future2.get(); - Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); - for (TimestampedValue unused : results) { - // Iterate over the results to force loading all the pages. - } + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build()); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build()); Mockito.verifyNoMoreInteractions(mockWindmill); assertThat( results, @@ -1082,16 +1101,23 @@ public void testReadSortedListRanges() throws Exception { assertNoReader(future2); } + { + Iterable> results = future1.get(); + assertThat(results, Matchers.contains(TimestampedValue.of(5, Instant.ofEpochMilli(5)))); + assertNoReader(future1); + } + { Iterable> results = future3.get(); - Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); - for (TimestampedValue unused : results) { - // Iterate over the results to force loading all the pages. - } - Mockito.verifyNoMoreInteractions(mockWindmill); assertThat(results, Matchers.contains(TimestampedValue.of(8, Instant.ofEpochMilli(8)))); assertNoReader(future3); } + + { + Iterable> results = future4.get(); + assertThat(results, Matchers.emptyIterable()); + assertNoReader(future4); + } } @Test