diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReader.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReader.java index 39ef63c8f7e9..c19bc251bc30 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReader.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReader.java @@ -21,6 +21,8 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -75,6 +77,11 @@ public class FlinkUnboundedSourceReader private int currentReaderIndex; private volatile boolean shouldEmitWatermark; + /** Pending checkpoints which have not been acknowledged yet. */ + private transient LinkedHashMap> pendingCheckpoints; + /** Keep a maximum of 32 checkpoints for {@code CheckpointMark.finalizeCheckpoint()}. */ + private static final int MAX_NUMBER_PENDING_CHECKPOINTS = 32; + public FlinkUnboundedSourceReader( String stepName, SourceReaderContext context, @@ -84,6 +91,7 @@ public FlinkUnboundedSourceReader( this.readers = new ArrayList<>(); this.dataAvailableFutureRef = new AtomicReference<>(DUMMY_FUTURE); this.currentReaderIndex = 0; + pendingCheckpoints = new LinkedHashMap<>(); } @VisibleForTesting @@ -97,6 +105,7 @@ protected FlinkUnboundedSourceReader( this.readers = new ArrayList<>(); this.dataAvailableFutureRef = new AtomicReference<>(DUMMY_FUTURE); this.currentReaderIndex = 0; + pendingCheckpoints = new LinkedHashMap<>(); } @Override @@ -217,6 +226,50 @@ protected Source.Reader createReader(@Nonnull FlinkSourceSplit sourceSplit return createUnboundedSourceReader(beamSource, sourceSplit.getSplitState()); } + @Override + public List> snapshotState(long checkpointId) { + + List checkpointMarks = new ArrayList<>(allReaders().size()); + allReaders() + .forEach( + (splitId, readerAndOutput) -> { + UnboundedSource.UnboundedReader reader = asUnbounded(readerAndOutput.reader); + checkpointMarks.add(reader.getCheckpointMark()); + }); + + // cleanup old pending checkpoints and add new checkpoint + int diff = pendingCheckpoints.size() - MAX_NUMBER_PENDING_CHECKPOINTS; + if (diff >= 0) { + for (Iterator iterator = pendingCheckpoints.keySet().iterator(); diff >= 0; diff--) { + iterator.next(); + iterator.remove(); + } + } + pendingCheckpoints.put(checkpointId, checkpointMarks); + return super.snapshotState(checkpointId); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + + List checkpointMarks = pendingCheckpoints.get(checkpointId); + if (checkpointMarks != null) { + + // remove old checkpoints including the current one + Iterator iterator = pendingCheckpoints.keySet().iterator(); + long currentId; + do { + currentId = iterator.next(); + iterator.remove(); + } while (currentId != checkpointId); + + // confirm all marks + for (UnboundedSource.CheckpointMark mark : checkpointMarks) { + mark.finalizeCheckpoint(); + } + } + } + // -------------- private helper methods ---------------- private void emitRecord( diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReaderTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReaderTest.java index 0ae5b407a157..98afccd61eab 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReaderTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/source/unbounded/FlinkUnboundedSourceReaderTest.java @@ -282,6 +282,31 @@ public void testPendingBytesMetric() throws Exception { } } + @Test + public void testCheckMarksFinalized() throws Exception { + + final int numSplits = 2; + final int numRecordsPerSplit = 10; + + List>> splits = + createSplits(numSplits, numRecordsPerSplit, 0); + RecordsValidatingOutput validatingOutput = new RecordsValidatingOutput(splits); + // Create a reader, take a snapshot. + try (SourceReader< + WindowedValue>>, + FlinkSourceSplit>> + reader = createReader()) { + List finalizeTracker = new ArrayList<>(); + TestCountingSource.setFinalizeTracker(finalizeTracker); + pollAndValidate(reader, splits, validatingOutput, numSplits * numRecordsPerSplit / 2); + assertTrue(finalizeTracker.isEmpty()); + reader.snapshotState(0L); + // notifyCheckpointComplete is normally called by the SourceOperator + reader.notifyCheckpointComplete(0L); + assertFalse(finalizeTracker.isEmpty()); + } + } + // --------------- private helper classes ----------------- /** A source whose advance() method only returns true occasionally. */ private static class DummySource extends TestCountingSource {