Skip to content

Commit

Permalink
Don't re-encode byte[] values in SortValues transform (#31025)
Browse files Browse the repository at this point in the history
* Don't re-encode byte[] values in SortValues transform

* checkstyle

* Apply code review comments
  • Loading branch information
clairemcginty authored May 30, 2024
1 parent 4a0849b commit 9f3f1c9
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.io.IOException;
import java.util.Iterator;
import javax.annotation.Nonnull;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.DoFn;
Expand Down Expand Up @@ -131,6 +133,20 @@ private static <PrimaryKeyT, SecondaryKeyT, ValueT> Coder<ValueT> getValueCoder(
return getSecondaryKeyValueCoder(inputCoder).getValueCoder();
}

private static <T> T elementOf(Coder<T> coder, byte[] bytes) throws CoderException {
if (coder instanceof ByteArrayCoder) {
return (T) bytes;
}
return CoderUtils.decodeFromByteArray(coder, bytes);
}

private static <T> byte[] bytesOf(Coder<T> coder, T element) throws CoderException {
if (element instanceof byte[]) {
return (byte[]) element;
}
return CoderUtils.encodeToByteArray(coder, element);
}

private static class SortValuesDoFn<PrimaryKeyT, SecondaryKeyT, ValueT>
extends DoFn<
KV<PrimaryKeyT, Iterable<KV<SecondaryKeyT, ValueT>>>,
Expand All @@ -156,9 +172,7 @@ public void processElement(ProcessContext c) {
Sorter sorter = BufferedExternalSorter.create(sorterOptions);
for (KV<SecondaryKeyT, ValueT> record : records) {
sorter.add(
KV.of(
CoderUtils.encodeToByteArray(keyCoder, record.getKey()),
CoderUtils.encodeToByteArray(valueCoder, record.getValue())));
KV.of(bytesOf(keyCoder, record.getKey()), bytesOf(valueCoder, record.getValue())));
}

