Skip to content

Commit

Permalink
Fix RowCoderGenerator to use the encodingPositions when encoding and …
Browse files Browse the repository at this point in the history
…decoding the bit set representing null fields.
  • Loading branch information
scwhittle committed Sep 11, 2024
1 parent ebcb2db commit a52564a
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;

/** A sub-class of SchemaCoder that can only encode {@link Row} instances. */
Expand All @@ -35,7 +36,12 @@ public static RowCoder of(Schema schema) {

/** Override encoding positions for the given schema. */
public static void overrideEncodingPositions(UUID uuid, Map<String, Integer> encodingPositions) {
SchemaCoder.overrideEncodingPositions(uuid, encodingPositions);
RowCoderGenerator.overrideEncodingPositions(uuid, encodingPositions);
}

@VisibleForTesting
static void clearGeneratedRowCoders() {
RowCoderGenerator.clearRowCoderCache();
}

private RowCoder(Schema schema) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Map;
import java.util.UUID;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import net.bytebuddy.ByteBuddy;
import net.bytebuddy.description.modifier.FieldManifestation;
import net.bytebuddy.description.modifier.Ownership;
Expand All @@ -53,10 +54,14 @@
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.util.StringUtils;
import org.apache.beam.sdk.util.common.ReflectHelpers;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A utility for automatically generating a {@link Coder} for {@link Row} objects corresponding to a
Expand Down Expand Up @@ -109,30 +114,113 @@ public abstract class RowCoderGenerator {
private static final String CODERS_FIELD_NAME = "FIELD_CODERS";
private static final String POSITIONS_FIELD_NAME = "FIELD_ENCODING_POSITIONS";

static class WithStackTrace<T> {
private final T value;
private final String stackTrace;

public WithStackTrace(T value, String stackTrace) {
this.value = value;
this.stackTrace = stackTrace;
}

public T getValue() {
return value;
}

public String getStackTrace() {
return stackTrace;
}
}

// Cache for Coder class that are already generated.
private static final Map<UUID, Coder<Row>> GENERATED_CODERS = Maps.newConcurrentMap();
private static final Map<UUID, Map<String, Integer>> ENCODING_POSITION_OVERRIDES =
Maps.newConcurrentMap();
@GuardedBy("cacheLock")
private static final Map<UUID, WithStackTrace<Coder<Row>>> GENERATED_CODERS = Maps.newHashMap();

@GuardedBy("cacheLock")
private static final Map<UUID, WithStackTrace<Map<String, Integer>>> ENCODING_POSITION_OVERRIDES =
Maps.newHashMap();

private static final Object cacheLock = new Object();

private static final Logger LOG = LoggerFactory.getLogger(RowCoderGenerator.class);

private static String getStackTrace() {
return StringUtils.arrayToNewlines(Thread.currentThread().getStackTrace(), 10);
}

public static void overrideEncodingPositions(UUID uuid, Map<String, Integer> encodingPositions) {
ENCODING_POSITION_OVERRIDES.put(uuid, encodingPositions);
final String stackTrace = getStackTrace();
synchronized (cacheLock) {
@Nullable
WithStackTrace<Map<String, Integer>> previousEncodingPositions =
ENCODING_POSITION_OVERRIDES.put(
uuid, new WithStackTrace<>(encodingPositions, stackTrace));
@Nullable WithStackTrace<Coder<Row>> existingCoder = GENERATED_CODERS.get(uuid);
if (previousEncodingPositions == null) {
if (existingCoder != null) {
LOG.error(
"Received encoding positions for uuid {} too late after creating RowCoder. Created: {}\n Override: {}",
uuid,
existingCoder.getStackTrace(),
stackTrace);
} else {
LOG.info("Received encoding positions {} for uuid {}.", encodingPositions, uuid);
}
} else if (!previousEncodingPositions.getValue().equals(encodingPositions)) {
if (existingCoder == null) {
LOG.error(
"Received differing encoding positions for uuid {} before coder creation. Was {} at {}\n Now {} at {}",
uuid,
previousEncodingPositions.getValue(),
encodingPositions,
previousEncodingPositions.getStackTrace(),
stackTrace);
} else {
LOG.error(
"Received differing encoding positions for uuid {} after coder creation at {}\n. "
+ "Was {} at {}\n Now {} at {}\n",
uuid,
existingCoder.getStackTrace(),
previousEncodingPositions.getValue(),
encodingPositions,
previousEncodingPositions.getStackTrace(),
stackTrace);
}
}
}
}

@VisibleForTesting
static void clearRowCoderCache() {
synchronized (cacheLock) {
GENERATED_CODERS.clear();
}
}

@SuppressWarnings("unchecked")
public static Coder<Row> generate(Schema schema) {
// Using ConcurrentHashMap::computeIfAbsent here would deadlock in case of nested
// coders. Using HashMap::computeIfAbsent generates ConcurrentModificationExceptions in Java 11.
Coder<Row> rowCoder = GENERATED_CODERS.get(schema.getUUID());
if (rowCoder == null) {
String stackTrace = getStackTrace();
UUID uuid = Preconditions.checkNotNull(schema.getUUID());
// Avoid using computeIfAbsent which may cause issues with nested schemas.
synchronized (cacheLock) {
@Nullable WithStackTrace<Coder<Row>> existingRowCoder = GENERATED_CODERS.get(uuid);
if (existingRowCoder != null) {
return existingRowCoder.getValue();
}
TypeDescription.Generic coderType =
TypeDescription.Generic.Builder.parameterizedType(Coder.class, Row.class).build();
DynamicType.Builder<Coder> builder =
(DynamicType.Builder<Coder>) BYTE_BUDDY.subclass(coderType);
builder = implementMethods(schema, builder);

int[] encodingPosToRowIndex = new int[schema.getFieldCount()];
@Nullable
WithStackTrace<Map<String, Integer>> existingEncodingPositions =
ENCODING_POSITION_OVERRIDES.get(uuid);
Map<String, Integer> encodingPositions =
ENCODING_POSITION_OVERRIDES.getOrDefault(schema.getUUID(), schema.getEncodingPositions());
existingEncodingPositions == null
? schema.getEncodingPositions()
: existingEncodingPositions.getValue();
for (int recordIndex = 0; recordIndex < schema.getFieldCount(); ++recordIndex) {
String name = schema.getField(recordIndex).getName();
int encodingPosition = encodingPositions.get(name);
Expand Down Expand Up @@ -163,6 +251,7 @@ public static Coder<Row> generate(Schema schema) {
.withParameters(Coder[].class, int[].class)
.intercept(new GeneratedCoderConstructor());

Coder<Row> rowCoder;
try {
rowCoder =
builder
Expand All @@ -179,9 +268,14 @@ public static Coder<Row> generate(Schema schema) {
| InvocationTargetException e) {
throw new RuntimeException("Unable to generate coder for schema " + schema, e);
}
GENERATED_CODERS.put(schema.getUUID(), rowCoder);
GENERATED_CODERS.put(uuid, new WithStackTrace<>(rowCoder, stackTrace));
LOG.debug(
"Created row coder for uuid {} with encoding positions {} at {}",
uuid,
encodingPositions,
stackTrace);
return rowCoder;
}
return rowCoder;
}

private static class GeneratedCoderConstructor implements Implementation {
Expand Down Expand Up @@ -326,7 +420,7 @@ static void encodeDelegate(
}

// Encode a bitmap for the null fields to save having to encode a bunch of nulls.
NULL_LIST_CODER.encode(scanNullFields(fieldValues), outputStream);
NULL_LIST_CODER.encode(scanNullFields(fieldValues, encodingPosToIndex), outputStream);
for (int encodingPos = 0; encodingPos < fieldValues.length; ++encodingPos) {
@Nullable Object fieldValue = fieldValues[encodingPosToIndex[encodingPos]];
if (fieldValue != null) {
Expand All @@ -348,14 +442,15 @@ static void encodeDelegate(

// Figure out which fields of the Row are null, and returns a BitSet. This allows us to save
// on encoding each null field separately.
private static BitSet scanNullFields(Object[] fieldValues) {
private static BitSet scanNullFields(Object[] fieldValues, int[] encodingPosToIndex) {
Preconditions.checkState(fieldValues.length == encodingPosToIndex.length);
BitSet nullFields = new BitSet(fieldValues.length);
for (int idx = 0; idx < fieldValues.length; ++idx) {
if (fieldValues[idx] == null) {
nullFields.set(idx);
for (int encodingPos = 0; encodingPos < encodingPosToIndex.length; ++encodingPos) {
int fieldIndex = encodingPosToIndex[encodingPos];
if (fieldValues[fieldIndex] == null) {
nullFields.set(encodingPos);
}
}

return nullFields;
}
}
Expand Down Expand Up @@ -425,7 +520,7 @@ static Row decodeDelegate(
// in which case we drop the extra fields.
if (encodingPos < coders.length) {
int rowIndex = encodingPosToIndex[encodingPos];
if (nullFields.get(rowIndex)) {
if (nullFields.get(encodingPos)) {
fieldValues[rowIndex] = null;
} else {
Object fieldValue = coders[encodingPos].decode(inputStream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ public String toString() {
}

// Sets the schema id, and then recursively ensures that all schemas have ids set.
private static void setSchemaIds(Schema schema) {
private static void setSchemaIds(@Nullable Schema schema) {
if (schema == null) {
return;
}
if (schema.getUUID() == null) {
schema.setUUID(UUID.randomUUID());
}
Expand All @@ -187,7 +190,7 @@ private static void setSchemaIds(FieldType fieldType) {
return;

case ARRAY:
case ITERABLE:;
case ITERABLE:
setSchemaIds(fieldType.getCollectionElementType());
return;

Expand Down
Loading

0 comments on commit a52564a

Please sign in to comment.