diff --git a/CHANGES.md b/CHANGES.md index 9f9b2f6f80f6..e9f6113a181a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -84,7 +84,7 @@ ## Bugfixes -* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)). ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java index eb7c6b2db1f8..97cb0c2b1a0f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextSource.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.io.ByteArrayOutputStream; @@ -24,6 +25,7 @@ import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.channels.SeekableByteChannel; +import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.NoSuchElementException; import org.apache.beam.sdk.coders.Coder; @@ -116,8 +118,13 @@ static class TextBasedReader extends FileBasedReader { private final byte @Nullable [] delimiter; private final int skipHeaderLines; - private final ByteArrayOutputStream str; + + // Used to build up results that span buffers. It may contain the delimiter as a suffix. + private final SubstringByteArrayOutputStream str; + + // Buffer for text read from the underlying file. private final byte[] buffer; + // A wrapper of the `buffer` field; private final ByteBuffer byteBuffer; private ReadableByteChannel inChannel; @@ -129,6 +136,9 @@ static class TextBasedReader extends FileBasedReader { private int bufferPosn = 0; // the current position in the buffer private boolean skipLineFeedAtStart; // skip an LF if at the start of the next buffer + // Finder for custom delimiter. + private @Nullable KMPDelimiterFinder delimiterFinder; + private TextBasedReader(TextSource source, byte[] delimiter) { this(source, delimiter, 0); } @@ -136,10 +146,14 @@ private TextBasedReader(TextSource source, byte[] delimiter) { private TextBasedReader(TextSource source, byte[] delimiter, int skipHeaderLines) { super(source); this.buffer = new byte[READ_BUFFER_SIZE]; - this.str = new ByteArrayOutputStream(); + this.str = new SubstringByteArrayOutputStream(); this.byteBuffer = ByteBuffer.wrap(buffer); this.delimiter = delimiter; this.skipHeaderLines = skipHeaderLines; + + if (delimiter != null) { + delimiterFinder = new KMPDelimiterFinder(delimiter); + } } @Override @@ -377,106 +391,60 @@ private boolean readDefaultLine() throws IOException { return true; } - /** - * Loosely based upon Hadoop - * LineReader.java - * - *

