From aedfa461fb354e0e97d75694d2a117d9ad505602 Mon Sep 17 00:00:00 2001 From: Minbo Bae <49642083+baeminbo@users.noreply.github.com> Date: Wed, 11 Oct 2023 10:57:24 -0700 Subject: [PATCH] [#20970] Fix gRPC leak by closing ResidualSource at BoundedToUnboundedSourceAdapter.Reader#init() in Dataflow worker (#28548) --- .../UnboundedReadFromBoundedSource.java | 14 +- .../UnboundedReadFromBoundedSourceTest.java | 180 ++++++++++++++++++ .../dataflow/worker/WorkerCustomSources.java | 7 +- 3 files changed, 198 insertions(+), 3 deletions(-) diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSource.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSource.java index 67697636a363..53fad782da96 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSource.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSource.java @@ -17,8 +17,8 @@ */ package org.apache.beam.runners.core.construction; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.io.IOException; import java.io.InputStream; @@ -288,6 +288,15 @@ private void init( residualElementsList == null ? new ResidualElements(Collections.emptyList()) : new ResidualElements(residualElementsList); + + if (this.residualSource != null) { + // close current residualSource to avoid leak of reader.close() in ResidualSource + try { + this.residualSource.close(); + } catch (IOException e) { + LOG.warn("Ignore error at closing ResidualSource", e); + } + } this.residualSource = residualSource == null ? null : new ResidualSource(residualSource, options); } @@ -465,7 +474,7 @@ public ResidualSource(BoundedSource residualSource, PipelineOptions options) } private boolean advance() throws IOException { - checkArgument(!closed, "advance() call on closed %s", getClass().getName()); + checkState(!closed, "advance() call on closed %s", getClass().getName()); if (readerDone) { return false; } @@ -505,6 +514,7 @@ BoundedSource getSource() { } Checkpoint getCheckpointMark() { + checkState(!closed, "getCheckpointMark() call on closed %s", getClass().getName()); if (reader == null) { // Reader hasn't started, checkpoint the residualSource. return new Checkpoint<>(null /* residualElements */, residualSource); diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSourceTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSourceTest.java index cd4b49262fcb..31f6842a42bc 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSourceTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnboundedReadFromBoundedSourceTest.java @@ -26,9 +26,15 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.lang.ref.Reference; +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; import java.util.Random; import org.apache.beam.runners.core.construction.UnboundedReadFromBoundedSource.BoundedToUnboundedSourceAdapter; @@ -69,10 +75,14 @@ import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** Unit tests for {@link UnboundedReadFromBoundedSource}. */ @RunWith(JUnit4.class) public class UnboundedReadFromBoundedSourceTest { + private static final Logger LOG = + LoggerFactory.getLogger(UnboundedReadFromBoundedSourceTest.class); @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); @@ -280,6 +290,38 @@ public void testReadFromCheckpointBeforeStart() throws Exception { unboundedSource.createReader(options, checkpoint).getCurrent(); } + @Test + public void testReadersClosedProperly() throws IOException { + ManagedReaderBoundedSource boundedSource = new ManagedReaderBoundedSource(0, 10); + BoundedToUnboundedSourceAdapter unboundedSource = + new BoundedToUnboundedSourceAdapter<>(boundedSource); + PipelineOptions options = PipelineOptionsFactory.create(); + + BoundedToUnboundedSourceAdapter.Reader reader = + unboundedSource.createReader(options, new Checkpoint(null, boundedSource)); + + for (int i = 0; i < 3; ++i) { + if (i == 0) { + assertTrue(reader.start()); + } else { + assertTrue(reader.advance()); + } + assertEquals(i, (int) reader.getCurrent()); + } + Checkpoint checkpoint = reader.getCheckpointMark(); + List> residualElements = checkpoint.getResidualElements(); + for (int i = 0; i < 7; ++i) { + TimestampedValue element = residualElements.get(i); + assertEquals(i + 3, (int) element.getValue()); + } + for (int i = 0; i < 100; ++i) { + // A WeakReference of an object that no other objects reference are not immediately added to + // ReferenceQueue. To test this, we should run System.gc() multiple times. + // If a reader is GCed without closing, `cleanQueue` throws a RuntimeException. + boundedSource.cleanQueue(); + } + } + /** Generate byte array of given size. */ private static byte[] generateInput(int size) { // Arbitrary but fixed seed @@ -298,6 +340,7 @@ private static void writeFile(File file, byte[] input) throws IOException { /** Unsplittable source for use in tests. */ private static class UnsplittableSource extends FileBasedSource { + public UnsplittableSource(String fileOrPatternSpec, long minBundleSize) { super(StaticValueProvider.of(fileOrPatternSpec), minBundleSize); } @@ -323,6 +366,7 @@ public Coder getOutputCoder() { } private static class UnsplittableReader extends FileBasedReader { + ByteBuffer buff = ByteBuffer.allocate(1); Byte current; long offset; @@ -370,4 +414,140 @@ protected long getCurrentOffset() { } } } + + /** + * An integer generating bounded source. This source class checks if readers are closed properly. + * For that, it manages weak references of readers, and checks at `createReader` and `cleanQueue` + * if readers were closed before GCed. The `cleanQueue` does not change the state in + * `ManagedReaderBoundedSource`, but throws an exception if it finds a reader GCed without + * closing. + */ + private static class ManagedReaderBoundedSource extends BoundedSource { + + private final int from; + private final int to; // exclusive + + private transient ReferenceQueue refQueue; + private transient Map, CloseStatus> cloesStatusMap; + + public ManagedReaderBoundedSource(int from, int to) { + if (from > to) { + throw new RuntimeException( + String.format("`from` <= `to`, but got from: %d, to: %d", from, to)); + } + this.from = from; + this.to = to; + } + + @Override + public List> split( + long desiredBundleSizeBytes, PipelineOptions options) { + return Collections.singletonList(this); + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) { + return (to - from) * 4L; + } + + @Override + public BoundedReader createReader(PipelineOptions options) { + // Add weak reference to queue to monitor GCed readers. If `CloseStatus` associated with + // reader is not closed, it means a reader was GCed without closing properly. The CloseStatus + // check for GCed readers are done at cleanQueue(). + if (refQueue == null) { + refQueue = new ReferenceQueue<>(); + cloesStatusMap = new HashMap<>(); + } + cleanQueue(); + + CloseStatus status = new CloseStatus(); + ManagedReader reader = new ManagedReader(status); + WeakReference reference = new WeakReference<>(reader, refQueue); + cloesStatusMap.put(reference, status); + LOG.info("Add reference {} for reader {}", reference, reader); + return reader; + } + + public void cleanQueue() { + System.gc(); + + Reference reference; + while ((reference = refQueue.poll()) != null) { + CloseStatus closeStatus = cloesStatusMap.get(reference); + LOG.info("Poll reference: {}, closed: {}", reference, closeStatus.closed); + closeStatus.throwIfNotClosed(); + } + } + + class CloseStatus { + + private final RuntimeException allocationStacktrace; + + private boolean closed; + + public CloseStatus() { + allocationStacktrace = + new RuntimeException("Previous reader was not closed properly. Reader allocation was"); + closed = false; + } + + void close() { + cleanQueue(); + closed = true; + } + + void throwIfNotClosed() { + if (!closed) { + throw allocationStacktrace; + } + } + } + + class ManagedReader extends BoundedReader { + + private final CloseStatus status; + + int current; + + public ManagedReader(CloseStatus status) { + this.status = status; + } + + @Override + public boolean start() { + if (from < to) { + current = from; + return true; + } else { + return false; + } + } + + @Override + public boolean advance() { + if (current + 1 < to) { + ++current; + return true; + } else { + return false; + } + } + + @Override + public Integer getCurrent() { + return current; + } + + @Override + public void close() { + status.close(); + } + + @Override + public BoundedSource getCurrentSource() { + return ManagedReaderBoundedSource.this; + } + } + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java index 872dc1e89a79..a9050236efc8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java @@ -776,6 +776,9 @@ public double getRemainingParallelism() { private static class UnboundedReaderIterator extends NativeReader.NativeReaderIterator>> { + // Do not close reader. The reader is cached in StreamingModeExecutionContext.readerCache, and + // will be reused until the cache is evicted, expired or invalidated. + // See UnboundedReader#iterator(). private final UnboundedSource.UnboundedReader reader; private final StreamingModeExecutionContext context; private final boolean started; @@ -862,7 +865,9 @@ public WindowedValue> getCurrent() throws NoSuchElementExce } @Override - public void close() {} + public void close() { + // Don't close reader. + } @Override public NativeReader.Progress getProgress() {