c.output(KV.of(c.element().getKey(), new DecodingIterable(sorter.sort())));
Expand Down Expand Up @@ -197,9 +211,9 @@ public boolean hasNext() {
public KV<SecondaryKeyT, ValueT> next() {
KV<byte[], byte[]> next = iterator.next();
try {
return KV.of(
CoderUtils.decodeFromByteArray(keyCoder, next.getKey()),
CoderUtils.decodeFromByteArray(valueCoder, next.getValue()));
SecondaryKeyT secondaryKey = elementOf(keyCoder, next.getKey());
ValueT value = elementOf(valueCoder, next.getValue());
return KV.of(secondaryKey, value);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.is;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
Expand Down Expand Up @@ -67,30 +70,141 @@ public void testSecondaryKeySorting() {
grouped.apply(SortValues.create(BufferedExternalSorter.options()));

PAssert.that(groupedAndSorted)
.satisfies(new AssertThatHasExpectedContentsForTestSecondaryKeySorting());
.satisfies(
new AssertThatHasExpectedContentsForTestSecondaryKeySorting<>(
Arrays.asList(
KV.of(
"key1",
Arrays.asList(
KV.of("secondaryKey1", 10),
KV.of("secondaryKey2", 20),
KV.of("secondaryKey3", 30))),
KV.of(
"key2",
Arrays.asList(KV.of("secondaryKey1", 100), KV.of("secondaryKey2", 200))))));

p.run();
}

@Test
public void testSecondaryKeyByteOptimization() {
PCollection<KV<String, KV<byte[], Integer>>> input =
p.apply(
Create.of(
Arrays.asList(
KV.of("key1", KV.of("secondaryKey2".getBytes(StandardCharsets.UTF_8), 20)),
KV.of("key2", KV.of("secondaryKey2".getBytes(StandardCharsets.UTF_8), 200)),
KV.of("key1", KV.of("secondaryKey3".getBytes(StandardCharsets.UTF_8), 30)),
KV.of("key1", KV.of("secondaryKey1".getBytes(StandardCharsets.UTF_8), 10)),
KV.of("key2", KV.of("secondaryKey1".getBytes(StandardCharsets.UTF_8), 100)))));

// Group by Key, bringing <SecondaryKey, Value> pairs for the same Key together.
PCollection<KV<String, Iterable<KV<byte[], Integer>>>> grouped =
input.apply(GroupByKey.create());

// For every Key, sort the iterable of <SecondaryKey, Value> pairs by SecondaryKey.
PCollection<KV<String, Iterable<KV<byte[], Integer>>>> groupedAndSorted =
grouped.apply(SortValues.create(BufferedExternalSorter.options()));

PAssert.that(groupedAndSorted)
.satisfies(
new AssertThatHasExpectedContentsForTestSecondaryKeySorting<>(
Arrays.asList(
KV.of(
"key1",
Arrays.asList(
KV.of("secondaryKey1".getBytes(StandardCharsets.UTF_8), 10),
KV.of("secondaryKey2".getBytes(StandardCharsets.UTF_8), 20),
KV.of("secondaryKey3".getBytes(StandardCharsets.UTF_8), 30))),
KV.of(
"key2",
Arrays.asList(
KV.of("secondaryKey1".getBytes(StandardCharsets.UTF_8), 100),
KV.of("secondaryKey2".getBytes(StandardCharsets.UTF_8), 200))))));

p.run();
}

@Test
public void testSecondaryKeyAndValueByteOptimization() {
PCollection<KV<String, KV<byte[], byte[]>>> input =
p.apply(
Create.of(
Arrays.asList(
KV.of(
"key1",
KV.of("secondaryKey2".getBytes(StandardCharsets.UTF_8), new byte[] {1})),
KV.of(
"key2",
KV.of("secondaryKey2".getBytes(StandardCharsets.UTF_8), new byte[] {2})),
KV.of(
"key1",
KV.of("secondaryKey3".getBytes(StandardCharsets.UTF_8), new byte[] {3})),
KV.of(
"key1",
KV.of("secondaryKey1".getBytes(StandardCharsets.UTF_8), new byte[] {4})),
KV.of(
"key2",
KV.of("secondaryKey1".getBytes(StandardCharsets.UTF_8), new byte[] {5})))));

// Group by Key, bringing <SecondaryKey, Value> pairs for the same Key together.
PCollection<KV<String, Iterable<KV<byte[], byte[]>>>> grouped =
input.apply(GroupByKey.create());

// For every Key, sort the iterable of <SecondaryKey, Value> pairs by SecondaryKey.
PCollection<KV<String, Iterable<KV<byte[], byte[]>>>> groupedAndSorted =
grouped.apply(SortValues.create(BufferedExternalSorter.options()));

PAssert.that(groupedAndSorted)
.satisfies(
new AssertThatHasExpectedContentsForTestSecondaryKeySorting<>(
Arrays.asList(
KV.of(
"key1",
Arrays.asList(
KV.of("secondaryKey1".getBytes(StandardCharsets.UTF_8), new byte[] {4}),
KV.of("secondaryKey2".getBytes(StandardCharsets.UTF_8), new byte[] {1}),
KV.of(
"secondaryKey3".getBytes(StandardCharsets.UTF_8), new byte[] {3}))),
KV.of(
"key2",
Arrays.asList(
KV.of("secondaryKey1".getBytes(StandardCharsets.UTF_8), new byte[] {5}),
KV.of(
"secondaryKey2".getBytes(StandardCharsets.UTF_8),
new byte[] {2}))))));

p.run();
}

static class AssertThatHasExpectedContentsForTestSecondaryKeySorting
implements SerializableFunction<Iterable<KV<String, Iterable<KV<String, Integer>>>>, Void> {
static class AssertThatHasExpectedContentsForTestSecondaryKeySorting<SecondaryKeyT, ValueT>
implements SerializableFunction<
Iterable<KV<String, Iterable<KV<SecondaryKeyT, ValueT>>>>, Void> {
final List<KV<String, List<KV<SecondaryKeyT, ValueT>>>> expected;

AssertThatHasExpectedContentsForTestSecondaryKeySorting(
List<KV<String, List<KV<SecondaryKeyT, ValueT>>>> expected) {
this.expected = expected;
}

@SuppressWarnings("unchecked")
@Override
public Void apply(Iterable<KV<String, Iterable<KV<String, Integer>>>> actual) {
public Void apply(Iterable<KV<String, Iterable<KV<SecondaryKeyT, ValueT>>>> actual) {
assertThat(
actual,
containsInAnyOrder(
KvMatcher.isKv(
is("key1"),
contains(
KvMatcher.isKv(is("secondaryKey1"), is(10)),
KvMatcher.isKv(is("secondaryKey2"), is(20)),
KvMatcher.isKv(is("secondaryKey3"), is(30)))),
KvMatcher.isKv(
is("key2"),
contains(
KvMatcher.isKv(is("secondaryKey1"), is(100)),
KvMatcher.isKv(is("secondaryKey2"), is(200))))));
expected.stream()
.map(
kv1 ->
KvMatcher.isKv(
is(kv1.getKey()),
contains(
kv1.getValue().stream()
.map(
kv2 ->
KvMatcher.isKv(is(kv2.getKey()), is(kv2.getValue())))
.collect(Collectors.toList()))))
.collect(Collectors.toList())));
return null;
}
}
Expand Down

0 comments on commit 9f3f1c9

Please sign in to comment.