Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Knuth–Morris–Pratt algorithm for delimiter search in TextIO #32398

Merged
merged 6 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

## Bugfixes

* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)).

## Security Fixes
* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)).
Expand Down
257 changes: 185 additions & 72 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
*/
package org.apache.beam.sdk.io;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SeekableByteChannel;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.NoSuchElementException;
import org.apache.beam.sdk.coders.Coder;
Expand All @@ -34,6 +36,7 @@
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Bytes;
import org.checkerframework.checker.nullness.qual.Nullable;

/**
Expand Down Expand Up @@ -116,8 +119,14 @@ static class TextBasedReader extends FileBasedReader<String> {

private final byte @Nullable [] delimiter;
private final int skipHeaderLines;
private final ByteArrayOutputStream str;

// The output stream can contain the delimiter at the last. It must exclude the delimiter when
baeminbo marked this conversation as resolved.
Show resolved Hide resolved
// converting into currentValue.
private final SubstringByteArrayOutputStream str;

// Buffer for text read from the underlying file.
private final byte[] buffer;
// A wrapper of the `buffer` field;
private final ByteBuffer byteBuffer;

private ReadableByteChannel inChannel;
Expand All @@ -129,17 +138,24 @@ static class TextBasedReader extends FileBasedReader<String> {
private int bufferPosn = 0; // the current position in the buffer
private boolean skipLineFeedAtStart; // skip an LF if at the start of the next buffer

// Finder for custom delimiter.
private @Nullable DelimiterFinder delimiterFinder;

private TextBasedReader(TextSource source, byte[] delimiter) {
this(source, delimiter, 0);
}

private TextBasedReader(TextSource source, byte[] delimiter, int skipHeaderLines) {
super(source);
this.buffer = new byte[READ_BUFFER_SIZE];
this.str = new ByteArrayOutputStream();
this.str = new SubstringByteArrayOutputStream();
this.byteBuffer = ByteBuffer.wrap(buffer);
this.delimiter = delimiter;
this.skipHeaderLines = skipHeaderLines;

if (delimiter != null) {
delimiterFinder = new DelimiterFinder(delimiter);
}
}

@Override
Expand Down Expand Up @@ -377,106 +393,61 @@ private boolean readDefaultLine() throws IOException {
return true;
}

/**
* Loosely based upon <a
* href="https://github.com/hanborq/hadoop/blob/master/src/core/org/apache/hadoop/util/LineReader.java">Hadoop
* LineReader.java</a>
*
* <p>Note that this implementation fixes an issue where a partial match against the delimiter
* would have been lost if the delimiter crossed at the buffer boundaries during reading.
*/
private boolean readCustomLine() throws IOException {
assert !eof;
checkState(!eof);
checkNotNull(delimiter);
checkNotNull(
delimiterFinder, "DelimiterFinder must not be null if custom delimiter is used.");

long bytesConsumed = 0;
int delPosn = 0;
EOF:
for (; ; ) {
int startPosn = bufferPosn; // starting from where we left off the last time
delimiterFinder.reset();

// Read the next chunk from the file, ensure that we read at least one byte
// or reach EOF.
while (true) {
int startPosn = bufferPosn;
while (bufferPosn >= bufferLength) {
startPosn = bufferPosn = 0;
byteBuffer.clear();
bufferLength = inChannel.read(byteBuffer);

// If we are at EOF then try to create the last value from the buffer.
if (bufferLength < 0) {
eof = true;
scwhittle marked this conversation as resolved.
Show resolved Hide resolved

// Write any partial delimiter now that we are at EOF
if (delPosn != 0) {
str.write(delimiter, 0, delPosn);
}

// Don't return an empty record if the file ends with a delimiter
if (str.size() == 0) {
return false;
}

// Not ending with a delimiter.
currentValue = str.toString(StandardCharsets.UTF_8.name());
break EOF;
break;
}
}

int prevDelPosn = delPosn;
DELIMITER_MATCH:
{
if (delPosn > 0) {
// slow-path: Handle the case where we only matched part of the delimiter, possibly
// adding that to str fixing up any partially consumed delimiter if we don't match the
// whole delimiter
for (; bufferPosn < bufferLength; ++bufferPosn) {
if (buffer[bufferPosn] == delimiter[delPosn]) {
delPosn++;
if (delPosn == delimiter.length) {
bufferPosn++;
break DELIMITER_MATCH; // Skip matching the delimiter using the fast path
}
} else {
// Add to str any previous partial delimiter since we didn't match the whole
// delimiter
str.write(delimiter, 0, prevDelPosn);
if (buffer[bufferPosn] == delimiter[0]) {
delPosn = 1;
} else {
delPosn = 0;
}
break; // Leave this loop and use the fast-path delimiter matching
}
}
}
if (eof) {
baeminbo marked this conversation as resolved.
Show resolved Hide resolved
break;
}

// fast-path: Look for the delimiter within the buffer
for (; bufferPosn < bufferLength; ++bufferPosn) {
if (buffer[bufferPosn] == delimiter[delPosn]) {
delPosn++;
if (delPosn == delimiter.length) {
bufferPosn++;
break;
}
} else if (buffer[bufferPosn] == delimiter[0]) {
delPosn = 1;
} else {
delPosn = 0;
}
boolean delimiterFound = false;
for (; bufferPosn < bufferLength; ++bufferPosn) {
scwhittle marked this conversation as resolved.
Show resolved Hide resolved
if (delimiterFinder.feed(buffer[bufferPosn])) {
++bufferPosn;
delimiterFound = true;
break;
}
}

int readLength = bufferPosn - startPosn;
bytesConsumed += readLength;
int appendLength = readLength - (delPosn - prevDelPosn);
if (delPosn < delimiter.length) {
// Append the prefix of the value to str skipping the partial delimiter
str.write(buffer, startPosn, appendLength);
if (!delimiterFound) {
str.write(buffer, startPosn, readLength);
} else {
if (str.size() == 0) {
// Optimize for the common case where the string is wholly contained within the buffer
currentValue = new String(buffer, startPosn, appendLength, StandardCharsets.UTF_8);
currentValue =
new String(
buffer, startPosn, readLength - delimiter.length, StandardCharsets.UTF_8);
} else {
str.write(buffer, startPosn, appendLength);
currentValue = str.toString(StandardCharsets.UTF_8.name());
str.write(buffer, startPosn, readLength);
currentValue = str.toString(0, str.size() - delimiter.length, StandardCharsets.UTF_8);
scwhittle marked this conversation as resolved.
Show resolved Hide resolved
}
break;
}
Expand All @@ -487,4 +458,146 @@ private boolean readCustomLine() throws IOException {
return true;
}
}

/**
* This class is created to avoid multiple bytes-copy when making a substring of the output.
* Without this class, it requires two bytes copies.
*
* <pre>{@code
* ByteArrayOutputStream out = ...;
* byte[] buffer = out.toByteArray(); // 1st-copy
* String s = new String(buffer, offset, length); // 2nd-copy
* }</pre>
*/
static class SubstringByteArrayOutputStream extends ByteArrayOutputStream {
public String toString(int offset, int length, Charset charset) {
if (offset < 0) {
throw new IllegalArgumentException("offset is negative: " + offset);
}
if (offset > count) {
throw new IllegalArgumentException(
"offset exceeds the buffer limit. offset: " + offset + ", limit: " + count);
}

if (length < 0) {
throw new IllegalArgumentException("length is negative: " + length);
}

if (offset + length > count) {
throw new IllegalArgumentException(
"offset + length exceeds the buffer limit. offset: "
+ offset
+ ", length: "
+ length
+ ", limit: "
+ count);
}

return new String(buf, offset, length, charset);
}
}

/**
* A state machine to match the delimiter in a byte stream.
*
* <pre>{@code
* DelimiterFinder finder = new DelimiterFinder([65, 65, 66]); // "AAB"
* finder.feed(65); // false. "A"
* finder.feed(66); // false. "AB"
* finder.feed(65); // false. "ABA"
* finder.feed(65); // false. "ABAA"
* finder.feed(66); // true. "ABAAB"
* finder.feed(65); // false. "ABAABA"
*
* }</pre>
*/
static class DelimiterFinder {
private final byte[] delimiter;
private final int[] subsequences;
private final int[][] trans;

int position;

public DelimiterFinder(byte[] delimiter) {
this.delimiter = delimiter;
subsequences = new int[delimiter.length];
trans = new int[delimiter.length][256];
compile();
}

public boolean feed(byte b) {
position = trans[position][byteToIndex(b)];
if (position == delimiter.length) {
position = 0;
return true;
} else {
return false;
}
}

public void reset() {
position = 0;
}

public String describe() {
StringBuilder sb = new StringBuilder();
sb.append("delimiter:\n").append(Bytes.asList(delimiter)).append('\n');
sb.append("subsequences and trans:\n");
for (int i = 0; i < delimiter.length; ++i) {
sb.append(i).append(" (").append(subsequences[i]).append("): ");
for (int b = Byte.MIN_VALUE; b <= Byte.MAX_VALUE; ++b) {
int tran = trans[i][byteToIndex((byte) b)];
if (tran > 0) {
sb.append(b).append(" -> ").append(tran).append(", ");
}
}
sb.append("\n");
}
sb.append("current position: ").append(position).append("\n");
return sb.toString();
}

private void compile() {
// e.g. "AABAAC":
// 0 -> 0, "" -> ""
// 1 -> 0: "A" -> ""
// 2 -> 1: "AA" -> "A"
// 3 -> 0: "AAB" -> "B"
// 4 -> 1: "AABA" -> "A"
// 5 -> 2: "AABAA" -> "AA"

for (int i = 2; i < delimiter.length; i++) {
if (delimiter[i - 1] == delimiter[subsequences[i - 1]]) {
subsequences[i] = subsequences[i - 1] + 1;
} else if (delimiter[i - 1] == delimiter[0]) {
subsequences[i] = 1;
} else {
subsequences[i] = 0;
}
}
// e.g. "AABAAC":
// index (subsequence): trans
// 0 (0): "A" -> 1
// 1 (0): "A" -> 2
// 2 (1): "A" -> 2, "B" -> 3
// 3 (0): "A" -> 4
// 4 (1): "A" -> 5
// 5 (2): "A" -> 2, "B" -> 3, "C" -> 6
trans[0][byteToIndex(delimiter[0])] = 1;

for (int i = 1; i < delimiter.length; i++) {
for (int b = Byte.MIN_VALUE; b <= Byte.MAX_VALUE; b++) {
baeminbo marked this conversation as resolved.
Show resolved Hide resolved
if (b == delimiter[i]) {
trans[i][byteToIndex((byte) b)] = i + 1;
} else {
trans[i][byteToIndex((byte) b)] = trans[subsequences[i]][byteToIndex((byte) b)];
}
}
}
}

private int byteToIndex(byte b) {
return b + 128;
}
}
}
Loading
Loading