Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Knuth–Morris–Pratt algorithm for delimiter search in TextIO #32398

Merged
merged 6 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

## Bugfixes

* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)).

## Security Fixes
* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)).
Expand Down
226 changes: 150 additions & 76 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
*/
package org.apache.beam.sdk.io;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SeekableByteChannel;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.NoSuchElementException;
import org.apache.beam.sdk.coders.Coder;
Expand Down Expand Up @@ -116,8 +118,13 @@ static class TextBasedReader extends FileBasedReader<String> {

private final byte @Nullable [] delimiter;
private final int skipHeaderLines;
private final ByteArrayOutputStream str;

// Used to build up results that span buffers. It may contain the delimiter as a suffix.
private final SubstringByteArrayOutputStream str;

// Buffer for text read from the underlying file.
private final byte[] buffer;
// A wrapper of the `buffer` field;
private final ByteBuffer byteBuffer;

private ReadableByteChannel inChannel;
Expand All @@ -129,17 +136,24 @@ static class TextBasedReader extends FileBasedReader<String> {
private int bufferPosn = 0; // the current position in the buffer
private boolean skipLineFeedAtStart; // skip an LF if at the start of the next buffer

// Finder for custom delimiter.
private @Nullable KMPDelimiterFinder delimiterFinder;

private TextBasedReader(TextSource source, byte[] delimiter) {
this(source, delimiter, 0);
}

private TextBasedReader(TextSource source, byte[] delimiter, int skipHeaderLines) {
super(source);
this.buffer = new byte[READ_BUFFER_SIZE];
this.str = new ByteArrayOutputStream();
this.str = new SubstringByteArrayOutputStream();
this.byteBuffer = ByteBuffer.wrap(buffer);
this.delimiter = delimiter;
this.skipHeaderLines = skipHeaderLines;

if (delimiter != null) {
delimiterFinder = new KMPDelimiterFinder(delimiter);
}
}

@Override
Expand Down Expand Up @@ -377,106 +391,60 @@ private boolean readDefaultLine() throws IOException {
return true;
}

/**
* Loosely based upon <a
* href="https://github.com/hanborq/hadoop/blob/master/src/core/org/apache/hadoop/util/LineReader.java">Hadoop
* LineReader.java</a>
*
* <p>Note that this implementation fixes an issue where a partial match against the delimiter
* would have been lost if the delimiter crossed at the buffer boundaries during reading.
*/
private boolean readCustomLine() throws IOException {
assert !eof;
checkState(!eof);
checkNotNull(delimiter);
checkNotNull(
delimiterFinder, "DelimiterFinder must not be null if custom delimiter is used.");

long bytesConsumed = 0;
int delPosn = 0;
EOF:
for (; ; ) {
int startPosn = bufferPosn; // starting from where we left off the last time
delimiterFinder.reset();

// Read the next chunk from the file, ensure that we read at least one byte
// or reach EOF.
while (bufferPosn >= bufferLength) {
startPosn = bufferPosn = 0;
while (true) {
if (bufferPosn >= bufferLength) {
bufferPosn = 0;
byteBuffer.clear();
bufferLength = inChannel.read(byteBuffer);

// If we are at EOF then try to create the last value from the buffer.
do {
bufferLength = inChannel.read(byteBuffer);
} while (bufferLength == 0);

if (bufferLength < 0) {
eof = true;
scwhittle marked this conversation as resolved.
Show resolved Hide resolved

// Write any partial delimiter now that we are at EOF
if (delPosn != 0) {
str.write(delimiter, 0, delPosn);
}

// Don't return an empty record if the file ends with a delimiter
if (str.size() == 0) {
return false;
}

// Not ending with a delimiter.
currentValue = str.toString(StandardCharsets.UTF_8.name());
break EOF;
break;
}
}

int prevDelPosn = delPosn;
DELIMITER_MATCH:
{
if (delPosn > 0) {
// slow-path: Handle the case where we only matched part of the delimiter, possibly
// adding that to str fixing up any partially consumed delimiter if we don't match the
// whole delimiter
for (; bufferPosn < bufferLength; ++bufferPosn) {
if (buffer[bufferPosn] == delimiter[delPosn]) {
delPosn++;
if (delPosn == delimiter.length) {
bufferPosn++;
break DELIMITER_MATCH; // Skip matching the delimiter using the fast path
}
} else {
// Add to str any previous partial delimiter since we didn't match the whole
// delimiter
str.write(delimiter, 0, prevDelPosn);
if (buffer[bufferPosn] == delimiter[0]) {
delPosn = 1;
} else {
delPosn = 0;
}
break; // Leave this loop and use the fast-path delimiter matching
}
}
}

// fast-path: Look for the delimiter within the buffer
for (; bufferPosn < bufferLength; ++bufferPosn) {
if (buffer[bufferPosn] == delimiter[delPosn]) {
delPosn++;
if (delPosn == delimiter.length) {
bufferPosn++;
break;
}
} else if (buffer[bufferPosn] == delimiter[0]) {
delPosn = 1;
} else {
delPosn = 0;
}
int startPosn = bufferPosn;
boolean delimiterFound = false;
for (; bufferPosn < bufferLength; ++bufferPosn) {
scwhittle marked this conversation as resolved.
Show resolved Hide resolved
if (delimiterFinder.feed(buffer[bufferPosn])) {
++bufferPosn;
delimiterFound = true;
break;
}
}

int readLength = bufferPosn - startPosn;
bytesConsumed += readLength;
int appendLength = readLength - (delPosn - prevDelPosn);
if (delPosn < delimiter.length) {
// Append the prefix of the value to str skipping the partial delimiter
str.write(buffer, startPosn, appendLength);
if (!delimiterFound) {
str.write(buffer, startPosn, readLength);
} else {
if (str.size() == 0) {
// Optimize for the common case where the string is wholly contained within the buffer
currentValue = new String(buffer, startPosn, appendLength, StandardCharsets.UTF_8);
currentValue =
new String(
buffer, startPosn, readLength - delimiter.length, StandardCharsets.UTF_8);
} else {
str.write(buffer, startPosn, appendLength);
currentValue = str.toString(StandardCharsets.UTF_8.name());
str.write(buffer, startPosn, readLength);
currentValue = str.toString(0, str.size() - delimiter.length, StandardCharsets.UTF_8);
scwhittle marked this conversation as resolved.
Show resolved Hide resolved
}
break;
}
Expand All @@ -487,4 +455,110 @@ private boolean readCustomLine() throws IOException {
return true;
}
}

/**
* This class is created to avoid multiple bytes-copy when making a substring of the output.
* Without this class, it requires two bytes copies.
*
* <pre>{@code
* ByteArrayOutputStream out = ...;
* byte[] buffer = out.toByteArray(); // 1st-copy
* String s = new String(buffer, offset, length); // 2nd-copy
* }</pre>
*/
static class SubstringByteArrayOutputStream extends ByteArrayOutputStream {
public String toString(int offset, int length, Charset charset) {
if (offset < 0) {
throw new IllegalArgumentException("offset is negative: " + offset);
}
if (offset > count) {
throw new IllegalArgumentException(
"offset exceeds the buffer limit. offset: " + offset + ", limit: " + count);
}

if (length < 0) {
throw new IllegalArgumentException("length is negative: " + length);
}

if (offset + length > count) {
throw new IllegalArgumentException(
"offset + length exceeds the buffer limit. offset: "
+ offset
+ ", length: "
+ length
+ ", limit: "
+ count);
}

return new String(buf, offset, length, charset);
}
}

/**
* @see <a
* href="https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm">Knuth–Morris–Pratt
* algorithm</a>
*/
static class KMPDelimiterFinder {
private final byte[] delimiter;
private final int[] table;
int delimiterOffset; // the current position in delimiter

public KMPDelimiterFinder(byte[] delimiter) {
this.delimiter = delimiter;
this.table = new int[delimiter.length];
compile();
}

public boolean feed(byte b) {
// Modified "Description of pseudocode for the search algorithm" in Wikipedia
while (true) {
if (b == delimiter[delimiterOffset]) {
++delimiterOffset;
if (delimiterOffset == delimiter.length) {
// return when the first occurrence is found.
delimiterOffset = 0;
return true;
}
return false;
}

delimiterOffset = table[delimiterOffset];
if (delimiterOffset < 0) {
++delimiterOffset;
return false;
}
}
}

public void reset() {
delimiterOffset = 0;
}

private void compile() {
// the current position in table
int pos = 1;
// the zero-based index in delimiter of the next character of the current candidate substring
int cnd = 0;

table[0] = -1;

while (pos < delimiter.length) {
if (delimiter[pos] == delimiter[cnd]) {
table[pos] = table[cnd];
} else {
table[pos] = cnd;
while (cnd >= 0 && delimiter[pos] != delimiter[cnd]) {
cnd = table[cnd];
}
}

++pos;
++cnd;
}

// We don't need the table entry at (pos + 1) in "Description of pseudocode for the
// table-building algorithm" in Wikipedia because we only checks the first occurrence.
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.zip.GZIPOutputStream;
import java.util.zip.ZipEntry;
Expand Down Expand Up @@ -696,6 +697,73 @@ public void testReadStringsWithCustomDelimiter() throws Exception {
p.run();
}

@Test
@Category(NeedsRunner.class)
public void testReadStringsWithCustomDelimiter_NestedDelimiter() throws IOException {
// Test for https://github.com/apache/beam/issues/32251
String delimiter = "AABAAC";
String text = "0AABAAC1AABAABAAC2AABAAABAAC";
List<String> expected = Arrays.asList("0", "1AAB", "2AABA");
assertEquals(
expected, Arrays.asList(Pattern.compile(delimiter, Pattern.LITERAL).split(text)));

File tmpFile = tempFolder.newFile("tmpfile.txt");
String filename = tmpFile.getPath();

try (Writer writer = Files.newBufferedWriter(tmpFile.toPath(), UTF_8)) {
writer.write(text);
}

PAssert.that(
p.apply(
TextIO.read()
.from(filename)
.withDelimiter(delimiter.getBytes(StandardCharsets.UTF_8))))
.containsInAnyOrder(expected);
p.run();
}

@Test
@Category(NeedsRunner.class)
public void testReadStringsWithCustomDelimiter_PartialDelimiterMatchedAtBoundary()
throws IOException {
// Test for https://github.com/apache/beam/issues/32249
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 8190; ++i) {
sb.append('0');
}
sb.append('A'); // index 8190
sb.append('B'); // index 8191
sb.append('C'); // index 8192
for (int i = 8193; i < 16400; ++i) {
sb.append('0');
}

String text = sb.toString();
String delimiter = "ABCDE";

// check the text is not split by the delimiter
assertEquals(
Collections.singletonList(text),
Arrays.asList(Pattern.compile(delimiter, Pattern.LITERAL).split(text)));

File tmpFile = tempFolder.newFile("tmpfile.txt");
String filename = tmpFile.getPath();

try (Writer writer = Files.newBufferedWriter(tmpFile.toPath(), UTF_8)) {
writer.write(text);
}

// Expects no IndexOutOfBoundsException
PAssert.that(
p.apply(
TextIO.read()
.from(filename)
.withDelimiter(delimiter.getBytes(StandardCharsets.UTF_8))))
.containsInAnyOrder(text);
p.run();
}

@Test
@Category(NeedsRunner.class)
public void testReadStrings() throws Exception {
Expand Down
Loading
Loading