Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#20970] Fix gRPC leak by closing ResidualSource at BoundedToUnboundedSourceAdapter.Reader#init() in Dataflow worker #28548

Merged
merged 3 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
*/
package org.apache.beam.runners.core.construction;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -288,6 +288,15 @@ private void init(
residualElementsList == null
? new ResidualElements(Collections.emptyList())
: new ResidualElements(residualElementsList);

if (this.residualSource != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking - this means that we are calling init() again without ever calling close() on the reader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's at Reader.getCheckpointMark().

The getCheckpointMark() is called at bundle finish in Dataflow streaming jobs, and close() will be pending as readers are cached. A reader can be reused in next bundles.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yea based on that the change is obviously correct. Thanks!

// close current residualSource to avoid leak of reader.close() in ResidualSource
try {
this.residualSource.close();
} catch (IOException e) {
LOG.warn("Ignore error at closing ResidualSource", e);
}
}
this.residualSource =
residualSource == null ? null : new ResidualSource(residualSource, options);
}
Expand Down Expand Up @@ -465,7 +474,7 @@ public ResidualSource(BoundedSource<T> residualSource, PipelineOptions options)
}

private boolean advance() throws IOException {
checkArgument(!closed, "advance() call on closed %s", getClass().getName());
checkState(!closed, "advance() call on closed %s", getClass().getName());
if (readerDone) {
return false;
}
Expand Down Expand Up @@ -505,6 +514,7 @@ BoundedSource<T> getSource() {
}

Checkpoint<T> getCheckpointMark() {
checkState(!closed, "getCheckpointMark() call on closed %s", getClass().getName());
if (reader == null) {
// Reader hasn't started, checkpoint the residualSource.
return new Checkpoint<>(null /* residualElements */, residualSource);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.ref.Reference;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import org.apache.beam.runners.core.construction.UnboundedReadFromBoundedSource.BoundedToUnboundedSourceAdapter;
Expand Down Expand Up @@ -69,10 +75,14 @@
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Unit tests for {@link UnboundedReadFromBoundedSource}. */
@RunWith(JUnit4.class)
public class UnboundedReadFromBoundedSourceTest {
private static final Logger LOG =
LoggerFactory.getLogger(UnboundedReadFromBoundedSourceTest.class);

@Rule public TemporaryFolder tmpFolder = new TemporaryFolder();

Expand Down Expand Up @@ -280,6 +290,38 @@ public void testReadFromCheckpointBeforeStart() throws Exception {
unboundedSource.createReader(options, checkpoint).getCurrent();
}

@Test
public void testReadersClosedProperly() throws IOException {
ManagedReaderBoundedSource boundedSource = new ManagedReaderBoundedSource(0, 10);
BoundedToUnboundedSourceAdapter<Integer> unboundedSource =
new BoundedToUnboundedSourceAdapter<>(boundedSource);
PipelineOptions options = PipelineOptionsFactory.create();

BoundedToUnboundedSourceAdapter<Integer>.Reader reader =
unboundedSource.createReader(options, new Checkpoint<Integer>(null, boundedSource));

for (int i = 0; i < 3; ++i) {
if (i == 0) {
assertTrue(reader.start());
} else {
assertTrue(reader.advance());
}
assertEquals(i, (int) reader.getCurrent());
}
Checkpoint<Integer> checkpoint = reader.getCheckpointMark();
List<TimestampedValue<Integer>> residualElements = checkpoint.getResidualElements();
for (int i = 0; i < 7; ++i) {
TimestampedValue<Integer> element = residualElements.get(i);
assertEquals(i + 3, (int) element.getValue());
}
for (int i = 0; i < 100; ++i) {
// A WeakReference of an object that no other objects reference are not immediately added to
// ReferenceQueue. To test this, we should run System.gc() multiple times.
// If a reader is GCed without closing, `cleanQueue` throws a RuntimeException.
boundedSource.cleanQueue();
}
}

/** Generate byte array of given size. */
private static byte[] generateInput(int size) {
// Arbitrary but fixed seed
Expand All @@ -298,6 +340,7 @@ private static void writeFile(File file, byte[] input) throws IOException {

/** Unsplittable source for use in tests. */
private static class UnsplittableSource extends FileBasedSource<Byte> {

public UnsplittableSource(String fileOrPatternSpec, long minBundleSize) {
super(StaticValueProvider.of(fileOrPatternSpec), minBundleSize);
}
Expand All @@ -323,6 +366,7 @@ public Coder<Byte> getOutputCoder() {
}

private static class UnsplittableReader extends FileBasedReader<Byte> {

ByteBuffer buff = ByteBuffer.allocate(1);
Byte current;
long offset;
Expand Down Expand Up @@ -370,4 +414,140 @@ protected long getCurrentOffset() {
}
}
}

/**
* An integer generating bounded source. This source class checks if readers are closed properly.
* For that, it manages weak references of readers, and checks at `createReader` and `cleanQueue`
* if readers were closed before GCed. The `cleanQueue` does not change the state in
* `ManagedReaderBoundedSource`, but throws an exception if it finds a reader GCed without
* closing.
*/
private static class ManagedReaderBoundedSource extends BoundedSource<Integer> {

private final int from;
private final int to; // exclusive

private transient ReferenceQueue<ManagedReader> refQueue;
private transient Map<Reference<ManagedReader>, CloseStatus> cloesStatusMap;

public ManagedReaderBoundedSource(int from, int to) {
if (from > to) {
throw new RuntimeException(
String.format("`from` <= `to`, but got from: %d, to: %d", from, to));
}
this.from = from;
this.to = to;
}

@Override
public List<? extends BoundedSource<Integer>> split(
long desiredBundleSizeBytes, PipelineOptions options) {
return Collections.singletonList(this);
}

@Override
public long getEstimatedSizeBytes(PipelineOptions options) {
return (to - from) * 4L;
}

@Override
public BoundedReader<Integer> createReader(PipelineOptions options) {
// Add weak reference to queue to monitor GCed readers. If `CloseStatus` associated with
// reader is not closed, it means a reader was GCed without closing properly. The CloseStatus
// check for GCed readers are done at cleanQueue().
if (refQueue == null) {
refQueue = new ReferenceQueue<>();
cloesStatusMap = new HashMap<>();
}
cleanQueue();

CloseStatus status = new CloseStatus();
ManagedReader reader = new ManagedReader(status);
WeakReference<ManagedReader> reference = new WeakReference<>(reader, refQueue);
cloesStatusMap.put(reference, status);
LOG.info("Add reference {} for reader {}", reference, reader);
return reader;
}

public void cleanQueue() {
System.gc();

Reference<? extends ManagedReader> reference;
while ((reference = refQueue.poll()) != null) {
CloseStatus closeStatus = cloesStatusMap.get(reference);
LOG.info("Poll reference: {}, closed: {}", reference, closeStatus.closed);
closeStatus.throwIfNotClosed();
}
}

class CloseStatus {

private final RuntimeException allocationStacktrace;

private boolean closed;

public CloseStatus() {
allocationStacktrace =
new RuntimeException("Previous reader was not closed properly. Reader allocation was");
closed = false;
}

void close() {
cleanQueue();
closed = true;
}

void throwIfNotClosed() {
if (!closed) {
throw allocationStacktrace;
}
}
}

class ManagedReader extends BoundedReader<Integer> {

private final CloseStatus status;

int current;

public ManagedReader(CloseStatus status) {
this.status = status;
}

@Override
public boolean start() {
if (from < to) {
current = from;
return true;
} else {
return false;
}
}

@Override
public boolean advance() {
if (current + 1 < to) {
++current;
return true;
} else {
return false;
}
}

@Override
public Integer getCurrent() {
return current;
}

@Override
public void close() {
status.close();
}

@Override
public BoundedSource<Integer> getCurrentSource() {
return ManagedReaderBoundedSource.this;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,9 @@ public double getRemainingParallelism() {

private static class UnboundedReaderIterator<T>
extends NativeReader.NativeReaderIterator<WindowedValue<ValueWithRecordId<T>>> {
// Do not close reader. The reader is cached in StreamingModeExecutionContext.readerCache, and
// will be reused until the cache is evicted, expired or invalidated.
// See UnboundedReader#iterator().
private final UnboundedSource.UnboundedReader<T> reader;
private final StreamingModeExecutionContext context;
private final boolean started;
Expand Down Expand Up @@ -862,7 +865,9 @@ public WindowedValue<ValueWithRecordId<T>> getCurrent() throws NoSuchElementExce
}

@Override
public void close() {}
public void close() {
// Don't close reader.
}

@Override
public NativeReader.Progress getProgress() {
Expand Down
Loading