Skip to content

Commit

Permalink
Use Knuth–Morris–Pratt algorithm for delimiter search in TextIO (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
baeminbo authored and reeba212 committed Dec 4, 2024
1 parent 96ae2c9 commit a872c67
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 77 deletions.
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@

## Bugfixes

* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)).

## Security Fixes
* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)).
Expand Down
226 changes: 150 additions & 76 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
*/
package org.apache.beam.sdk.io;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SeekableByteChannel;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.NoSuchElementException;
import org.apache.beam.sdk.coders.Coder;
Expand Down Expand Up @@ -116,8 +118,13 @@ static class TextBasedReader extends FileBasedReader<String> {

private final byte @Nullable [] delimiter;
private final int skipHeaderLines;
private final ByteArrayOutputStream str;

// Used to build up results that span buffers. It may contain the delimiter as a suffix.
private final SubstringByteArrayOutputStream str;

// Buffer for text read from the underlying file.
private final byte[] buffer;
// A wrapper of the `buffer` field;
private final ByteBuffer byteBuffer;

private ReadableByteChannel inChannel;
Expand All @@ -129,17 +136,24 @@ static class TextBasedReader extends FileBasedReader<String> {
private int bufferPosn = 0; // the current position in the buffer
private boolean skipLineFeedAtStart; // skip an LF if at the start of the next buffer

// Finder for custom delimiter.
private @Nullable KMPDelimiterFinder delimiterFinder;

private TextBasedReader(TextSource source, byte[] delimiter) {
this(source, delimiter, 0);
}

private TextBasedReader(TextSource source, byte[] delimiter, int skipHeaderLines) {
super(source);
this.buffer = new byte[READ_BUFFER_SIZE];
this.str = new ByteArrayOutputStream();
this.str = new SubstringByteArrayOutputStream();
this.byteBuffer = ByteBuffer.wrap(buffer);
this.delimiter = delimiter;
this.skipHeaderLines = skipHeaderLines;

if (delimiter != null) {
delimiterFinder = new KMPDelimiterFinder(delimiter);
}
}

@Override
Expand Down Expand Up @@ -377,106 +391,60 @@ private boolean readDefaultLine() throws IOException {
return true;
}

/**
* Loosely based upon <a
* href="https://github.com/hanborq/hadoop/blob/master/src/core/org/apache/hadoop/util/LineReader.java">Hadoop
* LineReader.java</a>
*
* <p>Note that this implementation fixes an issue where a partial match against the delimiter
* would have been lost if the delimiter crossed at the buffer boundaries during reading.
*/
private boolean readCustomLine() throws IOException {
assert !eof;
checkState(!eof);
checkNotNull(delimiter);
checkNotNull(
delimiterFinder, "DelimiterFinder must not be null if custom delimiter is used.");

long bytesConsumed = 0;
int delPosn = 0;
EOF:
for (; ; ) {
int startPosn = bufferPosn; // starting from where we left off the last time
delimiterFinder.reset();

// Read the next chunk from the file, ensure that we read at least one byte
// or reach EOF.
while (bufferPosn >= bufferLength) {
startPosn = bufferPosn = 0;
while (true) {
if (bufferPosn >= bufferLength) {
bufferPosn = 0;
byteBuffer.clear();
bufferLength = inChannel.read(byteBuffer);

// If we are at EOF then try to create the last value from the buffer.
do {
bufferLength = inChannel.read(byteBuffer);
} while (bufferLength == 0);

if (bufferLength < 0) {
eof = true;

// Write any partial delimiter now that we are at EOF
if (delPosn != 0) {
str.write(delimiter, 0, delPosn);
}

// Don't return an empty record if the file ends with a delimiter
if (str.size() == 0) {
return false;
}

// Not ending with a delimiter.
currentValue = str.toString(StandardCharsets.UTF_8.name());
break EOF;
break;
}
}

int prevDelPosn = delPosn;
DELIMITER_MATCH:
{
if (delPosn > 0) {
// slow-path: Handle the case where we only matched part of the delimiter, possibly
// adding that to str fixing up any partially consumed delimiter if we don't match the
// whole delimiter
for (; bufferPosn < bufferLength; ++bufferPosn) {
if (buffer[bufferPosn] == delimiter[delPosn]) {
delPosn++;
if (delPosn == delimiter.length) {
bufferPosn++;
break DELIMITER_MATCH; // Skip matching the delimiter using the fast path
}
} else {
// Add to str any previous partial delimiter since we didn't match the whole
// delimiter
str.write(delimiter, 0, prevDelPosn);
if (buffer[bufferPosn] == delimiter[0]) {
delPosn = 1;
} else {
delPosn = 0;
}
break; // Leave this loop and use the fast-path delimiter matching
}
}
}

// fast-path: Look for the delimiter within the buffer
for (; bufferPosn < bufferLength; ++bufferPosn) {
if (buffer[bufferPosn] == delimiter[delPosn]) {
delPosn++;
if (delPosn == delimiter.length) {
bufferPosn++;
break;
}
} else if (buffer[bufferPosn] == delimiter[0]) {
delPosn = 1;
} else {
delPosn = 0;
}
int startPosn = bufferPosn;
boolean delimiterFound = false;
for (; bufferPosn < bufferLength; ++bufferPosn) {
if (delimiterFinder.feed(buffer[bufferPosn])) {
++bufferPosn;
delimiterFound = true;
break;
}
}

int readLength = bufferPosn - startPosn;
bytesConsumed += readLength;
int appendLength = readLength - (delPosn - prevDelPosn);
if (delPosn < delimiter.length) {
// Append the prefix of the value to str skipping the partial delimiter
str.write(buffer, startPosn, appendLength);
if (!delimiterFound) {
str.write(buffer, startPosn, readLength);
} else {
if (str.size() == 0) {
// Optimize for the common case where the string is wholly contained within the buffer
currentValue = new String(buffer, startPosn, appendLength, StandardCharsets.UTF_8);
currentValue =
new String(
buffer, startPosn, readLength - delimiter.length, StandardCharsets.UTF_8);
} else {
str.write(buffer, startPosn, appendLength);
currentValue = str.toString(StandardCharsets.UTF_8.name());
str.write(buffer, startPosn, readLength);
currentValue = str.toString(0, str.size() - delimiter.length, StandardCharsets.UTF_8);
}
break;
}
Expand All @@ -487,4 +455,110 @@ private boolean readCustomLine() throws IOException {
return true;
}
}

/**
* This class is created to avoid multiple bytes-copy when making a substring of the output.
* Without this class, it requires two bytes copies.
*
* <pre>{@code
* ByteArrayOutputStream out = ...;
* byte[] buffer = out.toByteArray(); // 1st-copy
* String s = new String(buffer, offset, length); // 2nd-copy
* }</pre>
*/
static class SubstringByteArrayOutputStream extends ByteArrayOutputStream {
public String toString(int offset, int length, Charset charset) {
if (offset < 0) {
throw new IllegalArgumentException("offset is negative: " + offset);
}
if (offset > count) {
throw new IllegalArgumentException(
"offset exceeds the buffer limit. offset: " + offset + ", limit: " + count);
}

if (length < 0) {
throw new IllegalArgumentException("length is negative: " + length);
}

if (offset + length > count) {
throw new IllegalArgumentException(
"offset + length exceeds the buffer limit. offset: "
+ offset
+ ", length: "
+ length
+ ", limit: "
+ count);
}

return new String(buf, offset, length, charset);
}
}

/**
* @see <a
* href="https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm">Knuth–Morris–Pratt
* algorithm</a>
*/
static class KMPDelimiterFinder {
private final byte[] delimiter;
private final int[] table;
int delimiterOffset; // the current position in delimiter

public KMPDelimiterFinder(byte[] delimiter) {
this.delimiter = delimiter;
this.table = new int[delimiter.length];
compile();
}

public boolean feed(byte b) {
// Modified "Description of pseudocode for the search algorithm" in Wikipedia
while (true) {
if (b == delimiter[delimiterOffset]) {
++delimiterOffset;
if (delimiterOffset == delimiter.length) {
// return when the first occurrence is found.
delimiterOffset = 0;
return true;
}
return false;
}

delimiterOffset = table[delimiterOffset];
if (delimiterOffset < 0) {
++delimiterOffset;
return false;
}
}
}

public void reset() {
delimiterOffset = 0;
}

private void compile() {
// the current position in table
int pos = 1;
// the zero-based index in delimiter of the next character of the current candidate substring
int cnd = 0;

table[0] = -1;

while (pos < delimiter.length) {
if (delimiter[pos] == delimiter[cnd]) {
table[pos] = table[cnd];
} else {
table[pos] = cnd;
while (cnd >= 0 && delimiter[pos] != delimiter[cnd]) {
cnd = table[cnd];
}
}

++pos;
++cnd;
}

// We don't need the table entry at (pos + 1) in "Description of pseudocode for the
// table-building algorithm" in Wikipedia because we only checks the first occurrence.
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.zip.GZIPOutputStream;
import java.util.zip.ZipEntry;
Expand Down Expand Up @@ -696,6 +697,73 @@ public void testReadStringsWithCustomDelimiter() throws Exception {
p.run();
}

@Test
@Category(NeedsRunner.class)
public void testReadStringsWithCustomDelimiter_NestedDelimiter() throws IOException {
// Test for https://github.com/apache/beam/issues/32251
String delimiter = "AABAAC";
String text = "0AABAAC1AABAABAAC2AABAAABAAC";
List<String> expected = Arrays.asList("0", "1AAB", "2AABA");
assertEquals(
expected, Arrays.asList(Pattern.compile(delimiter, Pattern.LITERAL).split(text)));

File tmpFile = tempFolder.newFile("tmpfile.txt");
String filename = tmpFile.getPath();

try (Writer writer = Files.newBufferedWriter(tmpFile.toPath(), UTF_8)) {
writer.write(text);
}

PAssert.that(
p.apply(
TextIO.read()
.from(filename)
.withDelimiter(delimiter.getBytes(StandardCharsets.UTF_8))))
.containsInAnyOrder(expected);
p.run();
}

@Test
@Category(NeedsRunner.class)
public void testReadStringsWithCustomDelimiter_PartialDelimiterMatchedAtBoundary()
throws IOException {
// Test for https://github.com/apache/beam/issues/32249
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 8190; ++i) {
sb.append('0');
}
sb.append('A'); // index 8190
sb.append('B'); // index 8191
sb.append('C'); // index 8192
for (int i = 8193; i < 16400; ++i) {
sb.append('0');
}

String text = sb.toString();
String delimiter = "ABCDE";

// check the text is not split by the delimiter
assertEquals(
Collections.singletonList(text),
Arrays.asList(Pattern.compile(delimiter, Pattern.LITERAL).split(text)));

File tmpFile = tempFolder.newFile("tmpfile.txt");
String filename = tmpFile.getPath();

try (Writer writer = Files.newBufferedWriter(tmpFile.toPath(), UTF_8)) {
writer.write(text);
}

// Expects no IndexOutOfBoundsException
PAssert.that(
p.apply(
TextIO.read()
.from(filename)
.withDelimiter(delimiter.getBytes(StandardCharsets.UTF_8))))
.containsInAnyOrder(text);
p.run();
}

@Test
@Category(NeedsRunner.class)
public void testReadStrings() throws Exception {
Expand Down
Loading

0 comments on commit a872c67

Please sign in to comment.