Skip to content

Commit

Permalink
[#20970] Fix gRPC leak by closing ResidualSource at BoundedToUnbounde…
Browse files Browse the repository at this point in the history
…dSourceAdapter.Reader#init() in Dataflow worker (#28548)
  • Loading branch information
baeminbo authored Oct 11, 2023
1 parent 2bfcb9f commit aedfa46
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
*/
package org.apache.beam.runners.core.construction;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -288,6 +288,15 @@ private void init(
residualElementsList == null
? new ResidualElements(Collections.emptyList())
: new ResidualElements(residualElementsList);

if (this.residualSource != null) {
// close current residualSource to avoid leak of reader.close() in ResidualSource
try {
this.residualSource.close();
} catch (IOException e) {
LOG.warn("Ignore error at closing ResidualSource", e);
}
}
this.residualSource =
residualSource == null ? null : new ResidualSource(residualSource, options);
}
Expand Down Expand Up @@ -465,7 +474,7 @@ public ResidualSource(BoundedSource<T> residualSource, PipelineOptions options)
}

private boolean advance() throws IOException {
checkArgument(!closed, "advance() call on closed %s", getClass().getName());
checkState(!closed, "advance() call on closed %s", getClass().getName());
if (readerDone) {
return false;
}
Expand Down Expand Up @@ -505,6 +514,7 @@ BoundedSource<T> getSource() {
}

Checkpoint<T> getCheckpointMark() {
checkState(!closed, "getCheckpointMark() call on closed %s", getClass().getName());
if (reader == null) {
// Reader hasn't started, checkpoint the residualSource.
return new Checkpoint<>(null /* residualElements */, residualSource);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import org.apache.beam.runners.core.construction.UnboundedReadFromBoundedSource.BoundedToUnboundedSourceAdapter;
Expand Down Expand Up @@ -69,10 +75,14 @@
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Unit tests for {@link UnboundedReadFromBoundedSource}. */
@RunWith(JUnit4.class)
public class UnboundedReadFromBoundedSourceTest {
private static final Logger LOG =
LoggerFactory.getLogger(UnboundedReadFromBoundedSourceTest.class);

@Rule public TemporaryFolder tmpFolder = new TemporaryFolder();

Expand Down Expand Up @@ -280,6 +290,38 @@ public void testReadFromCheckpointBeforeStart() throws Exception {
unboundedSource.createReader(options, checkpoint).getCurrent();
}

@Test
public void testReadersClosedProperly() throws IOException {
ManagedReaderBoundedSource boundedSource = new ManagedReaderBoundedSource(0, 10);
BoundedToUnboundedSourceAdapter<Integer> unboundedSource =
new BoundedToUnboundedSourceAdapter<>(boundedSource);
PipelineOptions options = PipelineOptionsFactory.create();

BoundedToUnboundedSourceAdapter<Integer>.Reader reader =
unboundedSource.createReader(options, new Checkpoint<Integer>(null, boundedSource));

for (int i = 0; i < 3; ++i) {
if (i == 0) {
assertTrue(reader.start());
} else {
assertTrue(reader.advance());
}
assertEquals(i, (int) reader.getCurrent());
}
Checkpoint<Integer> checkpoint = reader.getCheckpointMark();
List<TimestampedValue<Integer>> residualElements = checkpoint.getResidualElements();
for (int i = 0; i < 7; ++i) {
TimestampedValue<Integer> element = residualElements.get(i);
assertEquals(i + 3, (int) element.getValue());
}
for (int i = 0; i < 100; ++i) {
// A WeakReference of an object that no other objects reference are not immediately added to
// ReferenceQueue. To test this, we should run System.gc() multiple times.
// If a reader is GCed without closing, `cleanQueue` throws a RuntimeException.
boundedSource.cleanQueue();
}
}

/** Generate byte array of given size. */
private static byte[] generateInput(int size) {
// Arbitrary but fixed seed
Expand All @@ -298,6 +340,7 @@ private static void writeFile(File file, byte[] input) throws IOException {

/** Unsplittable source for use in tests. */
private static class UnsplittableSource extends FileBasedSource<Byte> {

public UnsplittableSource(String fileOrPatternSpec, long minBundleSize) {
super(StaticValueProvider.of(fileOrPatternSpec), minBundleSize);
}
Expand All @@ -323,6 +366,7 @@ public Coder<Byte> getOutputCoder() {
}

private static class UnsplittableReader extends FileBasedReader<Byte> {

ByteBuffer buff = ByteBuffer.allocate(1);
Byte current;
long offset;
Expand Down Expand Up @@ -370,4 +414,140 @@ protected long getCurrentOffset() {
}
}
}

/**
* An integer generating bounded source. This source class checks if readers are closed properly.
* For that, it manages weak references of readers, and checks at `createReader` and `cleanQueue`
* if readers were closed before GCed. The `cleanQueue` does not change the state in
* `ManagedReaderBoundedSource`, but throws an exception if it finds a reader GCed without
* closing.
*/
private static class ManagedReaderBoundedSource extends BoundedSource<Integer> {

private final int from;
private final int to; // exclusive

private transient ReferenceQueue<ManagedReader> refQueue;
private transient Map<Reference<ManagedReader>, CloseStatus> cloesStatusMap;

public ManagedReaderBoundedSource(int from, int to) {
if (from > to) {
throw new RuntimeException(
String.format("`from` <= `to`, but got from: %d, to: %d", from, to));
}
this.from = from;
this.to = to;
}

@Override
public List<? extends BoundedSource<Integer>> split(
long desiredBundleSizeBytes, PipelineOptions options) {
return Collections.singletonList(this);
}

@Override
public long getEstimatedSizeBytes(PipelineOptions options) {
return (to - from) * 4L;
}

@Override
public BoundedReader<Integer> createReader(PipelineOptions options) {
// Add weak reference to queue to monitor GCed readers. If `CloseStatus` associated with
// reader is not closed, it means a reader was GCed without closing properly. The CloseStatus
// check for GCed readers are done at cleanQueue().
if (refQueue == null) {
refQueue = new ReferenceQueue<>();
cloesStatusMap = new HashMap<>();
}
cleanQueue();

CloseStatus status = new CloseStatus();
ManagedReader reader = new ManagedReader(status);
WeakReference<ManagedReader> reference = new WeakReference<>(reader, refQueue);
cloesStatusMap.put(reference, status);
LOG.info("Add reference {} for reader {}", reference, reader);
return reader;
}

public void cleanQueue() {
System.gc();

Reference<? extends ManagedReader> reference;
while ((reference = refQueue.poll()) != null) {
CloseStatus closeStatus = cloesStatusMap.get(reference);
LOG.info("Poll reference: {}, closed: {}", reference, closeStatus.closed);
closeStatus.throwIfNotClosed();
}
}

class CloseStatus {

private final RuntimeException allocationStacktrace;

private boolean closed;

public CloseStatus() {
allocationStacktrace =
new RuntimeException("Previous reader was not closed properly. Reader allocation was");
closed = false;
}

void close() {
cleanQueue();
closed = true;
}

void throwIfNotClosed() {
if (!closed) {
throw allocationStacktrace;
}
}
}

class ManagedReader extends BoundedReader<Integer> {

private final CloseStatus status;

int current;

public ManagedReader(CloseStatus status) {
this.status = status;
}

@Override
public boolean start() {
if (from < to) {
current = from;
return true;
} else {
return false;
}
}

@Override
public boolean advance() {
if (current + 1 < to) {
++current;
return true;
} else {
return false;
}
}

@Override
public Integer getCurrent() {
return current;
}

@Override
public void close() {
status.close();
}

@Override
public BoundedSource<Integer> getCurrentSource() {
return ManagedReaderBoundedSource.this;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,9 @@ public double getRemainingParallelism() {

private static class UnboundedReaderIterator<T>
extends NativeReader.NativeReaderIterator<WindowedValue<ValueWithRecordId<T>>> {
// Do not close reader. The reader is cached in StreamingModeExecutionContext.readerCache, and
// will be reused until the cache is evicted, expired or invalidated.
// See UnboundedReader#iterator().
private final UnboundedSource.UnboundedReader<T> reader;
private final StreamingModeExecutionContext context;
private final boolean started;
Expand Down Expand Up @@ -862,7 +865,9 @@ public WindowedValue<ValueWithRecordId<T>> getCurrent() throws NoSuchElementExce
}

@Override
public void close() {}
public void close() {
// Don't close reader.
}

@Override
public NativeReader.Progress getProgress() {
Expand Down

0 comments on commit aedfa46

Please sign in to comment.