Note that this implementation fixes an issue where a partial match against the delimiter - * would have been lost if the delimiter crossed at the buffer boundaries during reading. - */ private boolean readCustomLine() throws IOException { - assert !eof; + checkState(!eof); + checkNotNull(delimiter); + checkNotNull( + delimiterFinder, "DelimiterFinder must not be null if custom delimiter is used."); long bytesConsumed = 0; - int delPosn = 0; - EOF: - for (; ; ) { - int startPosn = bufferPosn; // starting from where we left off the last time + delimiterFinder.reset(); - // Read the next chunk from the file, ensure that we read at least one byte - // or reach EOF. - while (bufferPosn >= bufferLength) { - startPosn = bufferPosn = 0; + while (true) { + if (bufferPosn >= bufferLength) { + bufferPosn = 0; byteBuffer.clear(); - bufferLength = inChannel.read(byteBuffer); - // If we are at EOF then try to create the last value from the buffer. + do { + bufferLength = inChannel.read(byteBuffer); + } while (bufferLength == 0); + if (bufferLength < 0) { eof = true; - // Write any partial delimiter now that we are at EOF - if (delPosn != 0) { - str.write(delimiter, 0, delPosn); - } - - // Don't return an empty record if the file ends with a delimiter if (str.size() == 0) { return false; } + // Not ending with a delimiter. currentValue = str.toString(StandardCharsets.UTF_8.name()); - break EOF; + break; } } - int prevDelPosn = delPosn; - DELIMITER_MATCH: - { - if (delPosn > 0) { - // slow-path: Handle the case where we only matched part of the delimiter, possibly - // adding that to str fixing up any partially consumed delimiter if we don't match the - // whole delimiter - for (; bufferPosn < bufferLength; ++bufferPosn) { - if (buffer[bufferPosn] == delimiter[delPosn]) { - delPosn++; - if (delPosn == delimiter.length) { - bufferPosn++; - break DELIMITER_MATCH; // Skip matching the delimiter using the fast path - } - } else { - // Add to str any previous partial delimiter since we didn't match the whole - // delimiter - str.write(delimiter, 0, prevDelPosn); - if (buffer[bufferPosn] == delimiter[0]) { - delPosn = 1; - } else { - delPosn = 0; - } - break; // Leave this loop and use the fast-path delimiter matching - } - } - } - - // fast-path: Look for the delimiter within the buffer - for (; bufferPosn < bufferLength; ++bufferPosn) { - if (buffer[bufferPosn] == delimiter[delPosn]) { - delPosn++; - if (delPosn == delimiter.length) { - bufferPosn++; - break; - } - } else if (buffer[bufferPosn] == delimiter[0]) { - delPosn = 1; - } else { - delPosn = 0; - } + int startPosn = bufferPosn; + boolean delimiterFound = false; + for (; bufferPosn < bufferLength; ++bufferPosn) { + if (delimiterFinder.feed(buffer[bufferPosn])) { + ++bufferPosn; + delimiterFound = true; + break; } } int readLength = bufferPosn - startPosn; bytesConsumed += readLength; - int appendLength = readLength - (delPosn - prevDelPosn); - if (delPosn < delimiter.length) { - // Append the prefix of the value to str skipping the partial delimiter - str.write(buffer, startPosn, appendLength); + if (!delimiterFound) { + str.write(buffer, startPosn, readLength); } else { if (str.size() == 0) { // Optimize for the common case where the string is wholly contained within the buffer - currentValue = new String(buffer, startPosn, appendLength, StandardCharsets.UTF_8); + currentValue = + new String( + buffer, startPosn, readLength - delimiter.length, StandardCharsets.UTF_8); } else { - str.write(buffer, startPosn, appendLength); - currentValue = str.toString(StandardCharsets.UTF_8.name()); + str.write(buffer, startPosn, readLength); + currentValue = str.toString(0, str.size() - delimiter.length, StandardCharsets.UTF_8); } break; } @@ -487,4 +455,110 @@ private boolean readCustomLine() throws IOException { return true; } } + + /** + * This class is created to avoid multiple bytes-copy when making a substring of the output. + * Without this class, it requires two bytes copies. + * + *

{@code
+   * ByteArrayOutputStream out = ...;
+   * byte[] buffer = out.toByteArray(); // 1st-copy
+   * String s = new String(buffer, offset, length); // 2nd-copy
+   * }
+ */ + static class SubstringByteArrayOutputStream extends ByteArrayOutputStream { + public String toString(int offset, int length, Charset charset) { + if (offset < 0) { + throw new IllegalArgumentException("offset is negative: " + offset); + } + if (offset > count) { + throw new IllegalArgumentException( + "offset exceeds the buffer limit. offset: " + offset + ", limit: " + count); + } + + if (length < 0) { + throw new IllegalArgumentException("length is negative: " + length); + } + + if (offset + length > count) { + throw new IllegalArgumentException( + "offset + length exceeds the buffer limit. offset: " + + offset + + ", length: " + + length + + ", limit: " + + count); + } + + return new String(buf, offset, length, charset); + } + } + + /** + * @see Knuth–Morris–Pratt + * algorithm + */ + static class KMPDelimiterFinder { + private final byte[] delimiter; + private final int[] table; + int delimiterOffset; // the current position in delimiter + + public KMPDelimiterFinder(byte[] delimiter) { + this.delimiter = delimiter; + this.table = new int[delimiter.length]; + compile(); + } + + public boolean feed(byte b) { + // Modified "Description of pseudocode for the search algorithm" in Wikipedia + while (true) { + if (b == delimiter[delimiterOffset]) { + ++delimiterOffset; + if (delimiterOffset == delimiter.length) { + // return when the first occurrence is found. + delimiterOffset = 0; + return true; + } + return false; + } + + delimiterOffset = table[delimiterOffset]; + if (delimiterOffset < 0) { + ++delimiterOffset; + return false; + } + } + } + + public void reset() { + delimiterOffset = 0; + } + + private void compile() { + // the current position in table + int pos = 1; + // the zero-based index in delimiter of the next character of the current candidate substring + int cnd = 0; + + table[0] = -1; + + while (pos < delimiter.length) { + if (delimiter[pos] == delimiter[cnd]) { + table[pos] = table[cnd]; + } else { + table[pos] = cnd; + while (cnd >= 0 && delimiter[pos] != delimiter[cnd]) { + cnd = table[cnd]; + } + } + + ++pos; + ++cnd; + } + + // We don't need the table entry at (pos + 1) in "Description of pseudocode for the + // table-building algorithm" in Wikipedia because we only checks the first occurrence. + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOReadTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOReadTest.java index 11e239e2711f..39d039f5bb9b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOReadTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOReadTest.java @@ -58,6 +58,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.zip.GZIPOutputStream; import java.util.zip.ZipEntry; @@ -696,6 +697,73 @@ public void testReadStringsWithCustomDelimiter() throws Exception { p.run(); } + @Test + @Category(NeedsRunner.class) + public void testReadStringsWithCustomDelimiter_NestedDelimiter() throws IOException { + // Test for https://github.com/apache/beam/issues/32251 + String delimiter = "AABAAC"; + String text = "0AABAAC1AABAABAAC2AABAAABAAC"; + List expected = Arrays.asList("0", "1AAB", "2AABA"); + assertEquals( + expected, Arrays.asList(Pattern.compile(delimiter, Pattern.LITERAL).split(text))); + + File tmpFile = tempFolder.newFile("tmpfile.txt"); + String filename = tmpFile.getPath(); + + try (Writer writer = Files.newBufferedWriter(tmpFile.toPath(), UTF_8)) { + writer.write(text); + } + + PAssert.that( + p.apply( + TextIO.read() + .from(filename) + .withDelimiter(delimiter.getBytes(StandardCharsets.UTF_8)))) + .containsInAnyOrder(expected); + p.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testReadStringsWithCustomDelimiter_PartialDelimiterMatchedAtBoundary() + throws IOException { + // Test for https://github.com/apache/beam/issues/32249 + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 8190; ++i) { + sb.append('0'); + } + sb.append('A'); // index 8190 + sb.append('B'); // index 8191 + sb.append('C'); // index 8192 + for (int i = 8193; i < 16400; ++i) { + sb.append('0'); + } + + String text = sb.toString(); + String delimiter = "ABCDE"; + + // check the text is not split by the delimiter + assertEquals( + Collections.singletonList(text), + Arrays.asList(Pattern.compile(delimiter, Pattern.LITERAL).split(text))); + + File tmpFile = tempFolder.newFile("tmpfile.txt"); + String filename = tmpFile.getPath(); + + try (Writer writer = Files.newBufferedWriter(tmpFile.toPath(), UTF_8)) { + writer.write(text); + } + + // Expects no IndexOutOfBoundsException + PAssert.that( + p.apply( + TextIO.read() + .from(filename) + .withDelimiter(delimiter.getBytes(StandardCharsets.UTF_8)))) + .containsInAnyOrder(text); + p.run(); + } + @Test @Category(NeedsRunner.class) public void testReadStrings() throws Exception { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextSourceTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextSourceTest.java new file mode 100644 index 000000000000..1cc22445fa7f --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextSourceTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class TextSourceTest { + + @Test + public void testSubstringByteArrayOutputStreamSuccessful() throws IOException { + TextSource.SubstringByteArrayOutputStream output = + new TextSource.SubstringByteArrayOutputStream(); + assertEquals("", output.toString(0, 0, StandardCharsets.UTF_8)); + + output.write("ABC".getBytes(StandardCharsets.UTF_8)); + assertEquals("ABC", output.toString(0, 3, StandardCharsets.UTF_8)); + assertEquals("AB", output.toString(0, 2, StandardCharsets.UTF_8)); + assertEquals("BC", output.toString(1, 2, StandardCharsets.UTF_8)); + assertEquals("", output.toString(3, 0, StandardCharsets.UTF_8)); + + output.write("DE".getBytes(StandardCharsets.UTF_8)); + assertEquals("ABCDE", output.toString(0, 5, StandardCharsets.UTF_8)); + } + + @Test + public void testSubstringByteArrayIllegalArgumentException() throws IOException { + TextSource.SubstringByteArrayOutputStream output = + new TextSource.SubstringByteArrayOutputStream(); + output.write("ABC".getBytes(StandardCharsets.UTF_8)); + + IllegalArgumentException exception; + exception = + assertThrows( + IllegalArgumentException.class, () -> output.toString(-1, 2, StandardCharsets.UTF_8)); + assertThat(exception.getMessage(), containsString("offset is negative")); + + exception = + assertThrows( + IllegalArgumentException.class, () -> output.toString(4, 0, StandardCharsets.UTF_8)); + assertThat(exception.getMessage(), containsString("offset exceeds the buffer limit")); + + exception = + assertThrows( + IllegalArgumentException.class, () -> output.toString(0, -1, StandardCharsets.UTF_8)); + assertThat(exception.getMessage(), containsString("length is negative")); + + exception = + assertThrows( + IllegalArgumentException.class, () -> output.toString(2, 2, StandardCharsets.UTF_8)); + assertThat(exception.getMessage(), containsString("offset + length exceeds the buffer limit")); + } + + @Test + public void testDelimiterFinder() { + // Simple pattern + assertEquals(Arrays.asList("A", "C"), split("AB", "AABC")); + assertEquals(Arrays.asList("A", "B", "C"), split("AB", "AABBABC")); + + // When mismatched at 2 (zero-indexed), the substring "AA" has subsequence "A" + assertEquals(Arrays.asList("A", "C"), split("AAB", "AAABC")); + assertEquals(Arrays.asList("ABAB"), split("AAB", "ABAB")); + assertEquals(Arrays.asList("A", "A", "B"), split("AAB", "AAABAAABB")); + + // The last byte is the same as the first byte. + assertEquals(Arrays.asList("A", "B"), split("AABA", "AAABAB")); + assertEquals(Arrays.asList("AABBA"), split("AABA", "AABBA")); + + // "ABAB" has subsequence "AB". + assertEquals(Arrays.asList("", "D"), split("ABABC", "ABABCD")); + assertEquals(Arrays.asList("ABABAD"), split("ABABC", "ABABAD")); + assertEquals(Arrays.asList("ABABBD"), split("ABABC", "ABABBD")); + assertEquals(Arrays.asList("AB"), split("ABABC", "ABABABC")); + + // "ABCAB" has subsequence "AB". + assertEquals(Arrays.asList(""), split("ABCABD", "ABCABD")); + assertEquals(Arrays.asList("ABC"), split("ABCABD", "ABCABCABD")); + + // Repetition of 3 bytes pattern. + assertEquals(Arrays.asList("AABAAB"), split("AABAAC", "AABAAB")); + assertEquals(Arrays.asList(""), split("AABAAC", "AABAAC")); + assertEquals(Arrays.asList("A", "D"), split("AABAAC", "AAABAACD")); + assertEquals(Arrays.asList("AAB", "D"), split("AABAAC", "AABAABAACD")); + assertEquals(Arrays.asList("AABA", "D"), split("AABAAC", "AABAAABAACD")); + assertEquals(Arrays.asList("AABAA", "D"), split("AABAAC", "AABAAAABAACD")); + + // Same characters repeated 3 times. + assertEquals(Arrays.asList("AAA"), split("AAAA", "AAA")); + assertEquals(Arrays.asList(""), split("AAAA", "AAAA")); + assertEquals(Arrays.asList("AAAB"), split("AAAA", "AAAB")); + assertEquals(Arrays.asList("", "B"), split("AAAA", "AAAAB")); + + // 3 times repeated pattern followed by a different byte. + assertEquals(Arrays.asList(""), split("AAAB", "AAAB")); + assertEquals(Arrays.asList("A", "B"), split("AAAB", "AAAABB")); + assertEquals(Arrays.asList("AA", "B"), split("AAAB", "AAAAABB")); + assertEquals(Arrays.asList("AAA", "B"), split("AAAB", "AAAAAABB")); + assertEquals(Arrays.asList("AAAA", "B"), split("AAAB", "AAAAAAABB")); + + // Multiple empty strings return. + assertEquals(Arrays.asList("", "", ""), split("AA", "AAAAAA")); + assertEquals(Arrays.asList("", "", ""), split("AB", "ABABAB")); + assertEquals(Arrays.asList("", "", ""), split("AAB", "AABAABAAB")); + } + + List split(String delimiter, String text) { + byte[] delimiterBytes = delimiter.getBytes(StandardCharsets.UTF_8); + TextSource.KMPDelimiterFinder finder = new TextSource.KMPDelimiterFinder(delimiterBytes); + + byte[] textBytes = text.getBytes(StandardCharsets.UTF_8); + + List result = new ArrayList<>(); + int start = 0; + for (int i = 0; i < textBytes.length; i++) { + if (finder.feed(textBytes[i])) { + int nextStart = i + 1; + int end = nextStart - delimiterBytes.length; + String s = new String(textBytes, start, end - start, StandardCharsets.UTF_8); + result.add(s); + start = nextStart; + } + } + + if (start != textBytes.length) { + result.add(new String(textBytes, start, textBytes.length - start, StandardCharsets.UTF_8)); + } + return result; + } +}