From 873e0d0978d93ccf628de7889974df2f6824203d Mon Sep 17 00:00:00 2001 From: Alexey Romanenko Date: Fri, 18 Nov 2022 18:27:52 +0100 Subject: [PATCH] [#24292] Create Avro extension for Java SDK --- .../apache/beam/sdk/io/BlockBasedSource.java | 4 +- .../beam/sdk/io/DefaultFilenamePolicy.java | 2 +- .../sdk/io/ReadAllViaFileBasedSource.java | 2 +- sdks/java/extensions/avro/build.gradle | 56 + .../sdk/extensions/avro/coders/AvroCoder.java | 820 +++++++ .../avro/coders/AvroGenericCoder.java | 32 + .../extensions/avro/coders/package-info.java | 29 + .../beam/sdk/extensions/avro/io/AvroIO.java | 2026 +++++++++++++++++ .../avro/io/AvroSchemaIOProvider.java | 150 ++ .../beam/sdk/extensions/avro/io/AvroSink.java | 151 ++ .../sdk/extensions/avro/io/AvroSource.java | 777 +++++++ .../avro/io/ConstantAvroDestination.java | 148 ++ .../avro/io/DynamicAvroDestinations.java | 56 + .../avro/io/SerializableAvroCodecFactory.java | 112 + .../sdk/extensions/avro/io/package-info.java | 26 + .../avro/schemas/AvroRecordSchema.java | 64 + .../AvroPayloadSerializerProvider.java | 44 + .../schemas/io/payloads/package-info.java | 27 + .../extensions/avro/schemas/package-info.java | 29 + .../schemas/utils/AvroByteBuddyUtils.java | 142 ++ .../avro/schemas/utils/AvroUtils.java | 1341 +++++++++++ .../avro/schemas/utils/package-info.java | 26 + .../beam/sdk/extensions/avro/io/user.avsc | 10 + .../sdk/extensions/avro/schemas/test.avsc | 30 + .../extensions/avro/coders/AvroCoderTest.java | 1108 +++++++++ .../avro/coders/AvroCoderTestPojo.java | 51 + .../sdk/extensions/avro/io/AvroIOTest.java | 1587 +++++++++++++ .../avro/io/AvroSchemaIOProviderTest.java | 174 ++ .../extensions/avro/io/AvroSourceTest.java | 846 +++++++ .../io/SerializableAvroCodecFactoryTest.java | 93 + .../avro/schemas/AvroSchemaTest.java | 508 +++++ .../io/AvroPayloadSerializerProviderTest.java | 64 + .../avro/schemas/utils/AvroGenerators.java | 220 ++ .../avro/schemas/utils/AvroUtilsTest.java | 895 ++++++++ settings.gradle.kts | 1 + 35 files changed, 11647 insertions(+), 4 deletions(-) create mode 100644 sdks/java/extensions/avro/build.gradle create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoder.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/AvroGenericCoder.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/package-info.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSchemaIOProvider.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSink.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSource.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/ConstantAvroDestination.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/DynamicAvroDestinations.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/SerializableAvroCodecFactory.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/package-info.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/io/payloads/AvroPayloadSerializerProvider.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/io/payloads/package-info.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/package-info.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java create mode 100644 sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/package-info.java create mode 100644 sdks/java/extensions/avro/src/test/avro/org/apache/beam/sdk/extensions/avro/io/user.avsc create mode 100644 sdks/java/extensions/avro/src/test/avro/org/apache/beam/sdk/extensions/avro/schemas/test.avsc create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoderTest.java create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoderTestPojo.java create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroIOTest.java create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroSchemaIOProviderTest.java create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroSourceTest.java create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/SerializableAvroCodecFactoryTest.java create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/AvroSchemaTest.java create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/io/AvroPayloadSerializerProviderTest.java create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroGenerators.java create mode 100644 sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtilsTest.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java index 1f8501b571e0..e2c626228034 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/BlockBasedSource.java @@ -140,7 +140,7 @@ protected abstract static class Block { * byte of the block is within the range {@code [start, end)}. */ @Experimental(Kind.SOURCE_SINK) - protected abstract static class BlockBasedReader extends FileBasedReader { + public abstract static class BlockBasedReader extends FileBasedReader { private boolean atSplitPoint; protected BlockBasedReader(BlockBasedSource source) { @@ -195,7 +195,7 @@ public final T getCurrent() throws NoSuchElementException { * block boundaries. */ @Override - protected boolean isAtSplitPoint() { + public boolean isAtSplitPoint() { return atSplitPoint; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java index 7556c32d2a65..5803f450aeaa 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/DefaultFilenamePolicy.java @@ -269,7 +269,7 @@ public static DefaultFilenamePolicy fromParams(Params params) { * ".txt", with shardNum = 1 and numShards = 100, the following is produced: * "path/to/output-001-of-100.txt". */ - static ResourceId constructName( + public static ResourceId constructName( ResourceId baseFilename, String shardTemplate, String suffix, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSource.java index 82eca9193fbf..35819b60ebf9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSource.java @@ -47,7 +47,7 @@ public class ReadAllViaFileBasedSource extends PTransform, PCollection> { - protected static final boolean DEFAULT_USES_RESHUFFLE = true; + public static final boolean DEFAULT_USES_RESHUFFLE = true; private final long desiredBundleSizeBytes; private final SerializableFunction> createSource; private final Coder coder; diff --git a/sdks/java/extensions/avro/build.gradle b/sdks/java/extensions/avro/build.gradle new file mode 100644 index 000000000000..dae13cd99728 --- /dev/null +++ b/sdks/java/extensions/avro/build.gradle @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { id 'org.apache.beam.module' } +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.extensions.avro', + disableLintWarnings: ['rawtypes'], // Avro-generated test code has raw-type errors + publish: false, + exportJavadoc: false, +) +applyAvroNature() + +description = "Apache Beam :: SDKs :: Java :: Extensions :: Avro" + +// Exclude tests that need a runner +test { + systemProperty "beamUseDummyRunner", "true" + useJUnit { + excludeCategories "org.apache.beam.sdk.testing.NeedsRunner" + } +} + +dependencies { + implementation library.java.byte_buddy + implementation library.java.vendored_guava_26_0_jre + implementation (project(path: ":sdks:java:core", configuration: "shadow")) { + // Exclude Avro dependencies from "core" since Avro support moved to this extension + exclude group: "org.apache.avro", module: "avro" + } + implementation library.java.error_prone_annotations + implementation library.java.avro + implementation library.java.joda_time + testImplementation (project(path: ":sdks:java:core", configuration: "shadowTest")) { + // Exclude Avro dependencies from "core" since Avro support moved to this extension + exclude group: "org.apache.avro", module: "avro" + } + testImplementation library.java.avro_tests + testImplementation library.java.junit + testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") + testRuntimeOnly library.java.slf4j_jdk14 +} \ No newline at end of file diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoder.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoder.java new file mode 100644 index 000000000000..4687eb566424 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoder.java @@ -0,0 +1,820 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.coders; + +import com.google.errorprone.annotations.FormatMethod; +import com.google.errorprone.annotations.FormatString; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import org.apache.avro.AvroRuntimeException; +import org.apache.avro.Conversion; +import org.apache.avro.LogicalType; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.IndexedRecord; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.BinaryEncoder; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.EncoderFactory; +import org.apache.avro.reflect.AvroEncode; +import org.apache.avro.reflect.AvroName; +import org.apache.avro.reflect.AvroSchema; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumReader; +import org.apache.avro.reflect.ReflectDatumWriter; +import org.apache.avro.reflect.Union; +import org.apache.avro.specific.SpecificData; +import org.apache.avro.specific.SpecificDatumReader; +import org.apache.avro.specific.SpecificDatumWriter; +import org.apache.avro.specific.SpecificRecord; +import org.apache.avro.util.ClassUtils; +import org.apache.avro.util.Utf8; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderProvider; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.DefaultCoder; +import org.apache.beam.sdk.util.EmptyOnDeserializationThreadLocal; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; + +/** + * A {@link Coder} using Avro binary format. + * + *

Each instance of {@code AvroCoder} encapsulates an Avro schema for objects of type {@code + * T}. + * + *

The Avro schema may be provided explicitly via {@link AvroCoder#of(Class, Schema)} or omitted + * via {@link AvroCoder#of(Class)}, in which case it will be inferred using Avro's {@link + * ReflectData}. + * + *

For complete details about schema generation and how it can be controlled please see the + * {@link org.apache.avro.reflect} package. Only concrete classes with a no-argument constructor can + * be mapped to Avro records. All inherited fields that are not static or transient are included. + * Fields are not permitted to be null unless annotated by {@link Nullable} or a {@link Union} + * schema containing {@code "null"}. + * + *

To use, specify the {@code Coder} type on a PCollection: + * + *

{@code
+ * PCollection records =
+ *     input.apply(...)
+ *          .setCoder(AvroCoder.of(MyCustomElement.class));
+ * }
+ * + *

or annotate the element class using {@code @DefaultCoder}. + * + *

{@code @DefaultCoder(AvroCoder.class)
+ * public class MyCustomElement {
+ *     ...
+ * }
+ * }
+ * + *

The implementation attempts to determine if the Avro encoding of the given type will satisfy + * the criteria of {@link Coder#verifyDeterministic} by inspecting both the type and the Schema + * provided or generated by Avro. Only coders that are deterministic can be used in {@link + * org.apache.beam.sdk.transforms.GroupByKey} operations. + * + * @param the type of elements handled by this coder + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class AvroCoder extends CustomCoder { + + /** + * Returns an {@code AvroCoder} instance for the provided element type. + * + * @param the element type + */ + public static AvroCoder of(TypeDescriptor type) { + return of(type, true); + } + + /** + * Returns an {@code AvroCoder} instance for the provided element type, respecting whether to use + * Avro's Reflect* or Specific* suite for encoding and decoding. + * + * @param the element type + */ + public static AvroCoder of(TypeDescriptor type, boolean useReflectApi) { + @SuppressWarnings("unchecked") + Class clazz = (Class) type.getRawType(); + return of(clazz, useReflectApi); + } + + /** + * Returns an {@code AvroCoder} instance for the provided element class. + * + * @param the element type + */ + public static AvroCoder of(Class clazz) { + return of(clazz, true); + } + + /** + * Returns an {@code AvroGenericCoder} instance for the Avro schema. The implicit type is + * GenericRecord. + */ + public static AvroGenericCoder of(Schema schema) { + return AvroGenericCoder.of(schema); + } + + /** + * Returns an {@code AvroCoder} instance for the given class, respecting whether to use Avro's + * Reflect* or Specific* suite for encoding and decoding. + * + * @param the element type + */ + public static AvroCoder of(Class type, boolean useReflectApi) { + ClassLoader cl = type.getClassLoader(); + SpecificData data = useReflectApi ? new ReflectData(cl) : new SpecificData(cl); + return of(type, data.getSchema(type), useReflectApi); + } + + /** + * Returns an {@code AvroCoder} instance for the provided element type using the provided Avro + * schema. + * + *

The schema must correspond to the type provided. + * + * @param the element type + */ + public static AvroCoder of(Class type, Schema schema) { + return of(type, schema, true); + } + + /** + * Returns an {@code AvroCoder} instance for the given class and schema, respecting whether to use + * Avro's Reflect* or Specific* suite for encoding and decoding. + * + * @param the element type + */ + public static AvroCoder of(Class type, Schema schema, boolean useReflectApi) { + return new AvroCoder<>(type, schema, useReflectApi); + } + + /** + * Returns a {@link CoderProvider} which uses the {@link AvroCoder} if possible for all types. + * + *

It is unsafe to register this as a {@link CoderProvider} because Avro will reflectively + * accept dangerous types such as {@link Object}. + * + *

This method is invoked reflectively from {@link DefaultCoder}. + */ + @SuppressWarnings("unused") + public static CoderProvider getCoderProvider() { + return new AvroCoderProvider(); + } + + /** + * A {@link CoderProvider} that constructs an {@link AvroCoder} for Avro compatible classes. + * + *

It is unsafe to register this as a {@link CoderProvider} because Avro will reflectively + * accept dangerous types such as {@link Object}. + */ + static class AvroCoderProvider extends CoderProvider { + @Override + public Coder coderFor( + TypeDescriptor typeDescriptor, List> componentCoders) + throws CannotProvideCoderException { + try { + return AvroCoder.of(typeDescriptor); + } catch (AvroRuntimeException e) { + throw new CannotProvideCoderException( + String.format("%s is not compatible with Avro", typeDescriptor), e); + } + } + } + + private final Class type; + private final boolean useReflectApi; + private final SerializableSchemaSupplier schemaSupplier; + private final TypeDescriptor typeDescriptor; + + private final List nonDeterministicReasons; + + // Factories allocated by .get() are thread-safe and immutable. + private static final EncoderFactory ENCODER_FACTORY = EncoderFactory.get(); + private static final DecoderFactory DECODER_FACTORY = DecoderFactory.get(); + + /** + * A {@link Serializable} object that holds the {@link String} version of a {@link Schema}. This + * is paired with the {@link SerializableSchemaSupplier} via {@link Serializable}'s usage of the + * {@link #readResolve} method. + */ + private static class SerializableSchemaString implements Serializable { + private final String schema; + + private SerializableSchemaString(String schema) { + this.schema = schema; + } + + private Object readResolve() throws IOException, ClassNotFoundException { + return new SerializableSchemaSupplier(new Schema.Parser().parse(schema)); + } + } + + /** + * A {@link Serializable} object that delegates to the {@link SerializableSchemaString} via {@link + * Serializable}'s usage of the {@link #writeReplace} method. Kryo doesn't utilize Java's + * serialization and hence is able to encode the {@link Schema} object directly. + */ + private static class SerializableSchemaSupplier implements Serializable, Supplier { + // writeReplace makes this object serializable. This is a limitation of FindBugs as discussed + // here: + // http://stackoverflow.com/questions/26156523/is-writeobject-not-neccesary-using-the-serialization-proxy-pattern + @SuppressFBWarnings("SE_BAD_FIELD") + private final Schema schema; + + private SerializableSchemaSupplier(Schema schema) { + this.schema = schema; + } + + private Object writeReplace() { + return new SerializableSchemaString(schema.toString()); + } + + @Override + public Schema get() { + return schema; + } + } + + /** + * A {@link Serializable} object that lazily supplies a {@link ReflectData} built from the + * appropriate {@link ClassLoader} for the type encoded by this {@link AvroCoder}. + */ + private static class SerializableReflectDataSupplier + implements Serializable, Supplier { + + private final Class clazz; + + private SerializableReflectDataSupplier(Class clazz) { + this.clazz = clazz; + } + + @Override + public ReflectData get() { + ReflectData reflectData = new ReflectData(clazz.getClassLoader()); + reflectData.addLogicalTypeConversion(new JodaTimestampConversion()); + return reflectData; + } + } + + // Cache the old encoder/decoder and let the factories reuse them when possible. To be threadsafe, + // these are ThreadLocal. This code does not need to be re-entrant as AvroCoder does not use + // an inner coder. + private final EmptyOnDeserializationThreadLocal decoder; + private final EmptyOnDeserializationThreadLocal encoder; + private final EmptyOnDeserializationThreadLocal> writer; + private final EmptyOnDeserializationThreadLocal> reader; + + // Lazily re-instantiated after deserialization + private final Supplier reflectData; + + protected AvroCoder(Class type, Schema schema) { + this(type, schema, false); + } + + protected AvroCoder(Class type, Schema schema, boolean useReflectApi) { + this.type = type; + this.useReflectApi = useReflectApi; + this.schemaSupplier = new SerializableSchemaSupplier(schema); + typeDescriptor = TypeDescriptor.of(type); + nonDeterministicReasons = new AvroDeterminismChecker().check(TypeDescriptor.of(type), schema); + + // Decoder and Encoder start off null for each thread. They are allocated and potentially + // reused inside encode/decode. + this.decoder = new EmptyOnDeserializationThreadLocal<>(); + this.encoder = new EmptyOnDeserializationThreadLocal<>(); + + this.reflectData = Suppliers.memoize(new SerializableReflectDataSupplier(getType())); + + // Reader and writer are allocated once per thread per Coder + this.reader = + new EmptyOnDeserializationThreadLocal>() { + private final AvroCoder myCoder = AvroCoder.this; + + @Override + public DatumReader initialValue() { + if (myCoder.getType().equals(GenericRecord.class)) { + return new GenericDatumReader<>(myCoder.getSchema()); + } else if (SpecificRecord.class.isAssignableFrom(myCoder.getType()) && !useReflectApi) { + return new SpecificDatumReader<>(myCoder.getType()); + } + return new ReflectDatumReader<>( + myCoder.getSchema(), myCoder.getSchema(), myCoder.reflectData.get()); + } + }; + + this.writer = + new EmptyOnDeserializationThreadLocal>() { + private final AvroCoder myCoder = AvroCoder.this; + + @Override + public DatumWriter initialValue() { + if (myCoder.getType().equals(GenericRecord.class)) { + return new GenericDatumWriter<>(myCoder.getSchema()); + } else if (SpecificRecord.class.isAssignableFrom(myCoder.getType()) && !useReflectApi) { + return new SpecificDatumWriter<>(myCoder.getType()); + } + return new ReflectDatumWriter<>(myCoder.getSchema(), myCoder.reflectData.get()); + } + }; + } + + /** Returns the type this coder encodes/decodes. */ + public Class getType() { + return type; + } + + public boolean useReflectApi() { + return useReflectApi; + } + + @Override + public void encode(T value, OutputStream outStream) throws IOException { + // Get a BinaryEncoder instance from the ThreadLocal cache and attempt to reuse it. + BinaryEncoder encoderInstance = ENCODER_FACTORY.directBinaryEncoder(outStream, encoder.get()); + // Save the potentially-new instance for reuse later. + encoder.set(encoderInstance); + writer.get().write(value, encoderInstance); + // Direct binary encoder does not buffer any data and need not be flushed. + } + + @Override + public T decode(InputStream inStream) throws IOException { + // Get a BinaryDecoder instance from the ThreadLocal cache and attempt to reuse it. + BinaryDecoder decoderInstance = DECODER_FACTORY.directBinaryDecoder(inStream, decoder.get()); + // Save the potentially-new instance for later. + decoder.set(decoderInstance); + return reader.get().read(null, decoderInstance); + } + + /** + * @throws NonDeterministicException when the type may not be deterministically encoded using the + * given {@link Schema}, the {@code directBinaryEncoder}, and the {@link ReflectDatumWriter} + * or {@link GenericDatumWriter}. + */ + @Override + public void verifyDeterministic() throws NonDeterministicException { + if (!nonDeterministicReasons.isEmpty()) { + throw new NonDeterministicException(this, nonDeterministicReasons); + } + } + + /** Returns the schema used by this coder. */ + public Schema getSchema() { + return schemaSupplier.get(); + } + + @Override + public TypeDescriptor getEncodedTypeDescriptor() { + return typeDescriptor; + } + + /** + * Helper class encapsulating the various pieces of state maintained by the recursive walk used + * for checking if the encoding will be deterministic. + */ + private static class AvroDeterminismChecker { + + // Reasons that the original type are not deterministic. This accumulates + // the actual output. + private List reasons = new ArrayList<>(); + + // Types that are currently "open". Used to make sure we don't have any + // recursive types. Note that we assume that all occurrences of a given type + // are equal, rather than tracking pairs of type + schema. + private Set> activeTypes = new HashSet<>(); + + // Similarly to how we record active types, we record the schemas we visit + // to make sure we don't encounter recursive fields. + private Set activeSchemas = new HashSet<>(); + + /** Report an error in the current context. */ + @FormatMethod + private void reportError(String context, @FormatString String fmt, Object... args) { + String message = String.format(fmt, args); + reasons.add(context + ": " + message); + } + + /** + * Classes that are serialized by Avro as a String include + * + *

    + *
  • Subtypes of CharSequence (including String, Avro's mutable Utf8, etc.) + *
  • Several predefined classes (BigDecimal, BigInteger, URI, URL) + *
  • Classes annotated with @Stringable (uses their #toString() and a String constructor) + *
+ * + *

Rather than determine which of these cases are deterministic, we list some classes that + * definitely are, and treat any others as non-deterministic. + */ + private static final Set> DETERMINISTIC_STRINGABLE_CLASSES = new HashSet<>(); + + static { + // CharSequences: + DETERMINISTIC_STRINGABLE_CLASSES.add(String.class); + DETERMINISTIC_STRINGABLE_CLASSES.add(Utf8.class); + + // Explicitly Stringable: + DETERMINISTIC_STRINGABLE_CLASSES.add(java.math.BigDecimal.class); + DETERMINISTIC_STRINGABLE_CLASSES.add(java.math.BigInteger.class); + DETERMINISTIC_STRINGABLE_CLASSES.add(java.net.URI.class); + DETERMINISTIC_STRINGABLE_CLASSES.add(java.net.URL.class); + + // Classes annotated with @Stringable: + } + + /** Return true if the given type token is a subtype of *any* of the listed parents. */ + private static boolean isSubtypeOf(TypeDescriptor type, Class... parents) { + for (Class parent : parents) { + if (type.isSubtypeOf(TypeDescriptor.of(parent))) { + return true; + } + } + return false; + } + + protected AvroDeterminismChecker() {} + + // The entry point for the check. Should not be recursively called. + public List check(TypeDescriptor type, Schema schema) { + recurse(type.getRawType().getName(), type, schema); + return reasons; + } + + // This is the method that should be recursively called. It sets up the path + // and visited types correctly. + private void recurse(String context, TypeDescriptor type, Schema schema) { + if (type.getRawType().isAnnotationPresent(AvroSchema.class)) { + reportError(context, "Custom schemas are not supported -- remove @AvroSchema."); + return; + } + + if (!activeTypes.add(type)) { + reportError(context, "%s appears recursively", type); + return; + } + + // If the record isn't a true class, but rather a GenericRecord, SpecificRecord, etc. + // with a specified schema, then we need to make the decision based on the generated + // implementations. + if (isSubtypeOf(type, IndexedRecord.class)) { + checkIndexedRecord(context, schema, null); + } else { + doCheck(context, type, schema); + } + + activeTypes.remove(type); + } + + private void doCheck(String context, TypeDescriptor type, Schema schema) { + switch (schema.getType()) { + case ARRAY: + checkArray(context, type, schema); + break; + case ENUM: + // Enums should be deterministic, since they depend only on the ordinal. + break; + case FIXED: + // Depending on the implementation of GenericFixed, we don't know how + // the given field will be encoded. So, we assume that it isn't + // deterministic. + reportError(context, "FIXED encodings are not guaranteed to be deterministic"); + break; + case MAP: + checkMap(context, type, schema); + break; + case RECORD: + if (!(type.getType() instanceof Class)) { + reportError(context, "Cannot determine type from generic %s due to erasure", type); + return; + } + checkRecord(type, schema); + break; + case UNION: + checkUnion(context, type, schema); + break; + case STRING: + checkString(context, type); + break; + case BOOLEAN: + case BYTES: + case DOUBLE: + case INT: + case FLOAT: + case LONG: + case NULL: + // For types that Avro encodes using one of the above primitives, we assume they are + // deterministic. + break; + default: + // In any other case (eg., new types added to Avro) we cautiously return + // false. + reportError(context, "Unknown schema type %s may be non-deterministic", schema.getType()); + break; + } + } + + private void checkString(String context, TypeDescriptor type) { + // For types that are encoded as strings, we need to make sure they're in an approved + // list. For other types that are annotated @Stringable, Avro will just use the + // #toString() methods, which has no guarantees of determinism. + if (!DETERMINISTIC_STRINGABLE_CLASSES.contains(type.getRawType())) { + reportError(context, "%s may not have deterministic #toString()", type); + } + } + + private static final Schema AVRO_NULL_SCHEMA = Schema.create(Schema.Type.NULL); + + private void checkUnion(String context, TypeDescriptor type, Schema schema) { + final List unionTypes = schema.getTypes(); + + if (!type.getRawType().isAnnotationPresent(Union.class)) { + // First check for @Nullable field, which shows up as a union of field type and null. + if (unionTypes.size() == 2 && unionTypes.contains(AVRO_NULL_SCHEMA)) { + // Find the Schema that is not NULL and recursively check that it is deterministic. + Schema nullableFieldSchema = + unionTypes.get(0).equals(AVRO_NULL_SCHEMA) ? unionTypes.get(1) : unionTypes.get(0); + doCheck(context, type, nullableFieldSchema); + return; + } + + // Otherwise report a schema error. + reportError(context, "Expected type %s to have @Union annotation", type); + return; + } + + // Errors associated with this union will use the base class as their context. + String baseClassContext = type.getRawType().getName(); + + // For a union, we need to make sure that each possible instantiation is deterministic. + for (Schema concrete : unionTypes) { + @SuppressWarnings("unchecked") + TypeDescriptor unionType = TypeDescriptor.of(ReflectData.get().getClass(concrete)); + + recurse(baseClassContext, unionType, concrete); + } + } + + private void checkRecord(TypeDescriptor type, Schema schema) { + // For a record, we want to make sure that all the fields are deterministic. + Class clazz = type.getRawType(); + for (Schema.Field fieldSchema : schema.getFields()) { + Field field = getField(clazz, fieldSchema.name()); + String fieldContext = field.getDeclaringClass().getName() + "#" + field.getName(); + + if (field.isAnnotationPresent(AvroEncode.class)) { + reportError( + fieldContext, "Custom encoders may be non-deterministic -- remove @AvroEncode"); + continue; + } + + if (!IndexedRecord.class.isAssignableFrom(field.getType()) + && field.isAnnotationPresent(AvroSchema.class)) { + // TODO: We should be able to support custom schemas on POJO fields, but we shouldn't + // need to, so we just allow it in the case of IndexedRecords. + reportError( + fieldContext, "Custom schemas are only supported for subtypes of IndexedRecord."); + continue; + } + + TypeDescriptor fieldType = type.resolveType(field.getGenericType()); + recurse(fieldContext, fieldType, fieldSchema.schema()); + } + } + + private void checkIndexedRecord( + String context, Schema schema, @Nullable String specificClassStr) { + + if (!activeSchemas.add(schema)) { + reportError(context, "%s appears recursively", schema.getName()); + return; + } + + switch (schema.getType()) { + case ARRAY: + // Generic Records use GenericData.Array to implement arrays, which is + // essentially an ArrayList, and therefore ordering is deterministic. + // The array is thus deterministic if the elements are deterministic. + checkIndexedRecord(context, schema.getElementType(), null); + break; + case ENUM: + // Enums are deterministic because they encode as a single integer. + break; + case FIXED: + // In the case of GenericRecords, FIXED is deterministic because it + // encodes/decodes as a Byte[]. + break; + case MAP: + reportError( + context, + "GenericRecord and SpecificRecords use a HashMap to represent MAPs," + + " so it is non-deterministic"); + break; + case RECORD: + for (Schema.Field field : schema.getFields()) { + checkIndexedRecord( + schema.getName() + "." + field.name(), + field.schema(), + field.getProp(SpecificData.CLASS_PROP)); + } + break; + case STRING: + // GenericDatumWriter#findStringClass will use a CharSequence or a String + // for each string, so it is deterministic. + + // SpecificCompiler#getStringType will use java.lang.String, org.apache.avro.util.Utf8, + // or java.lang.CharSequence, unless SpecificData.CLASS_PROP overrides that. + if (specificClassStr != null) { + Class specificClass; + try { + specificClass = ClassUtils.forName(specificClassStr); + if (!DETERMINISTIC_STRINGABLE_CLASSES.contains(specificClass)) { + reportError( + context, + "Specific class %s is not known to be deterministic", + specificClassStr); + } + } catch (ClassNotFoundException e) { + reportError( + context, "Specific class %s is not known to be deterministic", specificClassStr); + } + } + break; + case UNION: + for (Schema subschema : schema.getTypes()) { + checkIndexedRecord(subschema.getName(), subschema, null); + } + break; + case BOOLEAN: + case BYTES: + case DOUBLE: + case INT: + case FLOAT: + case LONG: + case NULL: + // For types that Avro encodes using one of the above primitives, we assume they are + // deterministic. + break; + default: + reportError(context, "Unknown schema type %s may be non-deterministic", schema.getType()); + break; + } + + activeSchemas.remove(schema); + } + + private void checkMap(String context, TypeDescriptor type, Schema schema) { + if (!isSubtypeOf(type, SortedMap.class)) { + reportError(context, "%s may not be deterministically ordered", type); + } + + // Avro (currently) asserts that all keys are strings. + // In case that changes, we double check that the key was a string: + Class keyType = type.resolveType(Map.class.getTypeParameters()[0]).getRawType(); + if (!String.class.equals(keyType)) { + reportError(context, "map keys should be Strings, but was %s", keyType); + } + + recurse(context, type.resolveType(Map.class.getTypeParameters()[1]), schema.getValueType()); + } + + private void checkArray(String context, TypeDescriptor type, Schema schema) { + TypeDescriptor elementType = null; + if (type.isArray()) { + // The type is an array (with ordering)-> deterministic iff the element is deterministic. + elementType = type.getComponentType(); + } else if (isSubtypeOf(type, Collection.class)) { + if (isSubtypeOf(type, List.class, SortedSet.class)) { + // Ordered collection -> deterministic iff the element is deterministic + elementType = type.resolveType(Collection.class.getTypeParameters()[0]); + } else { + // Not an ordered collection -> not deterministic + reportError(context, "%s may not be deterministically ordered", type); + return; + } + } else { + // If it was an unknown type encoded as an array, be conservative and assume + // that we don't know anything about the order. + reportError(context, "encoding %s as an ARRAY was unexpected", type); + return; + } + + // If we get here, it's either a deterministically-ordered Collection, or + // an array. Either way, the type is deterministic iff the element type is + // deterministic. + recurse(context, elementType, schema.getElementType()); + } + + /** + * Extract a field from a class. We need to look at the declared fields so that we can see + * private fields. We may need to walk up to the parent to get classes from the parent. + */ + private static Field getField(Class originalClazz, String name) { + Class clazz = originalClazz; + while (clazz != null) { + for (Field field : clazz.getDeclaredFields()) { + AvroName avroName = field.getAnnotation(AvroName.class); + if (avroName != null && name.equals(avroName.value())) { + return field; + } else if (avroName == null && name.equals(field.getName())) { + return field; + } + } + clazz = clazz.getSuperclass(); + } + + throw new IllegalArgumentException("Unable to get field " + name + " from " + originalClazz); + } + } + + @Override + public boolean equals(@Nullable Object other) { + if (other == this) { + return true; + } + if (!(other instanceof AvroCoder)) { + return false; + } + AvroCoder that = (AvroCoder) other; + return Objects.equals(this.schemaSupplier.get(), that.schemaSupplier.get()) + && Objects.equals(this.typeDescriptor, that.typeDescriptor) + && this.useReflectApi == that.useReflectApi; + } + + @Override + public int hashCode() { + return Objects.hash(schemaSupplier.get(), typeDescriptor, useReflectApi); + } + + /** + * Conversion for DateTime. + * + *

This is a copy from Avro 1.8's TimestampConversion, which is renamed in Avro 1.9. Defining + * own copy gives flexibility for Beam Java SDK to work with Avro 1.8 and 1.9 at runtime. + * + * @see BEAM-9144: Beam's own Avro + * TimeConversion class in beam-sdk-java-core + */ + public static class JodaTimestampConversion extends Conversion { + @Override + public Class getConvertedType() { + return DateTime.class; + } + + @Override + public String getLogicalTypeName() { + return "timestamp-millis"; + } + + @Override + public DateTime fromLong(Long millisFromEpoch, Schema schema, LogicalType type) { + return new DateTime(millisFromEpoch, DateTimeZone.UTC); + } + + @Override + public Long toLong(DateTime timestamp, Schema schema, LogicalType type) { + return timestamp.getMillis(); + } + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/AvroGenericCoder.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/AvroGenericCoder.java new file mode 100644 index 000000000000..46e0b9715b24 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/AvroGenericCoder.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.coders; + +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericRecord; + +/** AvroCoder specialisation for GenericRecord. */ +public class AvroGenericCoder extends AvroCoder { + AvroGenericCoder(Schema schema) { + super(GenericRecord.class, schema); + } + + public static AvroGenericCoder of(Schema schema) { + return new AvroGenericCoder(schema); + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/package-info.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/package-info.java new file mode 100644 index 000000000000..639856878e57 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/coders/package-info.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Defines {@link org.apache.beam.sdk.coders.Coder Coders} to specify how data is encoded to and + * decoded from byte strings using Apache Avro. + */ +@DefaultAnnotation(NonNull.class) +@Experimental(Kind.EXTENSION) +package org.apache.beam.sdk.extensions.avro.coders; + +import edu.umd.cs.findbugs.annotations.DefaultAnnotation; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.checkerframework.checker.nullness.qual.NonNull; diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java new file mode 100644 index 000000000000..172d4e09d4e4 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java @@ -0,0 +1,2026 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import static org.apache.beam.sdk.io.FileIO.ReadMatches.DirectoryTreatment; +import static org.apache.beam.sdk.io.ReadAllViaFileBasedSource.ReadFileRangesFnExceptionHandler; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; +import java.io.IOException; +import java.io.Serializable; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.Map; +import org.apache.avro.Schema; +import org.apache.avro.file.CodecFactory; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.IndexedRecord; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumWriter; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.io.DefaultFilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSink; +import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSource; +import org.apache.beam.sdk.io.FileIO; +import org.apache.beam.sdk.io.FileIO.MatchConfiguration; +import org.apache.beam.sdk.io.FileIO.ReadableFile; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.ReadAllViaFileBasedSource; +import org.apache.beam.sdk.io.ShardNameTemplate; +import org.apache.beam.sdk.io.WriteFiles; +import org.apache.beam.sdk.io.WriteFilesResult; +import org.apache.beam.sdk.io.fs.EmptyMatchTreatment; +import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; +import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SerializableFunctions; +import org.apache.beam.sdk.transforms.Watch.Growth.TerminationCondition; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; + +/** + * {@link PTransform}s for reading and writing Avro files. + * + *

Reading Avro files

+ * + *

To read a {@link PCollection} from one or more Avro files with the same schema known at + * pipeline construction time, use {@link #read}, using {@link Read#from} to specify the filename or + * filepattern to read from. If the filepatterns to be read are themselves in a {@link PCollection} + * you can use {@link FileIO} to match them and {@link AvroIO#readFiles} to read them. If the schema + * is unknown at pipeline construction time, use {@link #parseGenericRecords} or {@link + * #parseFilesGenericRecords}. + * + *

Many configuration options below apply to several or all of these transforms. + * + *

See {@link FileSystems} for information on supported file systems and filepatterns. + * + *

Filepattern expansion and watching

+ * + *

By default, the filepatterns are expanded only once. {@link Read#watchForNewFiles} or the + * combination of {@link FileIO.Match#continuously(Duration, TerminationCondition)} and {@link + * AvroIO#readFiles(Class)} allow streaming of new files matching the filepattern(s). + * + *

By default, {@link #read} prohibits filepatterns that match no files, and {@link + * AvroIO#readFiles(Class)} allows them in case the filepattern contains a glob wildcard character. + * Use {@link Read#withEmptyMatchTreatment} or {@link + * FileIO.Match#withEmptyMatchTreatment(EmptyMatchTreatment)} plus {@link AvroIO#readFiles(Class)} + * to configure this behavior. + * + *

Reading records of a known schema

+ * + *

To read specific records, such as Avro-generated classes, use {@link #read(Class)}. To read + * {@link GenericRecord GenericRecords}, use {@link #readGenericRecords(Schema)} which takes a + * {@link Schema} object, or {@link #readGenericRecords(String)} which takes an Avro schema in a + * JSON-encoded string form. An exception will be thrown if a record doesn't match the specified + * schema. Likewise, to read a {@link PCollection} of filepatterns, apply {@link FileIO} matching + * plus {@link #readFilesGenericRecords}. + * + *

For example: + * + *

{@code
+ * Pipeline p = ...;
+ *
+ * // Read Avro-generated classes from files on GCS
+ * PCollection records =
+ *     p.apply(AvroIO.read(AvroAutoGenClass.class).from("gs://my_bucket/path/to/records-*.avro"));
+ *
+ * // Read GenericRecord's of the given schema from files on GCS
+ * Schema schema = new Schema.Parser().parse(new File("schema.avsc"));
+ * PCollection records =
+ *     p.apply(AvroIO.readGenericRecords(schema)
+ *                .from("gs://my_bucket/path/to/records-*.avro"));
+ * }
+ * + *

Reading records of an unknown schema

+ * + *

To read records from files whose schema is unknown at pipeline construction time or differs + * between files, use {@link #parseGenericRecords} - in this case, you will need to specify a + * parsing function for converting each {@link GenericRecord} into a value of your custom type. + * Likewise, to read a {@link PCollection} of filepatterns with unknown schema, use {@link FileIO} + * matching plus {@link #parseFilesGenericRecords(SerializableFunction)}. + * + *

For example: + * + *

{@code
+ * Pipeline p = ...;
+ *
+ * PCollection records =
+ *     p.apply(AvroIO.parseGenericRecords(new SerializableFunction() {
+ *       public Foo apply(GenericRecord record) {
+ *         // If needed, access the schema of the record using record.getSchema()
+ *         return ...;
+ *       }
+ *     }));
+ * }
+ * + *

Reading from a {@link PCollection} of filepatterns

+ * + *
{@code
+ * Pipeline p = ...;
+ *
+ * PCollection filepatterns = p.apply(...);
+ * PCollection records =
+ *     filepatterns.apply(AvroIO.readAll(AvroAutoGenClass.class));
+ * PCollection records =
+ *     filepatterns
+ *         .apply(FileIO.matchAll())
+ *         .apply(FileIO.readMatches())
+ *         .apply(AvroIO.readFiles(AvroAutoGenClass.class));
+ * PCollection genericRecords =
+ *     filepatterns.apply(AvroIO.readGenericRecords(schema));
+ * PCollection records =
+ *     filepatterns
+ *         .apply(FileIO.matchAll())
+ *         .apply(FileIO.readMatches())
+ *         .apply(AvroIO.parseFilesGenericRecords(new SerializableFunction...);
+ * }
+ * + *

Streaming new files matching a filepattern

+ * + *
{@code
+ * Pipeline p = ...;
+ *
+ * PCollection lines = p.apply(AvroIO
+ *     .read(AvroAutoGenClass.class)
+ *     .from("gs://my_bucket/path/to/records-*.avro")
+ *     .watchForNewFiles(
+ *       // Check for new files every minute
+ *       Duration.standardMinutes(1),
+ *       // Stop watching the filepattern if no new files appear within an hour
+ *       afterTimeSinceNewOutput(Duration.standardHours(1))));
+ * }
+ * + *

Reading a very large number of files

+ * + *

If it is known that the filepattern will match a very large number of files (e.g. tens of + * thousands or more), use {@link Read#withHintMatchesManyFiles} for better performance and + * scalability. Note that it may decrease performance if the filepattern matches only a small number + * of files. + * + *

Inferring Beam schemas from Avro files

+ * + *

If you want to use SQL or schema based operations on an Avro-based PCollection, you must + * configure the read transform to infer the Beam schema and automatically setup the Beam related + * coders by doing: + * + *

{@code
+ * PCollection records =
+ *     p.apply(AvroIO.read(...).from(...).withBeamSchemas(true));
+ * }
+ * + *

Inferring Beam schemas from Avro PCollections

+ * + *

If you created an Avro-based PCollection by other means e.g. reading records from Kafka or as + * the output of another PTransform, you may be interested on making your PCollection schema-aware + * so you can use the Schema-based APIs or Beam's SqlTransform. + * + *

If you are using Avro specific records (generated classes from an Avro schema), you can + * register a schema provider for the specific Avro class to make any PCollection of these objects + * schema-aware. + * + *

{@code
+ * pipeline.getSchemaRegistry().registerSchemaProvider(AvroAutoGenClass.class, AvroAutoGenClass.getClassSchema());
+ * }
+ * + * You can also manually set an Avro-backed Schema coder for a PCollection using {@link + * AvroUtils#schemaCoder(Class, Schema)} to make it schema-aware. + * + *
{@code
+ * PCollection records = ...
+ * AvroCoder coder = (AvroCoder) users.getCoder();
+ * records.setCoder(AvroUtils.schemaCoder(coder.getType(), coder.getSchema()));
+ * }
+ * + *

If you are using GenericRecords you may need to set a specific Beam schema coder for each + * PCollection to match their internal Avro schema. + * + *

{@code
+ * org.apache.avro.Schema avroSchema = ...
+ * PCollection records = ...
+ * records.setCoder(AvroUtils.schemaCoder(avroSchema));
+ * }
+ * + *

Writing Avro files

+ * + *

To write a {@link PCollection} to one or more Avro files, use {@link Write}, using {@code + * AvroIO.write().to(String)} to specify the output filename prefix. The default {@link + * DefaultFilenamePolicy} will use this prefix, in conjunction with a {@link ShardNameTemplate} (set + * via {@link Write#withShardNameTemplate(String)}) and optional filename suffix (set via {@link + * Write#withSuffix(String)}, to generate output filenames in a sharded way. You can override this + * default write filename policy using {@link Write#to(FilenamePolicy)} to specify a custom file + * naming policy. + * + *

By default, {@link Write} produces output files that are compressed using the {@link + * org.apache.avro.file.Codec CodecFactory.snappyCodec()}. This default can be changed or overridden + * using {@link Write#withCodec}. + * + *

Writing specific or generic records

+ * + *

To write specific records, such as Avro-generated classes, use {@link #write(Class)}. To write + * {@link GenericRecord GenericRecords}, use either {@link #writeGenericRecords(Schema)} which takes + * a {@link Schema} object, or {@link #writeGenericRecords(String)} which takes a schema in a + * JSON-encoded string form. An exception will be thrown if a record doesn't match the specified + * schema. + * + *

For example: + * + *

{@code
+ * // A simple Write to a local file (only runs locally):
+ * PCollection records = ...;
+ * records.apply(AvroIO.write(AvroAutoGenClass.class).to("/path/to/file.avro"));
+ *
+ * // A Write to a sharded GCS file (runs locally and using remote execution):
+ * Schema schema = new Schema.Parser().parse(new File("schema.avsc"));
+ * PCollection records = ...;
+ * records.apply("WriteToAvro", AvroIO.writeGenericRecords(schema)
+ *     .to("gs://my_bucket/path/to/numbers")
+ *     .withSuffix(".avro"));
+ * }
+ * + *

Writing windowed or unbounded data

+ * + *

By default, all input is put into the global window before writing. If per-window writes are + * desired - for example, when using a streaming runner - {@link Write#withWindowedWrites()} will + * cause windowing and triggering to be preserved. When producing windowed writes with a streaming + * runner that supports triggers, the number of output shards must be set explicitly using {@link + * Write#withNumShards(int)}; some runners may set this for you to a runner-chosen value, so you may + * need not set it yourself. A {@link FilenamePolicy} must be set, and unique windows and triggers + * must produce unique filenames. + * + *

Writing data to multiple destinations

+ * + *

The following shows a more-complex example of AvroIO.Write usage, generating dynamic file + * destinations as well as a dynamic Avro schema per file. In this example, a PCollection of user + * events (e.g. actions on a website) is written out to Avro files. Each event contains the user id + * as an integer field. We want events for each user to go into a specific directory for that user, + * and each user's data should be written with a specific schema for that user; a side input is + * used, so the schema can be calculated in a different stage. + * + *

{@code
+ * // This is the user class that controls dynamic destinations for this avro write. The input to
+ * // AvroIO.Write will be UserEvent, and we will be writing GenericRecords to the file (in order
+ * // to have dynamic schemas). Everything is per userid, so we define a dynamic destination type
+ * // of Integer.
+ * class UserDynamicAvroDestinations
+ *     extends DynamicAvroDestinations {
+ *   private final PCollectionView> userToSchemaMap;
+ *   public UserDynamicAvroDestinations( PCollectionView> userToSchemaMap) {
+ *     this.userToSchemaMap = userToSchemaMap;
+ *   }
+ *   public GenericRecord formatRecord(UserEvent record) {
+ *     return formatUserRecord(record, getSchema(record.getUserId()));
+ *   }
+ *   public Schema getSchema(Integer userId) {
+ *     return new Schema.Parser().parse(sideInput(userToSchemaMap).get(userId));
+ *   }
+ *   public Integer getDestination(UserEvent record) {
+ *     return record.getUserId();
+ *   }
+ *   public Integer getDefaultDestination() {
+ *     return 0;
+ *   }
+ *   public FilenamePolicy getFilenamePolicy(Integer userId) {
+ *     return DefaultFilenamePolicy.fromParams(new Params().withBaseFilename(baseDir + "/user-"
+ *     + userId + "/events"));
+ *   }
+ *   public List> getSideInputs() {
+ *     return ImmutableList.>of(userToSchemaMap);
+ *   }
+ * }
+ * PCollection events = ...;
+ * PCollectionView> userToSchemaMap = events.apply(
+ *     "ComputePerUserSchemas", new ComputePerUserSchemas());
+ * events.apply("WriteAvros", AvroIO.writeCustomTypeToGenericRecords()
+ *     .to(new UserDynamicAvroDestinations(userToSchemaMap)));
+ * }
+ */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class AvroIO { + /** + * Reads records of the given type from an Avro file (or multiple Avro files matching a pattern). + * + *

The schema must be specified using one of the {@code withSchema} functions. + */ + public static Read read(Class recordClass) { + return new AutoValue_AvroIO_Read.Builder() + .setMatchConfiguration(MatchConfiguration.create(EmptyMatchTreatment.DISALLOW)) + .setRecordClass(recordClass) + .setSchema(ReflectData.get().getSchema(recordClass)) + .setInferBeamSchema(false) + .setHintMatchesManyFiles(false) + .build(); + } + + /** + * Like {@link #read}, but reads each file in a {@link PCollection} of {@link ReadableFile}, + * returned by {@link FileIO#readMatches}. + * + *

You can read {@link GenericRecord} by using {@code #readFiles(GenericRecord.class)} or + * {@code #readFiles(new Schema.Parser().parse(schema))} if the schema is a String. + */ + public static ReadFiles readFiles(Class recordClass) { + return new AutoValue_AvroIO_ReadFiles.Builder() + .setRecordClass(recordClass) + .setSchema(ReflectData.get().getSchema(recordClass)) + .setInferBeamSchema(false) + .setDesiredBundleSizeBytes(DEFAULT_BUNDLE_SIZE_BYTES) + .setUsesReshuffle(ReadAllViaFileBasedSource.DEFAULT_USES_RESHUFFLE) + .setFileExceptionHandler(new ReadFileRangesFnExceptionHandler()) + .build(); + } + + /** + * Like {@link #read}, but reads each filepattern in the input {@link PCollection}. + * + * @deprecated You can achieve The functionality of {@link #readAll} using {@link FileIO} matching + * plus {@link #readFiles(Class)}. This is the preferred method to make composition explicit. + * {@link ReadAll} will not receive upgrades and will be removed in a future version of Beam. + */ + @Deprecated + public static ReadAll readAll(Class recordClass) { + return new AutoValue_AvroIO_ReadAll.Builder() + .setMatchConfiguration(MatchConfiguration.create(EmptyMatchTreatment.ALLOW_IF_WILDCARD)) + .setRecordClass(recordClass) + .setSchema(ReflectData.get().getSchema(recordClass)) + .setInferBeamSchema(false) + .setDesiredBundleSizeBytes(DEFAULT_BUNDLE_SIZE_BYTES) + .build(); + } + + /** Reads Avro file(s) containing records of the specified schema. */ + public static Read readGenericRecords(Schema schema) { + return new AutoValue_AvroIO_Read.Builder() + .setMatchConfiguration(MatchConfiguration.create(EmptyMatchTreatment.DISALLOW)) + .setRecordClass(GenericRecord.class) + .setSchema(schema) + .setInferBeamSchema(false) + .setHintMatchesManyFiles(false) + .build(); + } + + /** + * Like {@link #readGenericRecords(Schema)}, but for a {@link PCollection} of {@link + * ReadableFile}, for example, returned by {@link FileIO#readMatches}. + */ + public static ReadFiles readFilesGenericRecords(Schema schema) { + return new AutoValue_AvroIO_ReadFiles.Builder() + .setRecordClass(GenericRecord.class) + .setSchema(schema) + .setInferBeamSchema(false) + .setDesiredBundleSizeBytes(DEFAULT_BUNDLE_SIZE_BYTES) + .setUsesReshuffle(ReadAllViaFileBasedSource.DEFAULT_USES_RESHUFFLE) + .setFileExceptionHandler(new ReadFileRangesFnExceptionHandler()) + .build(); + } + + /** + * Like {@link #readGenericRecords(Schema)}, but for a {@link PCollection} of {@link + * ReadableFile}, for example, returned by {@link FileIO#readMatches}. + * + * @deprecated You can achieve The functionality of {@link #readAllGenericRecords(Schema)} using + * {@link FileIO} matching plus {@link #readFilesGenericRecords(Schema)}. This is the + * preferred method to make composition explicit. {@link ReadAll} will not receive upgrades + * and will be removed in a future version of Beam. + */ + @Deprecated + public static ReadAll readAllGenericRecords(Schema schema) { + return new AutoValue_AvroIO_ReadAll.Builder() + .setMatchConfiguration(MatchConfiguration.create(EmptyMatchTreatment.ALLOW_IF_WILDCARD)) + .setRecordClass(GenericRecord.class) + .setSchema(schema) + .setInferBeamSchema(false) + .setDesiredBundleSizeBytes(DEFAULT_BUNDLE_SIZE_BYTES) + .build(); + } + + /** + * Reads Avro file(s) containing records of the specified schema. The schema is specified as a + * JSON-encoded string. + */ + public static Read readGenericRecords(String schema) { + return readGenericRecords(new Schema.Parser().parse(schema)); + } + + /** Like {@link #readGenericRecords(String)}, but for {@link ReadableFile} collections. */ + public static ReadFiles readFilesGenericRecords(String schema) { + return readFilesGenericRecords(new Schema.Parser().parse(schema)); + } + + /** + * Like {@link #readGenericRecords(String)}, but reads each filepattern in the input {@link + * PCollection}. + * + * @deprecated You can achieve The functionality of {@link #readAllGenericRecords(String)} using + * {@link FileIO} matching plus {@link #readFilesGenericRecords(String)}. This is the + * preferred method to make composition explicit. {@link ReadAll} will not receive upgrades + * and will be removed in a future version of Beam. + */ + @Deprecated + public static ReadAll readAllGenericRecords(String schema) { + return readAllGenericRecords(new Schema.Parser().parse(schema)); + } + + /** + * Reads Avro file(s) containing records of an unspecified schema and converting each record to a + * custom type. + */ + public static Parse parseGenericRecords(SerializableFunction parseFn) { + return new AutoValue_AvroIO_Parse.Builder() + .setMatchConfiguration(MatchConfiguration.create(EmptyMatchTreatment.DISALLOW)) + .setParseFn(parseFn) + .setHintMatchesManyFiles(false) + .build(); + } + + /** + * Like {@link #parseGenericRecords(SerializableFunction)}, but reads each {@link ReadableFile} in + * the input {@link PCollection}. + */ + public static ParseFiles parseFilesGenericRecords( + SerializableFunction parseFn) { + return new AutoValue_AvroIO_ParseFiles.Builder() + .setParseFn(parseFn) + .setDesiredBundleSizeBytes(DEFAULT_BUNDLE_SIZE_BYTES) + .setUsesReshuffle(ReadAllViaFileBasedSource.DEFAULT_USES_RESHUFFLE) + .setFileExceptionHandler(new ReadFileRangesFnExceptionHandler()) + .build(); + } + + /** + * Like {@link #parseGenericRecords(SerializableFunction)}, but reads each filepattern in the + * input {@link PCollection}. + * + * @deprecated You can achieve The functionality of {@link + * #parseAllGenericRecords(SerializableFunction)} using {@link FileIO} matching plus {@link + * #parseFilesGenericRecords(SerializableFunction)} ()}. This is the preferred method to make + * composition explicit. {@link ParseAll} will not receive upgrades and will be removed in a + * future version of Beam. + */ + @Deprecated + public static ParseAll parseAllGenericRecords( + SerializableFunction parseFn) { + return new AutoValue_AvroIO_ParseAll.Builder() + .setMatchConfiguration(MatchConfiguration.create(EmptyMatchTreatment.ALLOW_IF_WILDCARD)) + .setParseFn(parseFn) + .setDesiredBundleSizeBytes(DEFAULT_BUNDLE_SIZE_BYTES) + .build(); + } + + /** + * Writes a {@link PCollection} to an Avro file (or multiple Avro files matching a sharding + * pattern). + */ + public static Write write(Class recordClass) { + return new Write<>( + AvroIO.defaultWriteBuilder() + .setGenericRecords(false) + .setSchema(ReflectData.get().getSchema(recordClass)) + .build()); + } + + /** Writes Avro records of the specified schema. */ + public static Write writeGenericRecords(Schema schema) { + return new Write<>( + AvroIO.defaultWriteBuilder() + .setGenericRecords(true) + .setSchema(schema) + .build()); + } + + /** + * A {@link PTransform} that writes a {@link PCollection} to an avro file (or multiple avro files + * matching a sharding pattern), with each element of the input collection encoded into its own + * record of type OutputT. + * + *

This version allows you to apply {@link AvroIO} writes to a PCollection of a custom type + * {@link UserT}. A format mechanism that converts the input type {@link UserT} to the output type + * that will be written to the file must be specified. If using a custom {@link + * DynamicAvroDestinations} object this is done using {@link + * DynamicAvroDestinations#formatRecord}, otherwise the {@link TypedWrite#withFormatFunction} can + * be used to specify a format function. + * + *

The advantage of using a custom type is that is it allows a user-provided {@link + * DynamicAvroDestinations} object, set via {@link Write#to(DynamicAvroDestinations)} to examine + * the custom type when choosing a destination. + * + *

If the output type is {@link GenericRecord} use {@link #writeCustomTypeToGenericRecords()} + * instead. + */ + public static TypedWrite writeCustomType() { + return AvroIO.defaultWriteBuilder().setGenericRecords(false).build(); + } + + /** + * Similar to {@link #writeCustomType()}, but specialized for the case where the output type is + * {@link GenericRecord}. A schema must be specified either in {@link + * DynamicAvroDestinations#getSchema} or if not using dynamic destinations, by using {@link + * TypedWrite#withSchema(Schema)}. + */ + public static TypedWrite writeCustomTypeToGenericRecords() { + return AvroIO.defaultWriteBuilder().setGenericRecords(true).build(); + } + + /** + * Writes Avro records of the specified schema. The schema is specified as a JSON-encoded string. + */ + public static Write writeGenericRecords(String schema) { + return writeGenericRecords(new Schema.Parser().parse(schema)); + } + + private static TypedWrite.Builder defaultWriteBuilder() { + return new AutoValue_AvroIO_TypedWrite.Builder() + .setFilenameSuffix(null) + .setShardTemplate(null) + .setNumShards(0) + .setCodec(TypedWrite.DEFAULT_SERIALIZABLE_CODEC) + .setMetadata(ImmutableMap.of()) + .setWindowedWrites(false) + .setNoSpilling(false); + } + + @Experimental(Kind.SCHEMAS) + private static PCollection setBeamSchema( + PCollection pc, Class clazz, @Nullable Schema schema) { + return pc.setCoder(AvroUtils.schemaCoder(clazz, schema)); + } + + /** + * 64MB is a reasonable value that allows to amortize the cost of opening files, but is not so + * large as to exhaust a typical runner's maximum amount of output per ProcessElement call. + */ + private static final long DEFAULT_BUNDLE_SIZE_BYTES = 64 * 1024 * 1024L; + + /** Implementation of {@link #read} and {@link #readGenericRecords}. */ + @AutoValue + public abstract static class Read extends PTransform> { + + abstract @Nullable ValueProvider getFilepattern(); + + abstract MatchConfiguration getMatchConfiguration(); + + abstract @Nullable Class getRecordClass(); + + abstract @Nullable Schema getSchema(); + + abstract boolean getInferBeamSchema(); + + abstract boolean getHintMatchesManyFiles(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setFilepattern(ValueProvider filepattern); + + abstract Builder setMatchConfiguration(MatchConfiguration matchConfiguration); + + abstract Builder setRecordClass(Class recordClass); + + abstract Builder setSchema(Schema schema); + + abstract Builder setInferBeamSchema(boolean infer); + + abstract Builder setHintMatchesManyFiles(boolean hintManyFiles); + + abstract Read build(); + } + + /** + * Reads from the given filename or filepattern. + * + *

If it is known that the filepattern will match a very large number of files (at least tens + * of thousands), use {@link #withHintMatchesManyFiles} for better performance and scalability. + */ + public Read from(ValueProvider filepattern) { + return toBuilder().setFilepattern(filepattern).build(); + } + + /** Like {@link #from(ValueProvider)}. */ + public Read from(String filepattern) { + return from(StaticValueProvider.of(filepattern)); + } + + /** Sets the {@link MatchConfiguration}. */ + public Read withMatchConfiguration(MatchConfiguration matchConfiguration) { + return toBuilder().setMatchConfiguration(matchConfiguration).build(); + } + + /** Configures whether or not a filepattern matching no files is allowed. */ + public Read withEmptyMatchTreatment(EmptyMatchTreatment treatment) { + return withMatchConfiguration(getMatchConfiguration().withEmptyMatchTreatment(treatment)); + } + + /** + * Continuously watches for new files matching the filepattern, polling it at the given + * interval, until the given termination condition is reached. The returned {@link PCollection} + * is unbounded. If {@code matchUpdatedFiles} is set, also watches for files with timestamp + * change. + * + *

This works only in runners supporting splittable {@link + * org.apache.beam.sdk.transforms.DoFn}. + */ + public Read watchForNewFiles( + Duration pollInterval, + TerminationCondition terminationCondition, + boolean matchUpdatedFiles) { + return withMatchConfiguration( + getMatchConfiguration() + .continuously(pollInterval, terminationCondition, matchUpdatedFiles)); + } + + /** + * Same as {@link Read#watchForNewFiles(Duration, TerminationCondition, boolean)} with {@code + * matchUpdatedFiles=false}. + */ + public Read watchForNewFiles( + Duration pollInterval, TerminationCondition terminationCondition) { + return watchForNewFiles(pollInterval, terminationCondition, false); + } + + /** + * Hints that the filepattern specified in {@link #from(String)} matches a very large number of + * files. + * + *

This hint may cause a runner to execute the transform differently, in a way that improves + * performance for this case, but it may worsen performance if the filepattern matches only a + * small number of files (e.g., in a runner that supports dynamic work rebalancing, it will + * happen less efficiently within individual files). + */ + public Read withHintMatchesManyFiles() { + return toBuilder().setHintMatchesManyFiles(true).build(); + } + + /** + * If set to true, a Beam schema will be inferred from the AVRO schema. This allows the output + * to be used by SQL and by the schema-transform library. + */ + @Experimental(Kind.SCHEMAS) + public Read withBeamSchemas(boolean withBeamSchemas) { + return toBuilder().setInferBeamSchema(withBeamSchemas).build(); + } + + @Override + @SuppressWarnings("unchecked") + public PCollection expand(PBegin input) { + checkNotNull(getFilepattern(), "filepattern"); + checkNotNull(getSchema(), "schema"); + + if (getMatchConfiguration().getWatchInterval() == null && !getHintMatchesManyFiles()) { + PCollection read = + input.apply( + "Read", + org.apache.beam.sdk.io.Read.from( + createSource( + getFilepattern(), + getMatchConfiguration().getEmptyMatchTreatment(), + getRecordClass(), + getSchema(), + null))); + return getInferBeamSchema() ? setBeamSchema(read, getRecordClass(), getSchema()) : read; + } + + // All other cases go through FileIO + ReadFiles + ReadFiles readFiles = + (getRecordClass() == GenericRecord.class) + ? (ReadFiles) readFilesGenericRecords(getSchema()) + : readFiles(getRecordClass()); + return input + .apply("Create filepattern", Create.ofProvider(getFilepattern(), StringUtf8Coder.of())) + .apply("Match All", FileIO.matchAll().withConfiguration(getMatchConfiguration())) + .apply( + "Read Matches", + FileIO.readMatches().withDirectoryTreatment(DirectoryTreatment.PROHIBIT)) + .apply("Via ReadFiles", readFiles); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder + .add( + DisplayData.item("inferBeamSchema", getInferBeamSchema()) + .withLabel("Infer Beam Schema")) + .addIfNotNull(DisplayData.item("schema", String.valueOf(getSchema()))) + .addIfNotNull(DisplayData.item("recordClass", getRecordClass()).withLabel("Record Class")) + .addIfNotNull( + DisplayData.item("filePattern", getFilepattern()).withLabel("Input File Pattern")) + .include("matchConfiguration", getMatchConfiguration()); + } + + @SuppressWarnings("unchecked") + private static org.apache.beam.sdk.io.AvroSource createSource( + ValueProvider filepattern, + EmptyMatchTreatment emptyMatchTreatment, + Class recordClass, + Schema schema, + org.apache.beam.sdk.io.AvroSource.@Nullable DatumReaderFactory readerFactory) { + org.apache.beam.sdk.io.AvroSource source = + org.apache.beam.sdk.io.AvroSource.from(filepattern) + .withEmptyMatchTreatment(emptyMatchTreatment); + + if (readerFactory != null) { + source = source.withDatumReaderFactory(readerFactory); + } + return recordClass == GenericRecord.class + ? (org.apache.beam.sdk.io.AvroSource) source.withSchema(schema) + : source.withSchema(recordClass); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** Implementation of {@link #readFiles}. */ + @AutoValue + public abstract static class ReadFiles + extends PTransform, PCollection> { + + abstract @Nullable Class getRecordClass(); + + abstract @Nullable Schema getSchema(); + + abstract boolean getUsesReshuffle(); + + abstract ReadFileRangesFnExceptionHandler getFileExceptionHandler(); + + abstract long getDesiredBundleSizeBytes(); + + abstract boolean getInferBeamSchema(); + + abstract org.apache.beam.sdk.io.AvroSource.@Nullable DatumReaderFactory + getDatumReaderFactory(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setRecordClass(Class recordClass); + + abstract Builder setSchema(Schema schema); + + abstract Builder setUsesReshuffle(boolean usesReshuffle); + + abstract Builder setFileExceptionHandler( + ReadFileRangesFnExceptionHandler exceptionHandler); + + abstract Builder setDesiredBundleSizeBytes(long desiredBundleSizeBytes); + + abstract Builder setInferBeamSchema(boolean infer); + + abstract Builder setDatumReaderFactory( + org.apache.beam.sdk.io.AvroSource.DatumReaderFactory factory); + + abstract ReadFiles build(); + } + + @VisibleForTesting + ReadFiles withDesiredBundleSizeBytes(long desiredBundleSizeBytes) { + return toBuilder().setDesiredBundleSizeBytes(desiredBundleSizeBytes).build(); + } + + /** Specifies if a Reshuffle should run before file reads occur. */ + @Experimental(Kind.FILESYSTEM) + public ReadFiles withUsesReshuffle(boolean usesReshuffle) { + return toBuilder().setUsesReshuffle(usesReshuffle).build(); + } + + /** Specifies if exceptions should be logged only for streaming pipelines. */ + @Experimental(Kind.FILESYSTEM) + public ReadFiles withFileExceptionHandler( + ReadFileRangesFnExceptionHandler exceptionHandler) { + return toBuilder().setFileExceptionHandler(exceptionHandler).build(); + } + + /** + * If set to true, a Beam schema will be inferred from the AVRO schema. This allows the output + * to be used by SQL and by the schema-transform library. + */ + @Experimental(Kind.SCHEMAS) + public ReadFiles withBeamSchemas(boolean withBeamSchemas) { + return toBuilder().setInferBeamSchema(withBeamSchemas).build(); + } + + public ReadFiles withDatumReaderFactory( + org.apache.beam.sdk.io.AvroSource.DatumReaderFactory factory) { + return toBuilder().setDatumReaderFactory(factory).build(); + } + + @Override + public PCollection expand(PCollection input) { + checkNotNull(getSchema(), "schema"); + PCollection read = + input.apply( + "Read all via FileBasedSource", + new ReadAllViaFileBasedSource<>( + getDesiredBundleSizeBytes(), + new CreateSourceFn<>( + getRecordClass(), getSchema().toString(), getDatumReaderFactory()), + AvroCoder.of(getRecordClass(), getSchema()), + getUsesReshuffle(), + getFileExceptionHandler())); + return getInferBeamSchema() ? setBeamSchema(read, getRecordClass(), getSchema()) : read; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder + .add( + DisplayData.item("inferBeamSchema", getInferBeamSchema()) + .withLabel("Infer Beam Schema")) + .addIfNotNull(DisplayData.item("schema", String.valueOf(getSchema()))) + .addIfNotNull( + DisplayData.item("recordClass", getRecordClass()).withLabel("Record Class")); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Implementation of {@link #readAll}. + * + * @deprecated See {@link #readAll(Class)} for details. + */ + @Deprecated + @AutoValue + public abstract static class ReadAll extends PTransform, PCollection> { + abstract MatchConfiguration getMatchConfiguration(); + + abstract @Nullable Class getRecordClass(); + + abstract @Nullable Schema getSchema(); + + abstract long getDesiredBundleSizeBytes(); + + abstract boolean getInferBeamSchema(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setMatchConfiguration(MatchConfiguration matchConfiguration); + + abstract Builder setRecordClass(Class recordClass); + + abstract Builder setSchema(Schema schema); + + abstract Builder setDesiredBundleSizeBytes(long desiredBundleSizeBytes); + + abstract Builder setInferBeamSchema(boolean infer); + + abstract ReadAll build(); + } + + /** Sets the {@link MatchConfiguration}. */ + public ReadAll withMatchConfiguration(MatchConfiguration configuration) { + return toBuilder().setMatchConfiguration(configuration).build(); + } + + /** Like {@link Read#withEmptyMatchTreatment}. */ + public ReadAll withEmptyMatchTreatment(EmptyMatchTreatment treatment) { + return withMatchConfiguration(getMatchConfiguration().withEmptyMatchTreatment(treatment)); + } + + /** Like {@link Read#watchForNewFiles}. */ + public ReadAll watchForNewFiles( + Duration pollInterval, TerminationCondition terminationCondition) { + return withMatchConfiguration( + getMatchConfiguration().continuously(pollInterval, terminationCondition)); + } + + @VisibleForTesting + ReadAll withDesiredBundleSizeBytes(long desiredBundleSizeBytes) { + return toBuilder().setDesiredBundleSizeBytes(desiredBundleSizeBytes).build(); + } + + /** + * If set to true, a Beam schema will be inferred from the AVRO schema. This allows the output + * to be used by SQL and by the schema-transform library. + */ + @Experimental(Kind.SCHEMAS) + public ReadAll withBeamSchemas(boolean withBeamSchemas) { + return toBuilder().setInferBeamSchema(withBeamSchemas).build(); + } + + @Override + public PCollection expand(PCollection input) { + checkNotNull(getSchema(), "schema"); + PCollection read = + input + .apply(FileIO.matchAll().withConfiguration(getMatchConfiguration())) + .apply(FileIO.readMatches().withDirectoryTreatment(DirectoryTreatment.PROHIBIT)) + .apply(readFiles(getRecordClass())); + return getInferBeamSchema() ? setBeamSchema(read, getRecordClass(), getSchema()) : read; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder + .add( + DisplayData.item("inferBeamSchema", getInferBeamSchema()) + .withLabel("Infer Beam Schema")) + .addIfNotNull(DisplayData.item("schema", String.valueOf(getSchema()))) + .addIfNotNull(DisplayData.item("recordClass", getRecordClass()).withLabel("Record Class")) + .include("matchConfiguration", getMatchConfiguration()); + } + } + + private static class CreateSourceFn + implements SerializableFunction> { + private final Class recordClass; + private final Supplier schemaSupplier; + private final org.apache.beam.sdk.io.AvroSource.DatumReaderFactory readerFactory; + + CreateSourceFn( + Class recordClass, + String jsonSchema, + org.apache.beam.sdk.io.AvroSource.DatumReaderFactory readerFactory) { + this.recordClass = recordClass; + this.schemaSupplier = + Suppliers.memoize( + Suppliers.compose(new JsonToSchema(), Suppliers.ofInstance(jsonSchema))); + this.readerFactory = readerFactory; + } + + @Override + public FileBasedSource apply(String input) { + return Read.createSource( + StaticValueProvider.of(input), + EmptyMatchTreatment.DISALLOW, + recordClass, + schemaSupplier.get(), + readerFactory); + } + + private static class JsonToSchema implements Function, Serializable { + @Override + public Schema apply(String input) { + return new Schema.Parser().parse(input); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** Implementation of {@link #parseGenericRecords}. */ + @AutoValue + public abstract static class Parse extends PTransform> { + + abstract @Nullable ValueProvider getFilepattern(); + + abstract MatchConfiguration getMatchConfiguration(); + + abstract SerializableFunction getParseFn(); + + abstract @Nullable Coder getCoder(); + + abstract boolean getHintMatchesManyFiles(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setFilepattern(ValueProvider filepattern); + + abstract Builder setMatchConfiguration(MatchConfiguration matchConfiguration); + + abstract Builder setParseFn(SerializableFunction parseFn); + + abstract Builder setCoder(Coder coder); + + abstract Builder setHintMatchesManyFiles(boolean hintMatchesManyFiles); + + abstract Parse build(); + } + + /** Reads from the given filename or filepattern. */ + public Parse from(String filepattern) { + return from(StaticValueProvider.of(filepattern)); + } + + /** Like {@link #from(String)}. */ + public Parse from(ValueProvider filepattern) { + return toBuilder().setFilepattern(filepattern).build(); + } + + /** Sets the {@link MatchConfiguration}. */ + public Parse withMatchConfiguration(MatchConfiguration configuration) { + return toBuilder().setMatchConfiguration(configuration).build(); + } + + /** Like {@link Read#withEmptyMatchTreatment}. */ + public Parse withEmptyMatchTreatment(EmptyMatchTreatment treatment) { + return withMatchConfiguration(getMatchConfiguration().withEmptyMatchTreatment(treatment)); + } + + /** Like {@link Read#watchForNewFiles}. */ + public Parse watchForNewFiles( + Duration pollInterval, TerminationCondition terminationCondition) { + return withMatchConfiguration( + getMatchConfiguration().continuously(pollInterval, terminationCondition)); + } + + /** Sets a coder for the result of the parse function. */ + public Parse withCoder(Coder coder) { + return toBuilder().setCoder(coder).build(); + } + + /** Like {@link Read#withHintMatchesManyFiles()}. */ + public Parse withHintMatchesManyFiles() { + return toBuilder().setHintMatchesManyFiles(true).build(); + } + + @Override + public PCollection expand(PBegin input) { + checkNotNull(getFilepattern(), "filepattern"); + Coder coder = inferCoder(getCoder(), getParseFn(), input.getPipeline().getCoderRegistry()); + + if (getMatchConfiguration().getWatchInterval() == null && !getHintMatchesManyFiles()) { + return input.apply( + org.apache.beam.sdk.io.Read.from( + org.apache.beam.sdk.io.AvroSource.from(getFilepattern()) + .withParseFn(getParseFn(), coder))); + } + + // All other cases go through FileIO + ParseFilesGenericRecords. + return input + .apply("Create filepattern", Create.ofProvider(getFilepattern(), StringUtf8Coder.of())) + .apply("Match All", FileIO.matchAll().withConfiguration(getMatchConfiguration())) + .apply( + "Read Matches", + FileIO.readMatches().withDirectoryTreatment(DirectoryTreatment.PROHIBIT)) + .apply("Via ParseFiles", parseFilesGenericRecords(getParseFn()).withCoder(coder)); + } + + private static Coder inferCoder( + @Nullable Coder explicitCoder, + SerializableFunction parseFn, + CoderRegistry coderRegistry) { + if (explicitCoder != null) { + return explicitCoder; + } + // If a coder was not specified explicitly, infer it from parse fn. + try { + return coderRegistry.getCoder(TypeDescriptors.outputOf(parseFn)); + } catch (CannotProvideCoderException e) { + throw new IllegalArgumentException( + "Unable to infer coder for output of parseFn. Specify it explicitly using withCoder().", + e); + } + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder + .addIfNotNull( + DisplayData.item("filePattern", getFilepattern()).withLabel("Input File Pattern")) + .add(DisplayData.item("parseFn", getParseFn().getClass()).withLabel("Parse function")) + .include("matchConfiguration", getMatchConfiguration()); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** Implementation of {@link #parseFilesGenericRecords}. */ + @AutoValue + public abstract static class ParseFiles + extends PTransform, PCollection> { + abstract SerializableFunction getParseFn(); + + abstract @Nullable Coder getCoder(); + + abstract boolean getUsesReshuffle(); + + abstract ReadFileRangesFnExceptionHandler getFileExceptionHandler(); + + abstract long getDesiredBundleSizeBytes(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setParseFn(SerializableFunction parseFn); + + abstract Builder setCoder(Coder coder); + + abstract Builder setUsesReshuffle(boolean usesReshuffle); + + abstract Builder setFileExceptionHandler( + ReadFileRangesFnExceptionHandler exceptionHandler); + + abstract Builder setDesiredBundleSizeBytes(long desiredBundleSizeBytes); + + abstract ParseFiles build(); + } + + /** Specifies the coder for the result of the {@code parseFn}. */ + public ParseFiles withCoder(Coder coder) { + return toBuilder().setCoder(coder).build(); + } + + /** Specifies if a Reshuffle should run before file reads occur. */ + @Experimental(Kind.FILESYSTEM) + public ParseFiles withUsesReshuffle(boolean usesReshuffle) { + return toBuilder().setUsesReshuffle(usesReshuffle).build(); + } + + /** Specifies if exceptions should be logged only for streaming pipelines. */ + @Experimental(Kind.FILESYSTEM) + public ParseFiles withFileExceptionHandler( + ReadFileRangesFnExceptionHandler exceptionHandler) { + return toBuilder().setFileExceptionHandler(exceptionHandler).build(); + } + + @VisibleForTesting + ParseFiles withDesiredBundleSizeBytes(long desiredBundleSizeBytes) { + return toBuilder().setDesiredBundleSizeBytes(desiredBundleSizeBytes).build(); + } + + @Override + public PCollection expand(PCollection input) { + final Coder coder = + Parse.inferCoder(getCoder(), getParseFn(), input.getPipeline().getCoderRegistry()); + final SerializableFunction parseFn = getParseFn(); + final SerializableFunction> createSource = + new CreateParseSourceFn<>(parseFn, coder); + return input.apply( + "Parse Files via FileBasedSource", + new ReadAllViaFileBasedSource<>( + getDesiredBundleSizeBytes(), + createSource, + coder, + getUsesReshuffle(), + getFileExceptionHandler())); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder.add(DisplayData.item("parseFn", getParseFn().getClass()).withLabel("Parse function")); + } + + private static class CreateParseSourceFn + implements SerializableFunction> { + private final SerializableFunction parseFn; + private final Coder coder; + + CreateParseSourceFn(SerializableFunction parseFn, Coder coder) { + this.parseFn = parseFn; + this.coder = coder; + } + + @Override + public FileBasedSource apply(String input) { + return AvroSource.from(input).withParseFn(parseFn, coder); + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Implementation of {@link #parseAllGenericRecords}. + * + * @deprecated See {@link #parseAllGenericRecords(SerializableFunction)} for details. + */ + @Deprecated + @AutoValue + public abstract static class ParseAll extends PTransform, PCollection> { + abstract MatchConfiguration getMatchConfiguration(); + + abstract SerializableFunction getParseFn(); + + abstract @Nullable Coder getCoder(); + + abstract long getDesiredBundleSizeBytes(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setMatchConfiguration(MatchConfiguration matchConfiguration); + + abstract Builder setParseFn(SerializableFunction parseFn); + + abstract Builder setCoder(Coder coder); + + abstract Builder setDesiredBundleSizeBytes(long desiredBundleSizeBytes); + + abstract ParseAll build(); + } + + /** Sets the {@link MatchConfiguration}. */ + public ParseAll withMatchConfiguration(MatchConfiguration configuration) { + return toBuilder().setMatchConfiguration(configuration).build(); + } + + /** Like {@link Read#withEmptyMatchTreatment}. */ + public ParseAll withEmptyMatchTreatment(EmptyMatchTreatment treatment) { + return withMatchConfiguration(getMatchConfiguration().withEmptyMatchTreatment(treatment)); + } + + /** Like {@link Read#watchForNewFiles(Duration, TerminationCondition, boolean)}. */ + public ParseAll watchForNewFiles( + Duration pollInterval, + TerminationCondition terminationCondition, + boolean matchUpdatedFiles) { + return withMatchConfiguration( + getMatchConfiguration() + .continuously(pollInterval, terminationCondition, matchUpdatedFiles)); + } + + /** Like {@link Read#watchForNewFiles(Duration, TerminationCondition)}. */ + public ParseAll watchForNewFiles( + Duration pollInterval, TerminationCondition terminationCondition) { + return watchForNewFiles(pollInterval, terminationCondition, false); + } + + /** Specifies the coder for the result of the {@code parseFn}. */ + public ParseAll withCoder(Coder coder) { + return toBuilder().setCoder(coder).build(); + } + + @VisibleForTesting + ParseAll withDesiredBundleSizeBytes(long desiredBundleSizeBytes) { + return toBuilder().setDesiredBundleSizeBytes(desiredBundleSizeBytes).build(); + } + + @Override + public PCollection expand(PCollection input) { + return input + .apply(FileIO.matchAll().withConfiguration(getMatchConfiguration())) + .apply(FileIO.readMatches().withDirectoryTreatment(DirectoryTreatment.PROHIBIT)) + .apply( + "Parse all via FileBasedSource", + parseFilesGenericRecords(getParseFn()).withCoder(getCoder())); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder + .add(DisplayData.item("parseFn", getParseFn().getClass()).withLabel("Parse function")) + .include("matchConfiguration", getMatchConfiguration()); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** Implementation of {@link #write}. */ + @AutoValue + public abstract static class TypedWrite + extends PTransform, WriteFilesResult> { + static final CodecFactory DEFAULT_CODEC = CodecFactory.snappyCodec(); + static final SerializableAvroCodecFactory DEFAULT_SERIALIZABLE_CODEC = + new SerializableAvroCodecFactory(DEFAULT_CODEC); + + abstract @Nullable SerializableFunction getFormatFunction(); + + abstract @Nullable ValueProvider getFilenamePrefix(); + + abstract @Nullable String getShardTemplate(); + + abstract @Nullable String getFilenameSuffix(); + + abstract @Nullable ValueProvider getTempDirectory(); + + abstract int getNumShards(); + + abstract boolean getGenericRecords(); + + abstract @Nullable Schema getSchema(); + + abstract boolean getWindowedWrites(); + + abstract boolean getNoSpilling(); + + abstract @Nullable FilenamePolicy getFilenamePolicy(); + + abstract @Nullable DynamicAvroDestinations + getDynamicDestinations(); + + abstract AvroSink.@Nullable DatumWriterFactory getDatumWriterFactory(); + + /** + * The codec used to encode the blocks in the Avro file. String value drawn from those in + * https://avro.apache.org/docs/1.7.7/api/java/org/apache/avro/file/CodecFactory.html + */ + abstract SerializableAvroCodecFactory getCodec(); + /** Avro file metadata. */ + abstract ImmutableMap getMetadata(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setFormatFunction( + @Nullable SerializableFunction formatFunction); + + abstract Builder setFilenamePrefix( + ValueProvider filenamePrefix); + + abstract Builder setFilenameSuffix( + @Nullable String filenameSuffix); + + abstract Builder setTempDirectory( + ValueProvider tempDirectory); + + abstract Builder setNumShards(int numShards); + + abstract Builder setShardTemplate( + @Nullable String shardTemplate); + + abstract Builder setGenericRecords(boolean genericRecords); + + abstract Builder setSchema(Schema schema); + + abstract Builder setWindowedWrites(boolean windowedWrites); + + abstract Builder setNoSpilling(boolean noSpilling); + + abstract Builder setFilenamePolicy( + FilenamePolicy filenamePolicy); + + abstract Builder setCodec(SerializableAvroCodecFactory codec); + + abstract Builder setMetadata( + ImmutableMap metadata); + + abstract Builder setDynamicDestinations( + DynamicAvroDestinations dynamicDestinations); + + abstract Builder setDatumWriterFactory( + AvroSink.DatumWriterFactory datumWriterFactory); + + abstract TypedWrite build(); + } + + /** + * Writes to file(s) with the given output prefix. See {@link FileSystems} for information on + * supported file systems. + * + *

The name of the output files will be determined by the {@link FilenamePolicy} used. + * + *

By default, a {@link DefaultFilenamePolicy} will build output filenames using the + * specified prefix, a shard name template (see {@link #withShardNameTemplate(String)}, and a + * common suffix (if supplied using {@link #withSuffix(String)}). This default can be overridden + * using {@link #to(FilenamePolicy)}. + */ + public TypedWrite to(String outputPrefix) { + return to(FileBasedSink.convertToFileResourceIfPossible(outputPrefix)); + } + + /** + * Writes to file(s) with the given output prefix. See {@link FileSystems} for information on + * supported file systems. This prefix is used by the {@link DefaultFilenamePolicy} to generate + * filenames. + * + *

By default, a {@link DefaultFilenamePolicy} will build output filenames using the + * specified prefix, a shard name template (see {@link #withShardNameTemplate(String)}, and a + * common suffix (if supplied using {@link #withSuffix(String)}). This default can be overridden + * using {@link #to(FilenamePolicy)}. + * + *

This default policy can be overridden using {@link #to(FilenamePolicy)}, in which case + * {@link #withShardNameTemplate(String)} and {@link #withSuffix(String)} should not be set. + * Custom filename policies do not automatically see this prefix - you should explicitly pass + * the prefix into your {@link FilenamePolicy} object if you need this. + * + *

If {@link #withTempDirectory} has not been called, this filename prefix will be used to + * infer a directory for temporary files. + */ + @Experimental(Kind.FILESYSTEM) + public TypedWrite to(ResourceId outputPrefix) { + return toResource(StaticValueProvider.of(outputPrefix)); + } + + private static class OutputPrefixToResourceId + implements SerializableFunction { + @Override + public ResourceId apply(String input) { + return FileBasedSink.convertToFileResourceIfPossible(input); + } + } + + /** Like {@link #to(String)}. */ + public TypedWrite to(ValueProvider outputPrefix) { + return toResource( + NestedValueProvider.of( + outputPrefix, + // The function cannot be created as an anonymous class here since the enclosed class + // may contain unserializable members. + new OutputPrefixToResourceId())); + } + + /** Like {@link #to(ResourceId)}. */ + @Experimental(Kind.FILESYSTEM) + public TypedWrite toResource( + ValueProvider outputPrefix) { + return toBuilder().setFilenamePrefix(outputPrefix).build(); + } + + /** + * Writes to files named according to the given {@link FilenamePolicy}. A directory for + * temporary files must be specified using {@link #withTempDirectory}. + */ + @Experimental(Kind.FILESYSTEM) + public TypedWrite to(FilenamePolicy filenamePolicy) { + return toBuilder().setFilenamePolicy(filenamePolicy).build(); + } + + /** + * Use a {@link DynamicAvroDestinations} object to vend {@link FilenamePolicy} objects. These + * objects can examine the input record when creating a {@link FilenamePolicy}. A directory for + * temporary files must be specified using {@link #withTempDirectory}. + * + * @deprecated Use {@link FileIO#write()} or {@link FileIO#writeDynamic()} instead. + */ + @Experimental(Kind.FILESYSTEM) + @Deprecated + public TypedWrite to( + DynamicAvroDestinations dynamicDestinations) { + return toBuilder() + .setDynamicDestinations((DynamicAvroDestinations) dynamicDestinations) + .build(); + } + + /** + * Sets the output schema. Can only be used when the output type is {@link GenericRecord} and + * when not using {@link #to(DynamicAvroDestinations)}. + */ + public TypedWrite withSchema(Schema schema) { + return toBuilder().setSchema(schema).build(); + } + + /** + * Specifies a format function to convert {@link UserT} to the output type. If {@link + * #to(DynamicAvroDestinations)} is used, {@link DynamicAvroDestinations#formatRecord} must be + * used instead. + */ + public TypedWrite withFormatFunction( + @Nullable SerializableFunction formatFunction) { + return toBuilder().setFormatFunction(formatFunction).build(); + } + + /** Set the base directory used to generate temporary files. */ + @Experimental(Kind.FILESYSTEM) + public TypedWrite withTempDirectory( + ValueProvider tempDirectory) { + return toBuilder().setTempDirectory(tempDirectory).build(); + } + + /** Set the base directory used to generate temporary files. */ + @Experimental(Kind.FILESYSTEM) + public TypedWrite withTempDirectory(ResourceId tempDirectory) { + return withTempDirectory(StaticValueProvider.of(tempDirectory)); + } + + /** + * Uses the given {@link ShardNameTemplate} for naming output files. This option may only be + * used when using one of the default filename-prefix to() overrides. + * + *

See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are + * used. + */ + public TypedWrite withShardNameTemplate(String shardTemplate) { + return toBuilder().setShardTemplate(shardTemplate).build(); + } + + /** + * Configures the filename suffix for written files. This option may only be used when using one + * of the default filename-prefix to() overrides. + * + *

See {@link DefaultFilenamePolicy} for how the prefix, shard name template, and suffix are + * used. + */ + public TypedWrite withSuffix(String filenameSuffix) { + return toBuilder().setFilenameSuffix(filenameSuffix).build(); + } + + /** + * Configures the number of output shards produced overall (when using unwindowed writes) or + * per-window (when using windowed writes). + * + *

For unwindowed writes, constraining the number of shards is likely to reduce the + * performance of a pipeline. Setting this value is not recommended unless you require a + * specific number of output files. + * + * @param numShards the number of shards to use, or 0 to let the system decide. + */ + public TypedWrite withNumShards(int numShards) { + checkArgument(numShards >= 0); + return toBuilder().setNumShards(numShards).build(); + } + + /** + * Forces a single file as output and empty shard name template. This option is only compatible + * with unwindowed writes. + * + *

For unwindowed writes, constraining the number of shards is likely to reduce the + * performance of a pipeline. Setting this value is not recommended unless you require a + * specific number of output files. + * + *

This is equivalent to {@code .withNumShards(1).withShardNameTemplate("")} + */ + public TypedWrite withoutSharding() { + return withNumShards(1).withShardNameTemplate(""); + } + + /** + * Preserves windowing of input elements and writes them to files based on the element's window. + * + *

If using {@link #to(FilenamePolicy)}. Filenames will be generated using {@link + * FilenamePolicy#windowedFilename}. See also {@link WriteFiles#withWindowedWrites()}. + */ + public TypedWrite withWindowedWrites() { + return toBuilder().setWindowedWrites(true).build(); + } + + /** See {@link WriteFiles#withNoSpilling()}. */ + public TypedWrite withNoSpilling() { + return toBuilder().setNoSpilling(true).build(); + } + + /** Writes to Avro file(s) compressed using specified codec. */ + public TypedWrite withCodec(CodecFactory codec) { + return toBuilder().setCodec(new SerializableAvroCodecFactory(codec)).build(); + } + + /** + * Specifies a {@link AvroSink.DatumWriterFactory} to use for creating {@link + * org.apache.avro.io.DatumWriter} instances. + */ + public TypedWrite withDatumWriterFactory( + AvroSink.DatumWriterFactory datumWriterFactory) { + return toBuilder().setDatumWriterFactory(datumWriterFactory).build(); + } + + /** + * Writes to Avro file(s) with the specified metadata. + * + *

Supported value types are String, Long, and byte[]. + */ + public TypedWrite withMetadata(Map metadata) { + Map badKeys = Maps.newLinkedHashMap(); + for (Map.Entry entry : metadata.entrySet()) { + Object v = entry.getValue(); + if (!(v instanceof String || v instanceof Long || v instanceof byte[])) { + badKeys.put(entry.getKey(), v.getClass().getSimpleName()); + } + } + checkArgument( + badKeys.isEmpty(), + "Metadata value type must be one of String, Long, or byte[]. Found {}", + badKeys); + return toBuilder().setMetadata(ImmutableMap.copyOf(metadata)).build(); + } + + DynamicAvroDestinations resolveDynamicDestinations() { + DynamicAvroDestinations dynamicDestinations = + getDynamicDestinations(); + if (dynamicDestinations == null) { + // In this case DestinationT is Void. + FilenamePolicy usedFilenamePolicy = getFilenamePolicy(); + if (usedFilenamePolicy == null) { + usedFilenamePolicy = + DefaultFilenamePolicy.fromStandardParameters( + getFilenamePrefix(), + getShardTemplate(), + getFilenameSuffix(), + getWindowedWrites()); + } + dynamicDestinations = + (DynamicAvroDestinations) + constantDestinations( + usedFilenamePolicy, + getSchema(), + getMetadata(), + getCodec().getCodec(), + getFormatFunction(), + getDatumWriterFactory()); + } + return dynamicDestinations; + } + + @Override + public WriteFilesResult expand(PCollection input) { + checkArgument( + getFilenamePrefix() != null || getTempDirectory() != null, + "Need to set either the filename prefix or the tempDirectory of a AvroIO.Write " + + "transform."); + if (getFilenamePolicy() != null) { + checkArgument( + getShardTemplate() == null && getFilenameSuffix() == null, + "shardTemplate and filenameSuffix should only be used with the default " + + "filename policy"); + } + if (getDynamicDestinations() != null) { + checkArgument( + getFormatFunction() == null, + "A format function should not be specified " + + "with DynamicDestinations. Use DynamicDestinations.formatRecord instead"); + } else { + checkArgument( + getSchema() != null, "Unless using DynamicDestinations, .withSchema() is required."); + } + + ValueProvider tempDirectory = getTempDirectory(); + if (tempDirectory == null) { + tempDirectory = getFilenamePrefix(); + } + WriteFiles write = + WriteFiles.to( + new AvroSink<>(tempDirectory, resolveDynamicDestinations(), getGenericRecords())); + if (getNumShards() > 0) { + write = write.withNumShards(getNumShards()); + } + if (getWindowedWrites()) { + write = write.withWindowedWrites(); + } + if (getNoSpilling()) { + write = write.withNoSpilling(); + } + return input.apply("Write", write); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + resolveDynamicDestinations().populateDisplayData(builder); + builder + .addIfNotDefault( + DisplayData.item("numShards", getNumShards()).withLabel("Maximum Output Shards"), 0) + .addIfNotNull( + DisplayData.item("tempDirectory", getTempDirectory()) + .withLabel("Directory for temporary files")); + } + } + + /** + * This class is used as the default return value of {@link AvroIO#write} + * + *

All methods in this class delegate to the appropriate method of {@link TypedWrite}. This + * class exists for backwards compatibility, and will be removed in Beam 3.0. + */ + public static class Write extends PTransform, PDone> { + @VisibleForTesting final TypedWrite inner; + + Write(TypedWrite inner) { + this.inner = inner; + } + + /** See {@link TypedWrite#to(String)}. */ + public Write to(String outputPrefix) { + return new Write<>( + inner + .to(FileBasedSink.convertToFileResourceIfPossible(outputPrefix)) + .withFormatFunction(SerializableFunctions.identity())); + } + + /** See {@link TypedWrite#to(ResourceId)} . */ + @Experimental(Kind.FILESYSTEM) + public Write to(ResourceId outputPrefix) { + return new Write<>( + inner.to(outputPrefix).withFormatFunction(SerializableFunctions.identity())); + } + + /** See {@link TypedWrite#to(ValueProvider)}. */ + public Write to(ValueProvider outputPrefix) { + return new Write<>( + inner.to(outputPrefix).withFormatFunction(SerializableFunctions.identity())); + } + + /** See {@link TypedWrite#to(ResourceId)}. */ + @Experimental(Kind.FILESYSTEM) + public Write toResource(ValueProvider outputPrefix) { + return new Write<>( + inner.toResource(outputPrefix).withFormatFunction(SerializableFunctions.identity())); + } + + /** See {@link TypedWrite#to(FilenamePolicy)}. */ + public Write to(FilenamePolicy filenamePolicy) { + return new Write<>( + inner.to(filenamePolicy).withFormatFunction(SerializableFunctions.identity())); + } + + /** + * See {@link TypedWrite#to(DynamicAvroDestinations)}. + * + * @deprecated Use {@link FileIO#write()} or {@link FileIO#writeDynamic()} instead. + */ + @Deprecated + public Write to(DynamicAvroDestinations dynamicDestinations) { + return new Write<>(inner.to(dynamicDestinations).withFormatFunction(null)); + } + + /** See {@link TypedWrite#withSchema}. */ + public Write withSchema(Schema schema) { + return new Write<>(inner.withSchema(schema)); + } + /** See {@link TypedWrite#withTempDirectory(ValueProvider)}. */ + @Experimental(Kind.FILESYSTEM) + public Write withTempDirectory(ValueProvider tempDirectory) { + return new Write<>(inner.withTempDirectory(tempDirectory)); + } + + /** See {@link TypedWrite#withTempDirectory(ResourceId)}. */ + public Write withTempDirectory(ResourceId tempDirectory) { + return new Write<>(inner.withTempDirectory(tempDirectory)); + } + + /** See {@link TypedWrite#withShardNameTemplate}. */ + public Write withShardNameTemplate(String shardTemplate) { + return new Write<>(inner.withShardNameTemplate(shardTemplate)); + } + + /** See {@link TypedWrite#withSuffix}. */ + public Write withSuffix(String filenameSuffix) { + return new Write<>(inner.withSuffix(filenameSuffix)); + } + + /** See {@link TypedWrite#withNumShards}. */ + public Write withNumShards(int numShards) { + return new Write<>(inner.withNumShards(numShards)); + } + + /** See {@link TypedWrite#withoutSharding}. */ + public Write withoutSharding() { + return new Write<>(inner.withoutSharding()); + } + + /** See {@link TypedWrite#withWindowedWrites}. */ + public Write withWindowedWrites() { + return new Write<>(inner.withWindowedWrites()); + } + + /** See {@link TypedWrite#withCodec}. */ + public Write withCodec(CodecFactory codec) { + return new Write<>(inner.withCodec(codec)); + } + + /** See {@link TypedWrite#withDatumWriterFactory}. */ + public Write withDatumWriterFactory(AvroSink.DatumWriterFactory datumWriterFactory) { + return new Write<>(inner.withDatumWriterFactory(datumWriterFactory)); + } + + /** + * Specify that output filenames are wanted. + * + *

The nested {@link TypedWrite}transform always has access to output filenames, however due + * to backwards-compatibility concerns, {@link Write} cannot return them. This method simply + * returns the inner {@link TypedWrite} transform which has {@link WriteFilesResult} as its + * output type, allowing access to output files. + * + *

The supplied {@code DestinationT} type must be: the same as that supplied in {@link + * #to(DynamicAvroDestinations)} if that method was used, or {@code Void} otherwise. + */ + public TypedWrite withOutputFilenames() { + return (TypedWrite) inner; + } + + /** See {@link TypedWrite#withMetadata} . */ + public Write withMetadata(Map metadata) { + return new Write<>(inner.withMetadata(metadata)); + } + + @Override + public PDone expand(PCollection input) { + input.apply(inner); + return PDone.in(input.getPipeline()); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + inner.populateDisplayData(builder); + } + } + + /** + * Returns a {@link DynamicAvroDestinations} that always returns the same {@link FilenamePolicy}, + * schema, metadata, and codec. + */ + public static DynamicAvroDestinations constantDestinations( + FilenamePolicy filenamePolicy, + Schema schema, + Map metadata, + CodecFactory codec, + SerializableFunction formatFunction) { + return constantDestinations(filenamePolicy, schema, metadata, codec, formatFunction, null); + } + + /** + * Returns a {@link DynamicAvroDestinations} that always returns the same {@link FilenamePolicy}, + * schema, metadata, and codec. + */ + public static DynamicAvroDestinations constantDestinations( + FilenamePolicy filenamePolicy, + Schema schema, + Map metadata, + CodecFactory codec, + SerializableFunction formatFunction, + AvroSink.@Nullable DatumWriterFactory datumWriterFactory) { + return new ConstantAvroDestination<>( + filenamePolicy, schema, metadata, codec, formatFunction, datumWriterFactory); + } + ///////////////////////////////////////////////////////////////////////////// + + /** + * Formats an element of a user type into a record with the given schema. + * + * @deprecated Users can achieve the same by providing this transform in a {@link + * org.apache.beam.sdk.transforms.ParDo} before using write in AvroIO {@link #write(Class)}. + */ + @Deprecated + public interface RecordFormatter extends Serializable { + GenericRecord formatRecord(ElementT element, Schema schema); + } + + /** + * A {@link Sink} for use with {@link FileIO#write} and {@link FileIO#writeDynamic}, writing + * elements of the given generated class, like {@link #write(Class)}. + */ + public static Sink sink(final Class clazz) { + return new AutoValue_AvroIO_Sink.Builder() + .setJsonSchema(ReflectData.get().getSchema(clazz).toString()) + .setMetadata(ImmutableMap.of()) + .setCodec(TypedWrite.DEFAULT_SERIALIZABLE_CODEC) + .build(); + } + + /** + * A {@link Sink} for use with {@link FileIO#write} and {@link FileIO#writeDynamic}, writing + * elements with a given (common) schema, like {@link #writeGenericRecords(Schema)}. + */ + @Experimental(Kind.SOURCE_SINK) + public static Sink sink(Schema schema) { + return sink(schema.toString()); + } + + /** + * A {@link Sink} for use with {@link FileIO#write} and {@link FileIO#writeDynamic}, writing + * elements with a given (common) schema, like {@link #writeGenericRecords(String)}. + */ + @Experimental(Kind.SOURCE_SINK) + public static Sink sink(String jsonSchema) { + return new AutoValue_AvroIO_Sink.Builder() + .setJsonSchema(jsonSchema) + .setMetadata(ImmutableMap.of()) + .setCodec(TypedWrite.DEFAULT_SERIALIZABLE_CODEC) + .build(); + } + + /** + * A {@link Sink} for use with {@link FileIO#write} and {@link FileIO#writeDynamic}, writing + * elements by converting each one to a {@link GenericRecord} with a given (common) schema, like + * {@link #writeCustomTypeToGenericRecords()}. + * + * @deprecated RecordFormatter will be removed in future versions. + */ + @Deprecated + public static Sink sinkViaGenericRecords( + Schema schema, RecordFormatter formatter) { + return new AutoValue_AvroIO_Sink.Builder() + .setRecordFormatter(formatter) + .setJsonSchema(schema.toString()) + .setMetadata(ImmutableMap.of()) + .setCodec(TypedWrite.DEFAULT_SERIALIZABLE_CODEC) + .build(); + } + + /** Implementation of {@link #sink} and {@link #sinkViaGenericRecords}. */ + @AutoValue + public abstract static class Sink implements FileIO.Sink { + /** @deprecated RecordFormatter will be removed in future versions. */ + @Deprecated + abstract @Nullable RecordFormatter getRecordFormatter(); + + abstract @Nullable String getJsonSchema(); + + abstract Map getMetadata(); + + abstract SerializableAvroCodecFactory getCodec(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + /** @deprecated RecordFormatter will be removed in future versions. */ + @Deprecated + abstract Builder setRecordFormatter(RecordFormatter formatter); + + abstract Builder setJsonSchema(String jsonSchema); + + abstract Builder setMetadata(Map metadata); + + abstract Builder setCodec(SerializableAvroCodecFactory codec); + + abstract Sink build(); + } + + /** Specifies to put the given metadata into each generated file. By default, empty. */ + public Sink withMetadata(Map metadata) { + return toBuilder().setMetadata(metadata).build(); + } + + /** + * Specifies to use the given {@link CodecFactory} for each generated file. By default, {@code + * CodecFactory.snappyCodec()}. + */ + public Sink withCodec(CodecFactory codec) { + return toBuilder().setCodec(new SerializableAvroCodecFactory(codec)).build(); + } + + private transient @Nullable Schema schema; + private transient @Nullable DataFileWriter reflectWriter; + private transient @Nullable DataFileWriter genericWriter; + + @Override + public void open(WritableByteChannel channel) throws IOException { + this.schema = new Schema.Parser().parse(getJsonSchema()); + DataFileWriter writer; + if (getRecordFormatter() == null) { + writer = reflectWriter = new DataFileWriter<>(new ReflectDatumWriter<>(schema)); + } else { + writer = genericWriter = new DataFileWriter<>(new GenericDatumWriter<>(schema)); + } + writer.setCodec(getCodec().getCodec()); + for (Map.Entry entry : getMetadata().entrySet()) { + Object v = entry.getValue(); + if (v instanceof String) { + writer.setMeta(entry.getKey(), (String) v); + } else if (v instanceof Long) { + writer.setMeta(entry.getKey(), (Long) v); + } else if (v instanceof byte[]) { + writer.setMeta(entry.getKey(), (byte[]) v); + } else { + throw new IllegalStateException( + "Metadata value type must be one of String, Long, or byte[]. Found " + + v.getClass().getSimpleName()); + } + } + writer.create(schema, Channels.newOutputStream(channel)); + } + + @Override + public void write(ElementT element) throws IOException { + if (getRecordFormatter() == null) { + reflectWriter.append(element); + } else { + genericWriter.append(getRecordFormatter().formatRecord(element, schema)); + } + } + + @Override + public void flush() throws IOException { + MoreObjects.firstNonNull(reflectWriter, genericWriter).flush(); + } + } + + /** Disallow construction of utility class. */ + private AvroIO() {} +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSchemaIOProvider.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSchemaIOProvider.java new file mode 100644 index 000000000000..08a9f3a2946b --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSchemaIOProvider.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import com.google.auto.service.AutoService; +import java.io.Serializable; +import org.apache.avro.generic.GenericRecord; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.io.SchemaIO; +import org.apache.beam.sdk.schemas.io.SchemaIOProvider; +import org.apache.beam.sdk.schemas.transforms.Convert; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; + +/** + * An implementation of {@link SchemaIOProvider} for reading and writing Avro files with {@link + * AvroIO}. + */ +@Internal +@AutoService(SchemaIOProvider.class) +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class AvroSchemaIOProvider implements SchemaIOProvider { + /** Returns an id that uniquely represents this IO. */ + @Override + public String identifier() { + return "avro"; + } + + /** + * Returns the expected schema of the configuration object. Note this is distinct from the schema + * of the data source itself. No configuration expected for Avro. + */ + @Override + public Schema configurationSchema() { + return Schema.builder().addNullableField("writeWindowSizeSeconds", FieldType.INT64).build(); + } + + /** + * Produce a SchemaIO given a String representing the data's location, the schema of the data that + * resides there, and some IO-specific configuration object. + */ + @Override + public AvroSchemaIO from(String location, Row configuration, Schema dataSchema) { + return new AvroSchemaIO(location, dataSchema, configuration); + } + + @Override + public boolean requiresDataSchema() { + return true; + } + + @Override + public IsBounded isBounded() { + // This supports streaming now as well but there's no option for this. The move to + // SchemaTransform will remove the need to provide this. + return IsBounded.BOUNDED; + } + + /** An abstraction to create schema aware IOs. */ + private static class AvroSchemaIO implements SchemaIO, Serializable { + protected final Schema dataSchema; + protected final String location; + protected final @Nullable Duration windowSize; + + private AvroSchemaIO(String location, Schema dataSchema, Row configuration) { + this.dataSchema = dataSchema; + this.location = location; + if (configuration.getInt64("writeWindowSizeSeconds") != null) { + windowSize = Duration.standardSeconds(configuration.getInt64("writeWindowSizeSeconds")); + } else { + windowSize = null; + } + } + + @Override + public Schema schema() { + return dataSchema; + } + + @Override + public PTransform> buildReader() { + return new PTransform>() { + @Override + public PCollection expand(PBegin begin) { + return begin + .apply( + "AvroIORead", + AvroIO.readGenericRecords(AvroUtils.toAvroSchema(dataSchema, null, null)) + .withBeamSchemas(true) + .from(location)) + .apply("ToRows", Convert.toRows()); + } + }; + } + + @Override + public PTransform, POutput> buildWriter() { + return new PTransform, POutput>() { + @Override + public PDone expand(PCollection input) { + PCollection asRecords = + input.apply("ToGenericRecords", Convert.to(GenericRecord.class)); + AvroIO.Write avroWrite = + AvroIO.writeGenericRecords(AvroUtils.toAvroSchema(dataSchema, null, null)) + .to(location); + if (input.isBounded() == IsBounded.UNBOUNDED || windowSize != null) { + asRecords = + asRecords.apply( + Window.into( + FixedWindows.of( + windowSize == null ? Duration.standardMinutes(1) : windowSize))); + avroWrite = avroWrite.withWindowedWrites().withNumShards(1); + } else { + avroWrite = avroWrite.withoutSharding(); + } + return asRecords.apply("AvroIOWrite", avroWrite); + } + }; + } + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSink.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSink.java new file mode 100644 index 000000000000..7870bf786e92 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSink.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import java.io.Serializable; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.Map; +import org.apache.avro.Schema; +import org.apache.avro.file.CodecFactory; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.reflect.ReflectDatumWriter; +import org.apache.beam.sdk.io.Compression; +import org.apache.beam.sdk.io.FileBasedSink; +import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.util.MimeTypes; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** A {@link FileBasedSink} for Avro files. */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class AvroSink + extends FileBasedSink { + private final boolean genericRecords; + + @FunctionalInterface + public interface DatumWriterFactory extends Serializable { + DatumWriter apply(Schema writer); + } + + AvroSink( + ValueProvider outputPrefix, + DynamicAvroDestinations dynamicDestinations, + boolean genericRecords) { + // Avro handles compression internally using the codec. + super(outputPrefix, dynamicDestinations, Compression.UNCOMPRESSED); + this.genericRecords = genericRecords; + } + + @Override + public DynamicAvroDestinations getDynamicDestinations() { + return (DynamicAvroDestinations) super.getDynamicDestinations(); + } + + @Override + public WriteOperation createWriteOperation() { + return new AvroWriteOperation<>(this, genericRecords); + } + + /** A {@link WriteOperation WriteOperation} for Avro files. */ + private static class AvroWriteOperation + extends WriteOperation { + private final DynamicAvroDestinations dynamicDestinations; + private final boolean genericRecords; + + private AvroWriteOperation(AvroSink sink, boolean genericRecords) { + super(sink); + this.dynamicDestinations = sink.getDynamicDestinations(); + this.genericRecords = genericRecords; + } + + @Override + public Writer createWriter() throws Exception { + return new AvroWriter<>(this, dynamicDestinations, genericRecords); + } + } + + /** A {@link Writer Writer} for Avro files. */ + private static class AvroWriter extends Writer { + + // Initialized in prepareWrite + private @Nullable DataFileWriter dataFileWriter; + + private final DynamicAvroDestinations dynamicDestinations; + private final boolean genericRecords; + + public AvroWriter( + WriteOperation writeOperation, + DynamicAvroDestinations dynamicDestinations, + boolean genericRecords) { + super(writeOperation, MimeTypes.BINARY); + this.dynamicDestinations = dynamicDestinations; + this.genericRecords = genericRecords; + } + + @SuppressWarnings("deprecation") // uses internal test functionality. + @Override + protected void prepareWrite(WritableByteChannel channel) throws Exception { + DestinationT destination = getDestination(); + CodecFactory codec = dynamicDestinations.getCodec(destination); + Schema schema = dynamicDestinations.getSchema(destination); + Map metadata = dynamicDestinations.getMetadata(destination); + DatumWriter datumWriter; + DatumWriterFactory datumWriterFactory = + dynamicDestinations.getDatumWriterFactory(destination); + + if (datumWriterFactory == null) { + datumWriter = + genericRecords ? new GenericDatumWriter<>(schema) : new ReflectDatumWriter<>(schema); + } else { + datumWriter = datumWriterFactory.apply(schema); + } + + dataFileWriter = new DataFileWriter<>(datumWriter).setCodec(codec); + for (Map.Entry entry : metadata.entrySet()) { + Object v = entry.getValue(); + if (v instanceof String) { + dataFileWriter.setMeta(entry.getKey(), (String) v); + } else if (v instanceof Long) { + dataFileWriter.setMeta(entry.getKey(), (Long) v); + } else if (v instanceof byte[]) { + dataFileWriter.setMeta(entry.getKey(), (byte[]) v); + } else { + throw new IllegalStateException( + "Metadata value type must be one of String, Long, or byte[]. Found " + + v.getClass().getSimpleName()); + } + } + dataFileWriter.create(schema, Channels.newOutputStream(channel)); + } + + @Override + public void write(OutputT value) throws Exception { + dataFileWriter.append(value); + } + + @Override + protected void finishWrite() throws Exception { + dataFileWriter.flush(); + } + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSource.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSource.java new file mode 100644 index 000000000000..aaa05bdc1739 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSource.java @@ -0,0 +1,777 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import static org.apache.beam.sdk.io.FileBasedSource.Mode.SINGLE_FILE_OR_SUBRANGE; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InvalidObjectException; +import java.io.ObjectInputStream; +import java.io.ObjectStreamException; +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.WeakHashMap; +import javax.annotation.concurrent.GuardedBy; +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileConstants; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.SeekableInput; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumReader; +import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.io.BlockBasedSource; +import org.apache.beam.sdk.io.FileBasedSource; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.OffsetBasedSource; +import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.io.fs.EmptyMatchTreatment; +import org.apache.beam.sdk.io.fs.MatchResult.Metadata; +import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.VarInt; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; +import org.checkerframework.checker.nullness.qual.Nullable; + +// CHECKSTYLE.OFF: JavadocStyle +/** + * Do not use in pipelines directly: most users should use {@link AvroIO.Read}. + * + *

A {@link FileBasedSource} for reading Avro files. + * + *

To read a {@link PCollection} of objects from one or more Avro files, use {@link + * AvroSource#from} to specify the path(s) of the files to read. The {@link AvroSource} that is + * returned will read objects of type {@link GenericRecord} with the schema(s) that were written at + * file creation. To further configure the {@link AvroSource} to read with a user-defined schema, or + * to return records of a type other than {@link GenericRecord}, use {@link + * AvroSource#withSchema(Schema)} (using an Avro {@link Schema}), {@link + * AvroSource#withSchema(String)} (using a JSON schema), or {@link AvroSource#withSchema(Class)} (to + * return objects of the Avro-generated class specified). + * + *

An {@link AvroSource} can be read from using the {@link Read} transform. For example: + * + *

{@code
+ * AvroSource source = AvroSource.from(file.toPath()).withSchema(MyType.class);
+ * PCollection records = Read.from(mySource);
+ * }
+ * + *

This class's implementation is based on the Avro 1.7.7 specification and implements + * parsing of some parts of Avro Object Container Files. The rationale for doing so is that the Avro + * API does not provide efficient ways of computing the precise offsets of blocks within a file, + * which is necessary to support dynamic work rebalancing. However, whenever it is possible to use + * the Avro API in a way that supports maintaining precise offsets, this class uses the Avro API. + * + *

Avro Object Container files store records in blocks. Each block contains a collection of + * records. Blocks may be encoded (e.g., with bzip2, deflate, snappy, etc.). Blocks are delineated + * from one another by a 16-byte sync marker. + * + *

An {@link AvroSource} for a subrange of a single file contains records in the blocks such that + * the start offset of the block is greater than or equal to the start offset of the source and less + * than the end offset of the source. + * + *

To use XZ-encoded Avro files, please include an explicit dependency on {@code xz-1.8.jar}, + * which has been marked as optional in the Maven {@code sdk/pom.xml}. + * + *

{@code
+ * 
+ *   org.tukaani
+ *   xz
+ *   1.8
+ * 
+ * }
+ * + *

Permissions

+ * + *

Permission requirements depend on the {@link PipelineRunner} that is used to execute the + * pipeline. Please refer to the documentation of corresponding {@link PipelineRunner}s for more + * details. + * + * @param The type of records to be read from the source. + */ +// CHECKSTYLE.ON: JavadocStyle +@Experimental(Kind.SOURCE_SINK) +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public class AvroSource extends BlockBasedSource { + // Default minimum bundle size (chosen as two default-size Avro blocks to attempt to + // ensure that every source has at least one block of records). + // The default sync interval is 64k. + private static final long DEFAULT_MIN_BUNDLE_SIZE = 2L * DataFileConstants.DEFAULT_SYNC_INTERVAL; + + @FunctionalInterface + public interface DatumReaderFactory extends Serializable { + DatumReader apply(Schema writer, Schema reader); + } + + private static final DatumReaderFactory GENERIC_DATUM_READER_FACTORY = GenericDatumReader::new; + + private static final DatumReaderFactory REFLECT_DATUM_READER_FACTORY = ReflectDatumReader::new; + + // Use cases of AvroSource are: + // 1) AvroSource Reading GenericRecord records with a specified schema. + // 2) AvroSource Reading records of a generated Avro class Foo. + // 3) AvroSource Reading GenericRecord records with an unspecified schema + // and converting them to type T. + // | Case 1 | Case 2 | Case 3 | + // type | GenericRecord | Foo | GenericRecord | + // readerSchemaString | non-null | non-null | null | + // parseFn | null | null | non-null | + // outputCoder | null | null | non-null | + // readerFactory | either | either | either | + private static class Mode implements Serializable { + private final Class type; + + // The JSON schema used to decode records. + private @Nullable String readerSchemaString; + + private final @Nullable SerializableFunction parseFn; + + private final @Nullable Coder outputCoder; + + private final @Nullable DatumReaderFactory readerFactory; + + private Mode( + Class type, + @Nullable String readerSchemaString, + @Nullable SerializableFunction parseFn, + @Nullable Coder outputCoder, + @Nullable DatumReaderFactory readerFactory) { + this.type = type; + this.readerSchemaString = internSchemaString(readerSchemaString); + this.parseFn = parseFn; + this.outputCoder = outputCoder; + this.readerFactory = readerFactory; + } + + private void readObject(ObjectInputStream is) throws IOException, ClassNotFoundException { + is.defaultReadObject(); + readerSchemaString = internSchemaString(readerSchemaString); + } + + private Coder getOutputCoder() { + if (parseFn == null) { + return AvroCoder.of((Class) type, internOrParseSchemaString(readerSchemaString)); + } else { + return outputCoder; + } + } + + private void validate() { + if (parseFn == null) { + checkArgument( + readerSchemaString != null, + "schema must be specified using withSchema() when not using a parse fn"); + } + } + + private Mode withReaderFactory(DatumReaderFactory factory) { + return new Mode<>(type, readerSchemaString, parseFn, outputCoder, factory); + } + + private DatumReader createReader(Schema writerSchema, Schema readerSchema) { + DatumReaderFactory factory = this.readerFactory; + if (factory == null) { + factory = + (type == GenericRecord.class) + ? GENERIC_DATUM_READER_FACTORY + : REFLECT_DATUM_READER_FACTORY; + } + return factory.apply(writerSchema, readerSchema); + } + } + + private static Mode readGenericRecordsWithSchema( + String schema, @Nullable DatumReaderFactory factory) { + return new Mode<>(GenericRecord.class, schema, null, null, factory); + } + + private static Mode readGeneratedClasses( + Class clazz, @Nullable DatumReaderFactory factory) { + return new Mode<>(clazz, ReflectData.get().getSchema(clazz).toString(), null, null, factory); + } + + private static Mode parseGenericRecords( + SerializableFunction parseFn, + Coder outputCoder, + @Nullable DatumReaderFactory factory) { + return new Mode<>(GenericRecord.class, null, parseFn, outputCoder, factory); + } + + private final Mode mode; + + /** + * Reads from the given file name or pattern ("glob"). The returned source needs to be further + * configured by calling {@link #withSchema} to return a type other than {@link GenericRecord}. + */ + public static AvroSource from(ValueProvider fileNameOrPattern) { + return new AvroSource<>( + fileNameOrPattern, + EmptyMatchTreatment.DISALLOW, + DEFAULT_MIN_BUNDLE_SIZE, + readGenericRecordsWithSchema(null /* will need to be specified in withSchema */, null)); + } + + public static AvroSource from(Metadata metadata) { + return new AvroSource<>( + metadata, + DEFAULT_MIN_BUNDLE_SIZE, + 0, + metadata.sizeBytes(), + readGenericRecordsWithSchema(null /* will need to be specified in withSchema */, null)); + } + + /** Like {@link #from(ValueProvider)}. */ + public static AvroSource from(String fileNameOrPattern) { + return from(ValueProvider.StaticValueProvider.of(fileNameOrPattern)); + } + + public AvroSource withEmptyMatchTreatment(EmptyMatchTreatment emptyMatchTreatment) { + return new AvroSource<>( + getFileOrPatternSpecProvider(), emptyMatchTreatment, getMinBundleSize(), mode); + } + + /** Reads files containing records that conform to the given schema. */ + public AvroSource withSchema(String schema) { + checkArgument(schema != null, "schema can not be null"); + return new AvroSource<>( + getFileOrPatternSpecProvider(), + getEmptyMatchTreatment(), + getMinBundleSize(), + readGenericRecordsWithSchema(schema, mode.readerFactory)); + } + + /** Like {@link #withSchema(String)}. */ + public AvroSource withSchema(Schema schema) { + checkArgument(schema != null, "schema can not be null"); + return withSchema(schema.toString()); + } + + /** Reads files containing records of the given class. */ + public AvroSource withSchema(Class clazz) { + checkArgument(clazz != null, "clazz can not be null"); + if (getMode() == SINGLE_FILE_OR_SUBRANGE) { + return new AvroSource<>( + getSingleFileMetadata(), + getMinBundleSize(), + getStartOffset(), + getEndOffset(), + readGeneratedClasses(clazz, mode.readerFactory)); + } + return new AvroSource<>( + getFileOrPatternSpecProvider(), + getEmptyMatchTreatment(), + getMinBundleSize(), + readGeneratedClasses(clazz, mode.readerFactory)); + } + + /** + * Reads {@link GenericRecord} of unspecified schema and maps them to instances of a custom type + * using the given {@code parseFn} and encoded using the given coder. + */ + public AvroSource withParseFn( + SerializableFunction parseFn, Coder coder) { + checkArgument(parseFn != null, "parseFn can not be null"); + checkArgument(coder != null, "coder can not be null"); + if (getMode() == SINGLE_FILE_OR_SUBRANGE) { + return new AvroSource<>( + getSingleFileMetadata(), + getMinBundleSize(), + getStartOffset(), + getEndOffset(), + parseGenericRecords(parseFn, coder, mode.readerFactory)); + } + return new AvroSource<>( + getFileOrPatternSpecProvider(), + getEmptyMatchTreatment(), + getMinBundleSize(), + parseGenericRecords(parseFn, coder, mode.readerFactory)); + } + + /** + * Sets the minimum bundle size. Refer to {@link OffsetBasedSource} for a description of {@code + * minBundleSize} and its use. + */ + public AvroSource withMinBundleSize(long minBundleSize) { + if (getMode() == SINGLE_FILE_OR_SUBRANGE) { + return new AvroSource<>( + getSingleFileMetadata(), minBundleSize, getStartOffset(), getEndOffset(), mode); + } + return new AvroSource<>( + getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), minBundleSize, mode); + } + + public AvroSource withDatumReaderFactory(DatumReaderFactory factory) { + Mode newMode = mode.withReaderFactory(factory); + if (getMode() == SINGLE_FILE_OR_SUBRANGE) { + return new AvroSource<>( + getSingleFileMetadata(), getMinBundleSize(), getStartOffset(), getEndOffset(), newMode); + } + return new AvroSource<>( + getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), getMinBundleSize(), newMode); + } + + /** Constructor for FILEPATTERN mode. */ + private AvroSource( + ValueProvider fileNameOrPattern, + EmptyMatchTreatment emptyMatchTreatment, + long minBundleSize, + Mode mode) { + super(fileNameOrPattern, emptyMatchTreatment, minBundleSize); + this.mode = mode; + } + + /** Constructor for SINGLE_FILE_OR_SUBRANGE mode. */ + private AvroSource( + Metadata metadata, long minBundleSize, long startOffset, long endOffset, Mode mode) { + super(metadata, minBundleSize, startOffset, endOffset); + this.mode = mode; + } + + @Override + public void validate() { + super.validate(); + mode.validate(); + } + + /** + * Used by the Dataflow worker. Do not introduce new usages. Do not delete without confirming that + * Dataflow ValidatesRunner tests pass. + * + * @deprecated Used by Dataflow worker + */ + @Deprecated + public BlockBasedSource createForSubrangeOfFile(String fileName, long start, long end) + throws IOException { + return createForSubrangeOfFile(FileSystems.matchSingleFileSpec(fileName), start, end); + } + + @Override + public BlockBasedSource createForSubrangeOfFile(Metadata fileMetadata, long start, long end) { + return new AvroSource<>(fileMetadata, getMinBundleSize(), start, end, mode); + } + + @Override + protected BlockBasedReader createSingleFileReader(PipelineOptions options) { + return new AvroReader<>(this); + } + + @Override + public Coder getOutputCoder() { + return mode.getOutputCoder(); + } + + @VisibleForTesting + @Nullable + String getReaderSchemaString() { + return mode.readerSchemaString; + } + + /** Avro file metadata. */ + @VisibleForTesting + static class AvroMetadata { + private final byte[] syncMarker; + private final String codec; + private final String schemaString; + + AvroMetadata(byte[] syncMarker, String codec, String schemaString) { + this.syncMarker = checkNotNull(syncMarker, "syncMarker"); + this.codec = checkNotNull(codec, "codec"); + this.schemaString = internSchemaString(checkNotNull(schemaString, "schemaString")); + } + + /** + * The JSON-encoded schema + * string for the file. + */ + public String getSchemaString() { + return schemaString; + } + + /** + * The codec of the + * file. + */ + public String getCodec() { + return codec; + } + + /** + * The 16-byte sync marker for the file. See the documentation for Object Container + * File for more information. + */ + public byte[] getSyncMarker() { + return syncMarker; + } + } + + /** + * Reads the {@link AvroMetadata} from the header of an Avro file. + * + *

This method parses the header of an Avro Object Container + * File. + * + * @throws IOException if the file is an invalid format. + */ + @VisibleForTesting + static AvroMetadata readMetadataFromFile(ResourceId fileResource) throws IOException { + String codec = null; + String schemaString = null; + byte[] syncMarker; + try (InputStream stream = Channels.newInputStream(FileSystems.open(fileResource))) { + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(stream, null); + + // The header of an object container file begins with a four-byte magic number, followed + // by the file metadata (including the schema and codec), encoded as a map. Finally, the + // header ends with the file's 16-byte sync marker. + // See https://avro.apache.org/docs/1.7.7/spec.html#Object+Container+Files for details on + // the encoding of container files. + + // Read the magic number. + byte[] magic = new byte[DataFileConstants.MAGIC.length]; + decoder.readFixed(magic); + if (!Arrays.equals(magic, DataFileConstants.MAGIC)) { + throw new IOException("Missing Avro file signature: " + fileResource); + } + + // Read the metadata to find the codec and schema. + ByteBuffer valueBuffer = ByteBuffer.allocate(512); + long numRecords = decoder.readMapStart(); + while (numRecords > 0) { + for (long recordIndex = 0; recordIndex < numRecords; recordIndex++) { + String key = decoder.readString(); + // readBytes() clears the buffer and returns a buffer where: + // - position is the start of the bytes read + // - limit is the end of the bytes read + valueBuffer = decoder.readBytes(valueBuffer); + byte[] bytes = new byte[valueBuffer.remaining()]; + valueBuffer.get(bytes); + if (key.equals(DataFileConstants.CODEC)) { + codec = new String(bytes, StandardCharsets.UTF_8); + } else if (key.equals(DataFileConstants.SCHEMA)) { + schemaString = new String(bytes, StandardCharsets.UTF_8); + } + } + numRecords = decoder.mapNext(); + } + if (codec == null) { + codec = DataFileConstants.NULL_CODEC; + } + + // Finally, read the sync marker. + syncMarker = new byte[DataFileConstants.SYNC_SIZE]; + decoder.readFixed(syncMarker); + } + checkState(schemaString != null, "No schema present in Avro file metadata %s", fileResource); + return new AvroMetadata(syncMarker, codec, schemaString); + } + + // A logical reference cache used to store schemas and schema strings to allow us to + // "intern" values and reduce the number of copies of equivalent objects. + private static final Map schemaLogicalReferenceCache = new WeakHashMap<>(); + private static final Map schemaStringLogicalReferenceCache = new WeakHashMap<>(); + + // We avoid String.intern() because depending on the JVM, these may be added to the PermGenSpace + // which we want to avoid otherwise we could run out of PermGenSpace. + private static synchronized String internSchemaString(String schema) { + String internSchema = schemaStringLogicalReferenceCache.get(schema); + if (internSchema != null) { + return internSchema; + } + schemaStringLogicalReferenceCache.put(schema, schema); + return schema; + } + + static synchronized Schema internOrParseSchemaString(String schemaString) { + Schema schema = schemaLogicalReferenceCache.get(schemaString); + if (schema != null) { + return schema; + } + Schema.Parser parser = new Schema.Parser(); + schema = parser.parse(schemaString); + schemaLogicalReferenceCache.put(schemaString, schema); + return schema; + } + + // Reading the object from Java serialization typically does not go through the constructor, + // we use readResolve to replace the constructed instance with one which uses the constructor + // allowing us to intern any schemas. + @SuppressWarnings("unused") + private Object readResolve() throws ObjectStreamException { + switch (getMode()) { + case SINGLE_FILE_OR_SUBRANGE: + return new AvroSource<>( + getSingleFileMetadata(), getMinBundleSize(), getStartOffset(), getEndOffset(), mode); + case FILEPATTERN: + return new AvroSource<>( + getFileOrPatternSpecProvider(), getEmptyMatchTreatment(), getMinBundleSize(), mode); + default: + throw new InvalidObjectException( + String.format("Unknown mode %s for AvroSource %s", getMode(), this)); + } + } + + /** + * A {@link Block} of Avro records. + * + * @param The type of records stored in the block. + */ + @Experimental(Kind.SOURCE_SINK) + static class AvroBlock extends Block { + + // The current record in the block. Initialized in readNextRecord. + private @Nullable T currentRecord; + + // The index of the current record in the block. + private long currentRecordIndex = 0; + + private final Iterator iterator; + + private final SerializableFunction parseFn; + + private final long numRecordsInBlock; + + AvroBlock( + Iterator iter, SerializableFunction parseFn, long numRecordsInBlock) { + this.iterator = iter; + this.parseFn = parseFn; + this.numRecordsInBlock = numRecordsInBlock; + } + + @Override + public T getCurrentRecord() { + return currentRecord; + } + + @Override + public boolean readNextRecord() { + if (currentRecordIndex >= numRecordsInBlock) { + return false; + } + + Object record = iterator.next(); + currentRecord = (parseFn == null) ? ((T) record) : parseFn.apply((GenericRecord) record); + currentRecordIndex++; + return true; + } + + @Override + public double getFractionOfBlockConsumed() { + return ((double) currentRecordIndex) / numRecordsInBlock; + } + } + + /** + * A {@link BlockBasedReader} for reading blocks from Avro files. + * + *

An Avro Object Container File consists of a header followed by a 16-bit sync marker and then + * a sequence of blocks, where each block begins with two encoded longs representing the total + * number of records in the block and the block's size in bytes, followed by the block's + * (optionally-encoded) records. Each block is terminated by a 16-bit sync marker. + * + * @param The type of records contained in the block. + */ + @Experimental(Kind.SOURCE_SINK) + public static class AvroReader extends BlockBasedReader { + + private static class SeekableChannelInput implements SeekableInput { + + private final SeekableByteChannel channel; + private final InputStream input; + + SeekableChannelInput(SeekableByteChannel channel) { + this.channel = channel; + this.input = Channels.newInputStream(channel); + } + + @Override + public void seek(long p) throws IOException { + channel.position(p); + } + + @Override + public long tell() throws IOException { + return channel.position(); + } + + @Override + public long length() throws IOException { + return channel.size(); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return input.read(b, off, len); + } + + @Override + public void close() throws IOException { + channel.close(); + } + } + + // The current block. + // Initialized in readNextRecord. + private @Nullable AvroBlock currentBlock; + + private @Nullable DataFileReader dataFileReader; + + // A lock used to synchronize block offsets for getRemainingParallelism + private final Object progressLock = new Object(); + + // Offset of the current block. + @GuardedBy("progressLock") + private long currentBlockOffset = 0; + + // Size of the current block. + @GuardedBy("progressLock") + private long currentBlockSizeBytes = 0; + + /** Reads Avro records of type {@code T} from the specified source. */ + public AvroReader(AvroSource source) { + super(source); + } + + @Override + public synchronized AvroSource getCurrentSource() { + return (AvroSource) super.getCurrentSource(); + } + + // Precondition: the stream is positioned after the sync marker in the current (about to be + // previous) block. currentBlockSize equals the size of the current block, or zero if this + // reader was just started. + // + // Postcondition: same as above, but for the new current (formerly next) block. + @Override + public boolean readNextBlock() { + if (!dataFileReader.hasNext()) { + return false; + } + + long headerLength = + (long) VarInt.getLength(dataFileReader.getBlockCount()) + + VarInt.getLength(dataFileReader.getBlockSize()) + + DataFileConstants.SYNC_SIZE; + + currentBlock = + new AvroBlock<>( + dataFileReader, getCurrentSource().mode.parseFn, dataFileReader.getBlockCount()); + + // Atomically update both the position and offset of the new block. + synchronized (progressLock) { + currentBlockOffset = dataFileReader.previousSync(); + // Total block size includes the header, block content, and trailing sync marker. + currentBlockSizeBytes = dataFileReader.getBlockSize() + headerLength; + } + + return true; + } + + @Override + public AvroBlock getCurrentBlock() { + return currentBlock; + } + + @Override + public long getCurrentBlockOffset() { + synchronized (progressLock) { + return currentBlockOffset; + } + } + + @Override + public long getCurrentBlockSize() { + synchronized (progressLock) { + return currentBlockSizeBytes; + } + } + + @Override + public long getSplitPointsRemaining() { + if (isDone()) { + return 0; + } + synchronized (progressLock) { + if (currentBlockOffset + currentBlockSizeBytes >= getCurrentSource().getEndOffset()) { + // This block is known to be the last block in the range. + return 1; + } + } + return super.getSplitPointsRemaining(); + } + + // Postcondition: the stream is positioned at the beginning of the first block after the start + // of the current source, and currentBlockOffset is that position. Additionally, + // currentBlockSizeBytes will be set to 0 indicating that the previous block was empty. + @Override + protected void startReading(ReadableByteChannel channel) throws IOException { + SeekableChannelInput seekableChannelInput = + new SeekableChannelInput((SeekableByteChannel) channel); + // the channel needs to be at the beginning of the file in order for the DataFileReader to + // read the header, etc, we'll seek it back to where it should be after creating the DFR. + seekableChannelInput.seek(0); + + Schema readerSchema = null; + String readerSchemaString = this.getCurrentSource().getReaderSchemaString(); + if (readerSchemaString != null) { + readerSchema = AvroSource.internOrParseSchemaString(readerSchemaString); + } + // the DataFileReader will call setSchema with the writer schema when created. + DatumReader reader = this.getCurrentSource().mode.createReader(readerSchema, readerSchema); + + dataFileReader = new DataFileReader<>(seekableChannelInput, reader); + + long startOffset = getCurrentSource().getStartOffset(); + if (startOffset != 0) { + // the start offset may be in the middle of a sync marker, by rewinding SYNC_SIZE bytes we + // ensure that we won't miss the block if so. + dataFileReader.sync(Math.max(0, startOffset - DataFileConstants.SYNC_SIZE)); + } + + synchronized (progressLock) { + currentBlockOffset = dataFileReader.previousSync(); + currentBlockSizeBytes = 0; + } + } + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/ConstantAvroDestination.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/ConstantAvroDestination.java new file mode 100644 index 000000000000..601c65935bec --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/ConstantAvroDestination.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import java.io.Serializable; +import java.util.Map; +import org.apache.avro.Schema; +import org.apache.avro.file.CodecFactory; +import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.display.HasDisplayData; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.BaseEncoding; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Always returns a constant {@link FilenamePolicy}, {@link Schema}, metadata, and codec. */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +class ConstantAvroDestination + extends DynamicAvroDestinations { + private static class SchemaFunction implements Serializable, Function { + @Override + public Schema apply(String input) { + return new Schema.Parser().parse(input); + } + } + + // This should be a multiple of 4 to not get a partial encoded byte. + private static final int METADATA_BYTES_MAX_LENGTH = 40; + private final FilenamePolicy filenamePolicy; + private final Supplier schema; + private final Map metadata; + private final SerializableAvroCodecFactory codec; + private final SerializableFunction formatFunction; + private final AvroSink.DatumWriterFactory datumWriterFactory; + + private class Metadata implements HasDisplayData { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + for (Map.Entry entry : metadata.entrySet()) { + DisplayData.Type type = DisplayData.inferType(entry.getValue()); + if (type != null) { + builder.add(DisplayData.item(entry.getKey(), type, entry.getValue())); + } else { + String base64 = BaseEncoding.base64().encode((byte[]) entry.getValue()); + String repr = + base64.length() <= METADATA_BYTES_MAX_LENGTH + ? base64 + : base64.substring(0, METADATA_BYTES_MAX_LENGTH) + "..."; + builder.add(DisplayData.item(entry.getKey(), repr)); + } + } + } + } + + public ConstantAvroDestination( + FilenamePolicy filenamePolicy, + Schema schema, + Map metadata, + CodecFactory codec, + SerializableFunction formatFunction) { + this(filenamePolicy, schema, metadata, codec, formatFunction, null); + } + + public ConstantAvroDestination( + FilenamePolicy filenamePolicy, + Schema schema, + Map metadata, + CodecFactory codec, + SerializableFunction formatFunction, + AvroSink.@Nullable DatumWriterFactory datumWriterFactory) { + this.filenamePolicy = filenamePolicy; + this.schema = Suppliers.compose(new SchemaFunction(), Suppliers.ofInstance(schema.toString())); + this.metadata = metadata; + this.codec = new SerializableAvroCodecFactory(codec); + this.formatFunction = formatFunction; + this.datumWriterFactory = datumWriterFactory; + } + + @Override + public OutputT formatRecord(UserT record) { + return formatFunction.apply(record); + } + + @Override + public @Nullable Void getDestination(UserT element) { + return (Void) null; + } + + @Override + public @Nullable Void getDefaultDestination() { + return (Void) null; + } + + @Override + public FilenamePolicy getFilenamePolicy(Void destination) { + return filenamePolicy; + } + + @Override + public Schema getSchema(Void destination) { + return schema.get(); + } + + @Override + public Map getMetadata(Void destination) { + return metadata; + } + + @Override + public CodecFactory getCodec(Void destination) { + return codec.getCodec(); + } + + @Override + public AvroSink.@Nullable DatumWriterFactory getDatumWriterFactory(Void destination) { + return datumWriterFactory; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + filenamePolicy.populateDisplayData(builder); + builder.add(DisplayData.item("schema", schema.get().toString()).withLabel("Record Schema")); + builder.addIfNotDefault( + DisplayData.item("codec", codec.getCodec().toString()).withLabel("Avro Compression Codec"), + AvroIO.TypedWrite.DEFAULT_SERIALIZABLE_CODEC.toString()); + builder.include("Metadata", new Metadata()); + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/DynamicAvroDestinations.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/DynamicAvroDestinations.java new file mode 100644 index 000000000000..c261ee47d4e7 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/DynamicAvroDestinations.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import java.util.Map; +import org.apache.avro.Schema; +import org.apache.avro.file.CodecFactory; +import org.apache.beam.sdk.io.FileBasedSink.DynamicDestinations; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A specialization of {@link DynamicDestinations} for {@link org.apache.beam.sdk.io.AvroIO}. In + * addition to dynamic file destinations, this allows specifying other AVRO properties (schema, + * metadata, codec, datum writer) per destination. + */ +public abstract class DynamicAvroDestinations + extends DynamicDestinations { + /** Return an AVRO schema for a given destination. */ + public abstract Schema getSchema(DestinationT destination); + + /** Return AVRO file metadata for a given destination. */ + public Map getMetadata(DestinationT destination) { + return ImmutableMap.of(); + } + + /** Return an AVRO codec for a given destination. */ + public CodecFactory getCodec(DestinationT destination) { + return AvroIO.TypedWrite.DEFAULT_CODEC; + } + + /** + * Return a {@link org.apache.beam.sdk.io.AvroSink.DatumWriterFactory} for a given destination. If + * provided, it will be used to created {@link org.apache.avro.io.DatumWriter} instances as + * required. + */ + public AvroSink.@Nullable DatumWriterFactory getDatumWriterFactory( + DestinationT destinationT) { + return null; + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/SerializableAvroCodecFactory.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/SerializableAvroCodecFactory.java new file mode 100644 index 000000000000..8a82ffcbcd42 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/SerializableAvroCodecFactory.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import static org.apache.avro.file.DataFileConstants.BZIP2_CODEC; +import static org.apache.avro.file.DataFileConstants.DEFLATE_CODEC; +import static org.apache.avro.file.DataFileConstants.NULL_CODEC; +import static org.apache.avro.file.DataFileConstants.SNAPPY_CODEC; +import static org.apache.avro.file.DataFileConstants.XZ_CODEC; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Arrays; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.apache.avro.file.CodecFactory; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A wrapper that allows {@link CodecFactory}s to be serialized using Java's standard serialization + * mechanisms. + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +class SerializableAvroCodecFactory implements Externalizable { + private static final long serialVersionUID = 7445324844109564303L; + private static final List noOptAvroCodecs = + Arrays.asList(NULL_CODEC, SNAPPY_CODEC, BZIP2_CODEC); + private static final Pattern deflatePattern = Pattern.compile(DEFLATE_CODEC + "-(?-?\\d)"); + private static final Pattern xzPattern = Pattern.compile(XZ_CODEC + "-(?\\d)"); + + private @Nullable CodecFactory codecFactory; + + // For java.io.Externalizable + public SerializableAvroCodecFactory() {} + + public SerializableAvroCodecFactory(CodecFactory codecFactory) { + checkNotNull(codecFactory, "Codec can't be null"); + checkState(checkIsSupportedCodec(codecFactory), "%s is not supported", codecFactory); + this.codecFactory = codecFactory; + } + + private boolean checkIsSupportedCodec(CodecFactory codecFactory) { + final String codecStr = codecFactory.toString(); + return noOptAvroCodecs.contains(codecStr) + || deflatePattern.matcher(codecStr).matches() + || xzPattern.matcher(codecStr).matches(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeUTF(codecFactory.toString()); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + final String codecStr = in.readUTF(); + + switch (codecStr) { + case NULL_CODEC: + case SNAPPY_CODEC: + case BZIP2_CODEC: + codecFactory = CodecFactory.fromString(codecStr); + return; + } + + Matcher deflateMatcher = deflatePattern.matcher(codecStr); + if (deflateMatcher.find()) { + codecFactory = CodecFactory.deflateCodec(Integer.parseInt(deflateMatcher.group("level"))); + return; + } + + Matcher xzMatcher = xzPattern.matcher(codecStr); + if (xzMatcher.find()) { + codecFactory = CodecFactory.xzCodec(Integer.parseInt(xzMatcher.group("level"))); + return; + } + + throw new IllegalStateException(codecStr + " is not supported"); + } + + public CodecFactory getCodec() { + return codecFactory; + } + + @Override + public String toString() { + checkNotNull(codecFactory, "Inner CodecFactory is null, please use non default constructor"); + return codecFactory.toString(); + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/package-info.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/package-info.java new file mode 100644 index 000000000000..8d6938347a44 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/package-info.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** Defines transforms for reading and writing Avro storage format. */ +@DefaultAnnotation(NonNull.class) +@Experimental(Kind.EXTENSION) +package org.apache.beam.sdk.extensions.avro.io; + +import edu.umd.cs.findbugs.annotations.DefaultAnnotation; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.checkerframework.checker.nullness.qual.NonNull; diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java new file mode 100644 index 000000000000..12b81be54c13 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.schemas; + +import static org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils.toBeamSchema; + +import java.util.List; +import org.apache.avro.reflect.ReflectData; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.FieldValueTypeInformation; +import org.apache.beam.sdk.schemas.GetterBasedSchemaProvider; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaProvider; +import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; +import org.apache.beam.sdk.values.TypeDescriptor; + +/** + * A {@link SchemaProvider} for AVRO generated SpecificRecords and POJOs. + * + *

This provider infers a schema from generated SpecificRecord objects, and creates schemas and + * rows that bind to the appropriate fields. This provider also infers schemas from Java POJO + * objects, creating a schema that matches that inferred by the AVRO libraries. + */ +@SuppressWarnings({ + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) +}) +public class AvroRecordSchema extends GetterBasedSchemaProvider { + @Override + public Schema schemaFor(TypeDescriptor typeDescriptor) { + return toBeamSchema(ReflectData.get().getSchema(typeDescriptor.getRawType())); + } + + @Override + public List fieldValueGetters(Class targetClass, Schema schema) { + return AvroUtils.getGetters(targetClass, schema); + } + + @Override + public List fieldValueTypeInformations( + Class targetClass, Schema schema) { + return AvroUtils.getFieldTypes(targetClass, schema); + } + + @Override + public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema schema) { + return AvroUtils.getCreator(targetClass, schema); + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/io/payloads/AvroPayloadSerializerProvider.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/io/payloads/AvroPayloadSerializerProvider.java new file mode 100644 index 000000000000..7245d4a75e0d --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/io/payloads/AvroPayloadSerializerProvider.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.schemas.io.payloads; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer; +import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializerProvider; + +@Internal +@Experimental(Kind.SCHEMAS) +@AutoService(PayloadSerializerProvider.class) +public class AvroPayloadSerializerProvider implements PayloadSerializerProvider { + @Override + public String identifier() { + return "avro"; + } + + @Override + public PayloadSerializer getSerializer(Schema schema, Map tableParams) { + return PayloadSerializer.of( + AvroUtils.getRowToAvroBytesFunction(schema), AvroUtils.getAvroBytesToRowFunction(schema)); + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/io/payloads/package-info.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/io/payloads/package-info.java new file mode 100644 index 000000000000..01d48f89ec72 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/io/payloads/package-info.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Provides abstractions for schema-aware AvroIO. */ +@DefaultAnnotation(NonNull.class) +@Experimental(Kind.EXTENSION) +package org.apache.beam.sdk.extensions.avro.schemas.io.payloads; + +import edu.umd.cs.findbugs.annotations.DefaultAnnotation; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.checkerframework.checker.nullness.qual.NonNull; diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/package-info.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/package-info.java new file mode 100644 index 000000000000..6428c686400e --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/package-info.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Defines {@link org.apache.beam.sdk.schemas.Schema} and other classes for representing schema'd + * data in a {@link org.apache.beam.sdk.Pipeline} using Apache Avro. + */ +@DefaultAnnotation(NonNull.class) +@Experimental(Kind.SCHEMAS) +package org.apache.beam.sdk.extensions.avro.schemas; + +import edu.umd.cs.findbugs.annotations.DefaultAnnotation; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.checkerframework.checker.nullness.qual.NonNull; diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java new file mode 100644 index 000000000000..a7ff6a581e68 --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.schemas.utils; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Type; +import java.util.Map; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.asm.AsmVisitorWrapper; +import net.bytebuddy.description.type.TypeDescription.ForLoadedType; +import net.bytebuddy.dynamic.DynamicType; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.MethodCall; +import net.bytebuddy.implementation.bytecode.StackManipulation; +import net.bytebuddy.implementation.bytecode.assign.TypeCasting; +import net.bytebuddy.implementation.bytecode.collection.ArrayAccess; +import net.bytebuddy.implementation.bytecode.constant.IntegerConstant; +import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; +import net.bytebuddy.jar.asm.ClassWriter; +import net.bytebuddy.matcher.ElementMatchers; +import org.apache.avro.specific.SpecificRecord; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.InjectPackageStrategy; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversion; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversionsFactory; +import org.apache.beam.sdk.schemas.utils.ReflectUtils.ClassWithSchema; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; + +@Experimental(Kind.SCHEMAS) +@SuppressWarnings({ + "nullness", // TODO(https://github.com/apache/beam/issues/20497) + "rawtypes" +}) +class AvroByteBuddyUtils { + private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); + + // Cache the generated constructors. + private static final Map CACHED_CREATORS = + Maps.newConcurrentMap(); + + static SchemaUserTypeCreator getCreator( + Class clazz, Schema schema) { + return CACHED_CREATORS.computeIfAbsent( + ClassWithSchema.create(clazz, schema), c -> createCreator(clazz, schema)); + } + + private static SchemaUserTypeCreator createCreator(Class clazz, Schema schema) { + Constructor baseConstructor = null; + Constructor[] constructors = clazz.getDeclaredConstructors(); + for (Constructor constructor : constructors) { + // TODO: This assumes that Avro only generates one constructor with this many fields. + if (constructor.getParameterCount() == schema.getFieldCount()) { + baseConstructor = constructor; + } + } + if (baseConstructor == null) { + throw new RuntimeException("No matching constructor found for class " + clazz); + } + + // Generate a method call to create and invoke the SpecificRecord's constructor. . + MethodCall construct = MethodCall.construct(baseConstructor); + for (int i = 0; i < baseConstructor.getParameterTypes().length; ++i) { + Class baseType = baseConstructor.getParameterTypes()[i]; + construct = construct.with(readAndConvertParameter(baseType, i), baseType); + } + + try { + DynamicType.Builder builder = + BYTE_BUDDY + .with(new InjectPackageStrategy(clazz)) + .subclass(SchemaUserTypeCreator.class) + .method(ElementMatchers.named("create")) + .intercept(construct); + + return builder + .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) + .make() + .load( + ReflectHelpers.findClassLoader(clazz.getClassLoader()), + ClassLoadingStrategy.Default.INJECTION) + .getLoaded() + .getDeclaredConstructor() + .newInstance(); + } catch (InstantiationException + | IllegalAccessException + | NoSuchMethodException + | InvocationTargetException e) { + throw new RuntimeException( + "Unable to generate a getter for class " + clazz + " with schema " + schema); + } + } + + private static StackManipulation readAndConvertParameter( + Class constructorParameterType, int index) { + TypeConversionsFactory typeConversionsFactory = new AvroUtils.AvroTypeConversionFactory(); + + // The types in the AVRO-generated constructor might be the types returned by Beam's Row class, + // so we have to convert the types used by Beam's Row class. + // We know that AVRO generates constructor parameters in the same order as fields + // in the schema, so we can just add the parameters sequentially. + TypeConversion convertType = typeConversionsFactory.createTypeConversion(true); + + // Map the AVRO-generated type to the one Beam will use. + ForLoadedType convertedType = + new ForLoadedType((Class) convertType.convert(TypeDescriptor.of(constructorParameterType))); + + // This will run inside the generated creator. Read the parameter and convert it to the + // type required by the SpecificRecord constructor. + StackManipulation readParameter = + new StackManipulation.Compound( + MethodVariableAccess.REFERENCE.loadFrom(1), + IntegerConstant.forValue(index), + ArrayAccess.REFERENCE.load(), + TypeCasting.to(convertedType)); + + // Convert to the parameter accepted by the SpecificRecord constructor. + return typeConversionsFactory + .createSetterConversions(readParameter) + .convert(TypeDescriptor.of(constructorParameterType)); + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java new file mode 100644 index 000000000000..60c3465a3c7b --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -0,0 +1,1341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.schemas.utils; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import net.bytebuddy.description.type.TypeDescription.ForLoadedType; +import net.bytebuddy.implementation.bytecode.Duplication; +import net.bytebuddy.implementation.bytecode.StackManipulation; +import net.bytebuddy.implementation.bytecode.StackManipulation.Compound; +import net.bytebuddy.implementation.bytecode.TypeCreation; +import net.bytebuddy.implementation.bytecode.assign.TypeCasting; +import net.bytebuddy.implementation.bytecode.member.MethodInvocation; +import net.bytebuddy.matcher.ElementMatchers; +import org.apache.avro.AvroRuntimeException; +import org.apache.avro.Conversions; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema.Type; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericFixed; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.reflect.AvroIgnore; +import org.apache.avro.reflect.AvroName; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.specific.SpecificData; +import org.apache.avro.specific.SpecificRecord; +import org.apache.avro.util.Utf8; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.extensions.avro.schemas.AvroRecordSchema; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.FieldValueTypeInformation; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes; +import org.apache.beam.sdk.schemas.logicaltypes.FixedString; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; +import org.apache.beam.sdk.schemas.logicaltypes.VariableBytes; +import org.apache.beam.sdk.schemas.logicaltypes.VariableString; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertType; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertValueForGetter; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertValueForSetter; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversion; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversionsFactory; +import org.apache.beam.sdk.schemas.utils.FieldValueTypeSupplier; +import org.apache.beam.sdk.schemas.utils.JavaBeanUtils; +import org.apache.beam.sdk.schemas.utils.POJOUtils; +import org.apache.beam.sdk.schemas.utils.ReflectUtils; +import org.apache.beam.sdk.schemas.utils.StaticSchemaInference; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.CaseFormat; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Days; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.joda.time.ReadableInstant; + +/** Utils to convert AVRO records to Beam rows. */ +@Experimental(Kind.SCHEMAS) +@SuppressWarnings({ + "nullness", // TODO(https://github.com/apache/beam/issues/20497) + "rawtypes" +}) +public class AvroUtils { + static { + // This works around a bug in the Avro library (AVRO-1891) around SpecificRecord's handling + // of DateTime types. + SpecificData.get().addLogicalTypeConversion(new AvroCoder.JodaTimestampConversion()); + GenericData.get().addLogicalTypeConversion(new AvroCoder.JodaTimestampConversion()); + } + + // Unwrap an AVRO schema into the base type an whether it is nullable. + static class TypeWithNullability { + public final org.apache.avro.Schema type; + public final boolean nullable; + + TypeWithNullability(org.apache.avro.Schema avroSchema) { + if (avroSchema.getType() == Type.UNION) { + List types = avroSchema.getTypes(); + + // optional fields in AVRO have form of: + // {"name": "foo", "type": ["null", "something"]} + + // don't need recursion because nested unions aren't supported in AVRO + List nonNullTypes = + types.stream().filter(x -> x.getType() != Type.NULL).collect(Collectors.toList()); + + if (nonNullTypes.size() == types.size() || nonNullTypes.isEmpty()) { + // union without `null` or all 'null' union, keep as is. + type = avroSchema; + nullable = false; + } else if (nonNullTypes.size() > 1) { + type = org.apache.avro.Schema.createUnion(nonNullTypes); + nullable = true; + } else { + // One non-null type. + type = nonNullTypes.get(0); + nullable = true; + } + } else { + type = avroSchema; + nullable = false; + } + } + } + + /** Wrapper for fixed byte fields. */ + public static class FixedBytesField { + private final int size; + + private FixedBytesField(int size) { + this.size = size; + } + + /** Create a {@link FixedBytesField} with the specified size. */ + public static FixedBytesField withSize(int size) { + return new FixedBytesField(size); + } + + /** Create a {@link FixedBytesField} from a Beam {@link FieldType}. */ + public static @Nullable FixedBytesField fromBeamFieldType(FieldType fieldType) { + if (fieldType.getTypeName().isLogicalType() + && fieldType.getLogicalType().getIdentifier().equals(FixedBytes.IDENTIFIER)) { + int length = fieldType.getLogicalType(FixedBytes.class).getLength(); + return new FixedBytesField(length); + } else { + return null; + } + } + + /** Create a {@link FixedBytesField} from an AVRO type. */ + public static @Nullable FixedBytesField fromAvroType(org.apache.avro.Schema type) { + if (type.getType().equals(Type.FIXED)) { + return new FixedBytesField(type.getFixedSize()); + } else { + return null; + } + } + + /** Get the size. */ + public int getSize() { + return size; + } + + /** Convert to a Beam type. */ + public FieldType toBeamType() { + return FieldType.logicalType(FixedBytes.of(size)); + } + + /** Convert to an AVRO type. */ + public org.apache.avro.Schema toAvroType(String name, String namespace) { + return org.apache.avro.Schema.createFixed(name, null, namespace, size); + } + } + + public static class AvroConvertType extends ConvertType { + public AvroConvertType(boolean returnRawType) { + super(returnRawType); + } + + @Override + protected java.lang.reflect.Type convertDefault(TypeDescriptor type) { + if (type.isSubtypeOf(TypeDescriptor.of(GenericFixed.class))) { + return byte[].class; + } else { + return super.convertDefault(type); + } + } + } + + public static class AvroConvertValueForGetter extends ConvertValueForGetter { + AvroConvertValueForGetter(StackManipulation readValue) { + super(readValue); + } + + @Override + protected TypeConversionsFactory getFactory() { + return new AvroTypeConversionFactory(); + } + + @Override + protected StackManipulation convertDefault(TypeDescriptor type) { + if (type.isSubtypeOf(TypeDescriptor.of(GenericFixed.class))) { + // Generate the following code: + // return value.bytes(); + return new Compound( + readValue, + MethodInvocation.invoke( + new ForLoadedType(GenericFixed.class) + .getDeclaredMethods() + .filter( + ElementMatchers.named("bytes") + .and(ElementMatchers.returns(new ForLoadedType(byte[].class)))) + .getOnly())); + } + return super.convertDefault(type); + } + } + + public static class AvroConvertValueForSetter extends ConvertValueForSetter { + AvroConvertValueForSetter(StackManipulation readValue) { + super(readValue); + } + + @Override + protected TypeConversionsFactory getFactory() { + return new AvroTypeConversionFactory(); + } + + @Override + protected StackManipulation convertDefault(TypeDescriptor type) { + final ForLoadedType byteArrayType = new ForLoadedType(byte[].class); + if (type.isSubtypeOf(TypeDescriptor.of(GenericFixed.class))) { + // Generate the following code: + // return new T((byte[]) value); + ForLoadedType loadedType = new ForLoadedType(type.getRawType()); + return new Compound( + TypeCreation.of(loadedType), + Duplication.SINGLE, + // Load the parameter and cast it to a byte[]. + readValue, + TypeCasting.to(byteArrayType), + // Create a new instance that wraps this byte[]. + MethodInvocation.invoke( + loadedType + .getDeclaredMethods() + .filter( + ElementMatchers.isConstructor() + .and(ElementMatchers.takesArguments(byteArrayType))) + .getOnly())); + } + return super.convertDefault(type); + } + } + + static class AvroTypeConversionFactory implements TypeConversionsFactory { + + @Override + public TypeConversion createTypeConversion(boolean returnRawTypes) { + return new AvroConvertType(returnRawTypes); + } + + @Override + public TypeConversion createGetterConversions(StackManipulation readValue) { + return new AvroConvertValueForGetter(readValue); + } + + @Override + public TypeConversion createSetterConversions(StackManipulation readValue) { + return new AvroConvertValueForSetter(readValue); + } + } + + /** Get Beam Field from avro Field. */ + public static Field toBeamField(org.apache.avro.Schema.Field field) { + TypeWithNullability nullableType = new TypeWithNullability(field.schema()); + FieldType beamFieldType = toFieldType(nullableType); + return Field.of(field.name(), beamFieldType); + } + + /** Get Avro Field from Beam Field. */ + public static org.apache.avro.Schema.Field toAvroField(Field field, String namespace) { + org.apache.avro.Schema fieldSchema = + getFieldSchema(field.getType(), field.getName(), namespace); + return new org.apache.avro.Schema.Field( + field.getName(), fieldSchema, field.getDescription(), (Object) null); + } + + private AvroUtils() {} + + /** + * Converts AVRO schema to Beam row schema. + * + * @param schema schema of type RECORD + */ + public static Schema toBeamSchema(org.apache.avro.Schema schema) { + Schema.Builder builder = Schema.builder(); + + for (org.apache.avro.Schema.Field field : schema.getFields()) { + Field beamField = toBeamField(field); + if (field.doc() != null) { + beamField = beamField.withDescription(field.doc()); + } + builder.addField(beamField); + } + + return builder.build(); + } + + /** Converts a Beam Schema into an AVRO schema. */ + public static org.apache.avro.Schema toAvroSchema( + Schema beamSchema, @Nullable String name, @Nullable String namespace) { + final String schemaName = Strings.isNullOrEmpty(name) ? "topLevelRecord" : name; + final String schemaNamespace = namespace == null ? "" : namespace; + String childNamespace = + !"".equals(schemaNamespace) ? schemaNamespace + "." + schemaName : schemaName; + List fields = Lists.newArrayList(); + for (Field field : beamSchema.getFields()) { + org.apache.avro.Schema.Field recordField = toAvroField(field, childNamespace); + fields.add(recordField); + } + return org.apache.avro.Schema.createRecord(schemaName, null, schemaNamespace, false, fields); + } + + public static org.apache.avro.Schema toAvroSchema(Schema beamSchema) { + return toAvroSchema(beamSchema, null, null); + } + + /** + * Strict conversion from AVRO to Beam, strict because it doesn't do widening or narrowing during + * conversion. If Schema is not provided, one is inferred from the AVRO schema. + */ + public static Row toBeamRowStrict(GenericRecord record, @Nullable Schema schema) { + if (schema == null) { + schema = toBeamSchema(record.getSchema()); + } + + Row.Builder builder = Row.withSchema(schema); + org.apache.avro.Schema avroSchema = record.getSchema(); + + for (Field field : schema.getFields()) { + Object value = record.get(field.getName()); + org.apache.avro.Schema fieldAvroSchema = avroSchema.getField(field.getName()).schema(); + builder.addValue(convertAvroFieldStrict(value, fieldAvroSchema, field.getType())); + } + + return builder.build(); + } + + /** + * Convert from a Beam Row to an AVRO GenericRecord. The Avro Schema is inferred from the Beam + * schema on the row. + */ + public static GenericRecord toGenericRecord(Row row) { + return toGenericRecord(row, null); + } + + /** + * Convert from a Beam Row to an AVRO GenericRecord. If a Schema is not provided, one is inferred + * from the Beam schema on the row. + */ + public static GenericRecord toGenericRecord( + Row row, org.apache.avro.@Nullable Schema avroSchema) { + Schema beamSchema = row.getSchema(); + // Use the provided AVRO schema if present, otherwise infer an AVRO schema from the row + // schema. + if (avroSchema != null && avroSchema.getFields().size() != beamSchema.getFieldCount()) { + throw new IllegalArgumentException( + "AVRO schema doesn't match row schema. Row schema " + + beamSchema + + ". AVRO schema + " + + avroSchema); + } + if (avroSchema == null) { + avroSchema = toAvroSchema(beamSchema); + } + + GenericRecordBuilder builder = new GenericRecordBuilder(avroSchema); + for (int i = 0; i < beamSchema.getFieldCount(); ++i) { + Field field = beamSchema.getField(i); + builder.set( + field.getName(), + genericFromBeamField( + field.getType(), avroSchema.getField(field.getName()).schema(), row.getValue(i))); + } + return builder.build(); + } + + @SuppressWarnings("unchecked") + public static SerializableFunction getToRowFunction( + Class clazz, org.apache.avro.@Nullable Schema schema) { + if (GenericRecord.class.equals(clazz)) { + Schema beamSchema = toBeamSchema(schema); + return (SerializableFunction) getGenericRecordToRowFunction(beamSchema); + } else { + return new AvroRecordSchema().toRowFunction(TypeDescriptor.of(clazz)); + } + } + + @SuppressWarnings("unchecked") + public static SerializableFunction getFromRowFunction(Class clazz) { + return GenericRecord.class.equals(clazz) + ? (SerializableFunction) getRowToGenericRecordFunction(null) + : new AvroRecordSchema().fromRowFunction(TypeDescriptor.of(clazz)); + } + + public static @Nullable Schema getSchema( + Class clazz, org.apache.avro.@Nullable Schema schema) { + if (schema != null) { + return schema.getType().equals(Type.RECORD) ? toBeamSchema(schema) : null; + } + if (GenericRecord.class.equals(clazz)) { + throw new IllegalArgumentException("No schema provided for getSchema(GenericRecord)"); + } + return new AvroRecordSchema().schemaFor(TypeDescriptor.of(clazz)); + } + + /** Returns a function mapping encoded AVRO {@link GenericRecord}s to Beam {@link Row}s. */ + public static SimpleFunction getAvroBytesToRowFunction(Schema beamSchema) { + return new AvroBytesToRowFn(beamSchema); + } + + private static class AvroBytesToRowFn extends SimpleFunction { + private final AvroCoder coder; + private final Schema beamSchema; + + AvroBytesToRowFn(Schema beamSchema) { + org.apache.avro.Schema avroSchema = toAvroSchema(beamSchema); + coder = AvroCoder.of(avroSchema); + this.beamSchema = beamSchema; + } + + @Override + public Row apply(byte[] bytes) { + try { + ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes); + GenericRecord record = coder.decode(inputStream); + return AvroUtils.toBeamRowStrict(record, beamSchema); + } catch (Exception e) { + throw new AvroRuntimeException( + "Could not decode avro record from given bytes " + + new String(bytes, StandardCharsets.UTF_8), + e); + } + } + } + + /** Returns a function mapping Beam {@link Row}s to encoded AVRO {@link GenericRecord}s. */ + public static SimpleFunction getRowToAvroBytesFunction(Schema beamSchema) { + return new RowToAvroBytesFn(beamSchema); + } + + private static class RowToAvroBytesFn extends SimpleFunction { + private final transient org.apache.avro.Schema avroSchema; + private final AvroCoder coder; + + RowToAvroBytesFn(Schema beamSchema) { + avroSchema = toAvroSchema(beamSchema); + coder = AvroCoder.of(avroSchema); + } + + @Override + public byte[] apply(Row row) { + try { + GenericRecord record = toGenericRecord(row, avroSchema); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + coder.encode(record, outputStream); + return outputStream.toByteArray(); + } catch (Exception e) { + throw new AvroRuntimeException( + String.format("Could not encode avro from given row: %s", row), e); + } + } + } + + /** + * Returns a function mapping AVRO {@link GenericRecord}s to Beam {@link Row}s for use in {@link + * org.apache.beam.sdk.values.PCollection#setSchema}. + */ + public static SerializableFunction getGenericRecordToRowFunction( + @Nullable Schema schema) { + return new GenericRecordToRowFn(schema); + } + + private static class GenericRecordToRowFn implements SerializableFunction { + private final Schema schema; + + GenericRecordToRowFn(Schema schema) { + this.schema = schema; + } + + @Override + public Row apply(GenericRecord input) { + return toBeamRowStrict(input, schema); + } + + @Override + public boolean equals(@Nullable Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + GenericRecordToRowFn that = (GenericRecordToRowFn) other; + return Objects.equals(this.schema, that.schema); + } + + @Override + public int hashCode() { + return Objects.hash(schema); + } + } + + /** + * Returns a function mapping Beam {@link Row}s to AVRO {@link GenericRecord}s for use in {@link + * org.apache.beam.sdk.values.PCollection#setSchema}. + */ + public static SerializableFunction getRowToGenericRecordFunction( + org.apache.avro.@Nullable Schema avroSchema) { + return new RowToGenericRecordFn(avroSchema); + } + + private static class RowToGenericRecordFn implements SerializableFunction { + private transient org.apache.avro.Schema avroSchema; + + RowToGenericRecordFn(org.apache.avro.@Nullable Schema avroSchema) { + this.avroSchema = avroSchema; + } + + @Override + public GenericRecord apply(Row input) { + return toGenericRecord(input, avroSchema); + } + + @Override + public boolean equals(@Nullable Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + RowToGenericRecordFn that = (RowToGenericRecordFn) other; + return Objects.equals(this.avroSchema, that.avroSchema); + } + + @Override + public int hashCode() { + return Objects.hash(avroSchema); + } + + private void writeObject(ObjectOutputStream out) throws IOException { + final String avroSchemaAsString = (avroSchema == null) ? null : avroSchema.toString(); + out.writeObject(avroSchemaAsString); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + final String avroSchemaAsString = (String) in.readObject(); + avroSchema = + (avroSchemaAsString == null) + ? null + : new org.apache.avro.Schema.Parser().parse(avroSchemaAsString); + } + } + + /** + * Returns an {@code SchemaCoder} instance for the provided element type. + * + * @param the element type + */ + public static SchemaCoder schemaCoder(TypeDescriptor type) { + @SuppressWarnings("unchecked") + Class clazz = (Class) type.getRawType(); + org.apache.avro.Schema avroSchema = new ReflectData(clazz.getClassLoader()).getSchema(clazz); + Schema beamSchema = toBeamSchema(avroSchema); + return SchemaCoder.of( + beamSchema, type, getToRowFunction(clazz, avroSchema), getFromRowFunction(clazz)); + } + + /** + * Returns an {@code SchemaCoder} instance for the provided element class. + * + * @param the element type + */ + public static SchemaCoder schemaCoder(Class clazz) { + return schemaCoder(TypeDescriptor.of(clazz)); + } + + /** + * Returns an {@code SchemaCoder} instance for the Avro schema. The implicit type is + * GenericRecord. + */ + public static SchemaCoder schemaCoder(org.apache.avro.Schema schema) { + Schema beamSchema = toBeamSchema(schema); + return SchemaCoder.of( + beamSchema, + TypeDescriptor.of(GenericRecord.class), + getGenericRecordToRowFunction(beamSchema), + getRowToGenericRecordFunction(schema)); + } + + /** + * Returns an {@code SchemaCoder} instance for the provided element type using the provided Avro + * schema. + * + *

If the type argument is GenericRecord, the schema may be arbitrary. Otherwise, the schema + * must correspond to the type provided. + * + * @param the element type + */ + public static SchemaCoder schemaCoder(Class clazz, org.apache.avro.Schema schema) { + return SchemaCoder.of( + getSchema(clazz, schema), + TypeDescriptor.of(clazz), + getToRowFunction(clazz, schema), + getFromRowFunction(clazz)); + } + + /** + * Returns an {@code SchemaCoder} instance based on the provided AvroCoder for the element type. + * + * @param the element type + */ + public static SchemaCoder schemaCoder(AvroCoder avroCoder) { + return schemaCoder(avroCoder.getType(), avroCoder.getSchema()); + } + + private static final class AvroSpecificRecordFieldValueTypeSupplier + implements FieldValueTypeSupplier { + @Override + public List get(Class clazz) { + throw new RuntimeException("Unexpected call."); + } + + @Override + public List get(Class clazz, Schema schema) { + Map mapping = getMapping(schema); + List methods = ReflectUtils.getMethods(clazz); + List types = Lists.newArrayList(); + for (int i = 0; i < methods.size(); ++i) { + Method method = methods.get(i); + if (ReflectUtils.isGetter(method)) { + FieldValueTypeInformation fieldValueTypeInformation = + FieldValueTypeInformation.forGetter(method, i); + String name = mapping.get(fieldValueTypeInformation.getName()); + if (name != null) { + types.add(fieldValueTypeInformation.withName(name)); + } + } + } + + // Return the list ordered by the schema fields. + return StaticSchemaInference.sortBySchema(types, schema); + } + + private Map getMapping(Schema schema) { + Map mapping = Maps.newHashMap(); + for (Field field : schema.getFields()) { + String fieldName = field.getName(); + String getter; + if (fieldName.contains("_")) { + if (Character.isLowerCase(fieldName.charAt(0))) { + // field_name -> fieldName + getter = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, fieldName); + } else { + // FIELD_NAME -> fIELDNAME + // must remove underscore and then convert to match compiled Avro schema getter name + getter = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, fieldName.replace("_", "")); + } + } else if (Character.isUpperCase(fieldName.charAt(0))) { + // FieldName -> fieldName + getter = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, fieldName); + } else { + // If the field is in camel case already, then it's the identity mapping. + getter = fieldName; + } + mapping.put(getter, fieldName); + // The Avro compiler might add a $ at the end of a getter to disambiguate. + mapping.put(getter + "$", fieldName); + } + return mapping; + } + } + + private static final class AvroPojoFieldValueTypeSupplier implements FieldValueTypeSupplier { + @Override + public List get(Class clazz) { + List classFields = ReflectUtils.getFields(clazz); + Map types = Maps.newHashMap(); + for (int i = 0; i < classFields.size(); ++i) { + java.lang.reflect.Field f = classFields.get(i); + if (!f.isAnnotationPresent(AvroIgnore.class)) { + FieldValueTypeInformation typeInformation = FieldValueTypeInformation.forField(f, i); + AvroName avroname = f.getAnnotation(AvroName.class); + if (avroname != null) { + typeInformation = typeInformation.withName(avroname.value()); + } + types.put(typeInformation.getName(), typeInformation); + } + } + return Lists.newArrayList(types.values()); + } + } + + /** Get field types for an AVRO-generated SpecificRecord or a POJO. */ + public static List getFieldTypes(Class clazz, Schema schema) { + if (TypeDescriptor.of(clazz).isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { + return JavaBeanUtils.getFieldTypes( + clazz, schema, new AvroSpecificRecordFieldValueTypeSupplier()); + } else { + return POJOUtils.getFieldTypes(clazz, schema, new AvroPojoFieldValueTypeSupplier()); + } + } + + /** Get generated getters for an AVRO-generated SpecificRecord or a POJO. */ + public static List getGetters(Class clazz, Schema schema) { + if (TypeDescriptor.of(clazz).isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { + return JavaBeanUtils.getGetters( + clazz, + schema, + new AvroSpecificRecordFieldValueTypeSupplier(), + new AvroTypeConversionFactory()); + } else { + return POJOUtils.getGetters( + clazz, schema, new AvroPojoFieldValueTypeSupplier(), new AvroTypeConversionFactory()); + } + } + + /** Get an object creator for an AVRO-generated SpecificRecord. */ + public static SchemaUserTypeCreator getCreator(Class clazz, Schema schema) { + if (TypeDescriptor.of(clazz).isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { + return AvroByteBuddyUtils.getCreator((Class) clazz, schema); + } else { + return POJOUtils.getSetFieldCreator( + clazz, schema, new AvroPojoFieldValueTypeSupplier(), new AvroTypeConversionFactory()); + } + } + + /** Converts AVRO schema to Beam field. */ + private static FieldType toFieldType(TypeWithNullability type) { + FieldType fieldType = null; + org.apache.avro.Schema avroSchema = type.type; + + LogicalType logicalType = LogicalTypes.fromSchema(avroSchema); + if (logicalType != null) { + if (logicalType instanceof LogicalTypes.Decimal) { + fieldType = FieldType.DECIMAL; + } else if (logicalType instanceof LogicalTypes.TimestampMillis) { + // TODO: There is a desire to move Beam schema DATETIME to a micros representation. When + // this is done, this logical type needs to be changed. + fieldType = FieldType.DATETIME; + } else if (logicalType instanceof LogicalTypes.Date) { + fieldType = FieldType.DATETIME; + } + } + + if (fieldType == null) { + switch (type.type.getType()) { + case RECORD: + fieldType = FieldType.row(toBeamSchema(avroSchema)); + break; + + case ENUM: + fieldType = FieldType.logicalType(EnumerationType.create(type.type.getEnumSymbols())); + break; + + case ARRAY: + FieldType elementType = toFieldType(new TypeWithNullability(avroSchema.getElementType())); + fieldType = FieldType.array(elementType); + break; + + case MAP: + fieldType = + FieldType.map( + FieldType.STRING, + toFieldType(new TypeWithNullability(avroSchema.getValueType()))); + break; + + case FIXED: + fieldType = FixedBytesField.fromAvroType(type.type).toBeamType(); + break; + + case STRING: + fieldType = FieldType.STRING; + break; + + case BYTES: + fieldType = FieldType.BYTES; + break; + + case INT: + fieldType = FieldType.INT32; + break; + + case LONG: + fieldType = FieldType.INT64; + break; + + case FLOAT: + fieldType = FieldType.FLOAT; + break; + + case DOUBLE: + fieldType = FieldType.DOUBLE; + break; + + case BOOLEAN: + fieldType = FieldType.BOOLEAN; + break; + + case UNION: + fieldType = + FieldType.logicalType( + OneOfType.create( + avroSchema.getTypes().stream() + .map(x -> Field.of(x.getName(), toFieldType(new TypeWithNullability(x)))) + .collect(Collectors.toList()))); + break; + case NULL: + throw new IllegalArgumentException("Can't convert 'null' to FieldType"); + + default: + throw new AssertionError("Unexpected AVRO Schema.Type: " + avroSchema.getType()); + } + } + fieldType = fieldType.withNullable(type.nullable); + return fieldType; + } + + private static org.apache.avro.Schema getFieldSchema( + FieldType fieldType, String fieldName, String namespace) { + org.apache.avro.Schema baseType; + switch (fieldType.getTypeName()) { + case BYTE: + case INT16: + case INT32: + baseType = org.apache.avro.Schema.create(Type.INT); + break; + + case INT64: + baseType = org.apache.avro.Schema.create(Type.LONG); + break; + + case DECIMAL: + baseType = + LogicalTypes.decimal(Integer.MAX_VALUE) + .addToSchema(org.apache.avro.Schema.create(Type.BYTES)); + break; + + case FLOAT: + baseType = org.apache.avro.Schema.create(Type.FLOAT); + break; + + case DOUBLE: + baseType = org.apache.avro.Schema.create(Type.DOUBLE); + break; + + case STRING: + baseType = org.apache.avro.Schema.create(Type.STRING); + break; + + case DATETIME: + // TODO: There is a desire to move Beam schema DATETIME to a micros representation. When + // this is done, this logical type needs to be changed. + baseType = + LogicalTypes.timestampMillis().addToSchema(org.apache.avro.Schema.create(Type.LONG)); + break; + + case BOOLEAN: + baseType = org.apache.avro.Schema.create(Type.BOOLEAN); + break; + + case BYTES: + baseType = org.apache.avro.Schema.create(Type.BYTES); + break; + + case LOGICAL_TYPE: + String identifier = fieldType.getLogicalType().getIdentifier(); + if (FixedBytes.IDENTIFIER.equals(identifier)) { + FixedBytesField fixedBytesField = + checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); + baseType = fixedBytesField.toAvroType("fixed", namespace + "." + fieldName); + } else if (VariableBytes.IDENTIFIER.equals(identifier)) { + // treat VARBINARY as bytes as that is what avro supports + baseType = org.apache.avro.Schema.create(Type.BYTES); + } else if (FixedString.IDENTIFIER.equals(identifier) + || "CHAR".equals(identifier) + || "NCHAR".equals(identifier)) { + baseType = + buildHiveLogicalTypeSchema("char", (int) fieldType.getLogicalType().getArgument()); + } else if (VariableString.IDENTIFIER.equals(identifier) + || "NVARCHAR".equals(identifier) + || "VARCHAR".equals(identifier) + || "LONGNVARCHAR".equals(identifier) + || "LONGVARCHAR".equals(identifier)) { + baseType = + buildHiveLogicalTypeSchema("varchar", (int) fieldType.getLogicalType().getArgument()); + } else if (EnumerationType.IDENTIFIER.equals(identifier)) { + EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); + baseType = + org.apache.avro.Schema.createEnum(fieldName, "", "", enumerationType.getValues()); + } else if (OneOfType.IDENTIFIER.equals(identifier)) { + OneOfType oneOfType = fieldType.getLogicalType(OneOfType.class); + baseType = + org.apache.avro.Schema.createUnion( + oneOfType.getOneOfSchema().getFields().stream() + .map(x -> getFieldSchema(x.getType(), x.getName(), namespace)) + .collect(Collectors.toList())); + } else if ("DATE".equals(identifier)) { + baseType = LogicalTypes.date().addToSchema(org.apache.avro.Schema.create(Type.INT)); + } else if ("TIME".equals(identifier)) { + baseType = LogicalTypes.timeMillis().addToSchema(org.apache.avro.Schema.create(Type.INT)); + } else { + throw new RuntimeException( + "Unhandled logical type " + fieldType.getLogicalType().getIdentifier()); + } + break; + + case ARRAY: + case ITERABLE: + baseType = + org.apache.avro.Schema.createArray( + getFieldSchema(fieldType.getCollectionElementType(), fieldName, namespace)); + break; + + case MAP: + if (fieldType.getMapKeyType().getTypeName().isStringType()) { + // Avro only supports string keys in maps. + baseType = + org.apache.avro.Schema.createMap( + getFieldSchema(fieldType.getMapValueType(), fieldName, namespace)); + } else { + throw new IllegalArgumentException("Avro only supports maps with string keys"); + } + break; + + case ROW: + baseType = toAvroSchema(fieldType.getRowSchema(), fieldName, namespace); + break; + + default: + throw new IllegalArgumentException("Unexpected type " + fieldType); + } + return fieldType.getNullable() ? ReflectData.makeNullable(baseType) : baseType; + } + + private static @Nullable Object genericFromBeamField( + FieldType fieldType, org.apache.avro.Schema avroSchema, @Nullable Object value) { + TypeWithNullability typeWithNullability = new TypeWithNullability(avroSchema); + if (!fieldType.getNullable().equals(typeWithNullability.nullable)) { + throw new IllegalArgumentException( + "FieldType " + + fieldType + + " and AVRO schema " + + avroSchema + + " don't have matching nullability"); + } + + if (value == null) { + return value; + } + + switch (fieldType.getTypeName()) { + case BYTE: + case INT16: + case INT32: + case INT64: + case FLOAT: + case DOUBLE: + case BOOLEAN: + return value; + + case STRING: + return new Utf8((String) value); + + case DECIMAL: + BigDecimal decimal = (BigDecimal) value; + LogicalType logicalType = typeWithNullability.type.getLogicalType(); + return new Conversions.DecimalConversion().toBytes(decimal, null, logicalType); + + case DATETIME: + if (typeWithNullability.type.getType() == Type.INT) { + ReadableInstant instant = (ReadableInstant) value; + return (int) Days.daysBetween(Instant.EPOCH, instant).getDays(); + } else if (typeWithNullability.type.getType() == Type.LONG) { + ReadableInstant instant = (ReadableInstant) value; + return (long) instant.getMillis(); + } else { + throw new IllegalArgumentException( + "Can't represent " + fieldType + " as " + typeWithNullability.type.getType()); + } + + case BYTES: + return ByteBuffer.wrap((byte[]) value); + + case LOGICAL_TYPE: + String identifier = fieldType.getLogicalType().getIdentifier(); + if (FixedBytes.IDENTIFIER.equals(identifier)) { + FixedBytesField fixedBytesField = + checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); + byte[] byteArray = (byte[]) value; + if (byteArray.length != fixedBytesField.getSize()) { + throw new IllegalArgumentException("Incorrectly sized byte array."); + } + return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + } else if (VariableBytes.IDENTIFIER.equals(identifier)) { + return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + } else if (FixedString.IDENTIFIER.equals(identifier) + || "CHAR".equals(identifier) + || "NCHAR".equals(identifier)) { + return new Utf8((String) value); + } else if (VariableString.IDENTIFIER.equals(identifier) + || "NVARCHAR".equals(identifier) + || "VARCHAR".equals(identifier) + || "LONGNVARCHAR".equals(identifier) + || "LONGVARCHAR".equals(identifier)) { + return new Utf8((String) value); + } else if (EnumerationType.IDENTIFIER.equals(identifier)) { + EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); + return GenericData.get() + .createEnum( + enumerationType.toString((EnumerationType.Value) value), + typeWithNullability.type); + } else if (OneOfType.IDENTIFIER.equals(identifier)) { + OneOfType oneOfType = fieldType.getLogicalType(OneOfType.class); + OneOfType.Value oneOfValue = (OneOfType.Value) value; + FieldType innerFieldType = oneOfType.getFieldType(oneOfValue); + if (typeWithNullability.nullable && oneOfValue.getValue() == null) { + return null; + } else { + return genericFromBeamField( + innerFieldType.withNullable(false), + typeWithNullability.type.getTypes().get(oneOfValue.getCaseType().getValue()), + oneOfValue.getValue()); + } + } else if ("DATE".equals(identifier)) { + return Days.daysBetween(Instant.EPOCH, (Instant) value).getDays(); + } else if ("TIME".equals(identifier)) { + return (int) ((Instant) value).getMillis(); + } else { + throw new RuntimeException("Unhandled logical type " + identifier); + } + + case ARRAY: + case ITERABLE: + Iterable iterable = (Iterable) value; + List translatedArray = Lists.newArrayListWithExpectedSize(Iterables.size(iterable)); + + for (Object arrayElement : iterable) { + translatedArray.add( + genericFromBeamField( + fieldType.getCollectionElementType(), + typeWithNullability.type.getElementType(), + arrayElement)); + } + return translatedArray; + + case MAP: + Map map = Maps.newHashMap(); + Map valueMap = (Map) value; + for (Map.Entry entry : valueMap.entrySet()) { + Utf8 key = new Utf8((String) entry.getKey()); + map.put( + key, + genericFromBeamField( + fieldType.getMapValueType(), + typeWithNullability.type.getValueType(), + entry.getValue())); + } + return map; + + case ROW: + return toGenericRecord((Row) value, typeWithNullability.type); + + default: + throw new IllegalArgumentException("Unsupported type " + fieldType); + } + } + + /** + * Strict conversion from AVRO to Beam, strict because it doesn't do widening or narrowing during + * conversion. + * + * @param value {@link GenericRecord} or any nested value + * @param avroSchema schema for value + * @param fieldType target beam field type + * @return value converted for {@link Row} + */ + @SuppressWarnings("unchecked") + public static @Nullable Object convertAvroFieldStrict( + @Nullable Object value, + @Nonnull org.apache.avro.Schema avroSchema, + @Nonnull FieldType fieldType) { + if (value == null) { + return null; + } + + TypeWithNullability type = new TypeWithNullability(avroSchema); + LogicalType logicalType = LogicalTypes.fromSchema(type.type); + if (logicalType != null) { + if (logicalType instanceof LogicalTypes.Decimal) { + ByteBuffer byteBuffer = (ByteBuffer) value; + BigDecimal bigDecimal = + new Conversions.DecimalConversion() + .fromBytes(byteBuffer.duplicate(), type.type, logicalType); + return convertDecimal(bigDecimal, fieldType); + } else if (logicalType instanceof LogicalTypes.TimestampMillis) { + if (value instanceof ReadableInstant) { + return convertDateTimeStrict(((ReadableInstant) value).getMillis(), fieldType); + } else { + return convertDateTimeStrict((Long) value, fieldType); + } + } else if (logicalType instanceof LogicalTypes.Date) { + if (value instanceof ReadableInstant) { + int epochDays = Days.daysBetween(Instant.EPOCH, (ReadableInstant) value).getDays(); + return convertDateStrict(epochDays, fieldType); + } else { + return convertDateStrict((Integer) value, fieldType); + } + } + } + + switch (type.type.getType()) { + case FIXED: + return convertFixedStrict((GenericFixed) value, fieldType); + + case BYTES: + return convertBytesStrict((ByteBuffer) value, fieldType); + + case STRING: + return convertStringStrict((CharSequence) value, fieldType); + + case INT: + return convertIntStrict((Integer) value, fieldType); + + case LONG: + return convertLongStrict((Long) value, fieldType); + + case FLOAT: + return convertFloatStrict((Float) value, fieldType); + + case DOUBLE: + return convertDoubleStrict((Double) value, fieldType); + + case BOOLEAN: + return convertBooleanStrict((Boolean) value, fieldType); + + case RECORD: + return convertRecordStrict((GenericRecord) value, fieldType); + + case ENUM: + // enums are either Java enums, or GenericEnumSymbol, + // they don't share common interface, but override toString() + return convertEnumStrict(value, fieldType); + + case ARRAY: + return convertArrayStrict((List) value, type.type.getElementType(), fieldType); + + case MAP: + return convertMapStrict( + (Map) value, type.type.getValueType(), fieldType); + + case UNION: + return convertUnionStrict(value, type.type, fieldType); + + case NULL: + throw new IllegalArgumentException("Can't convert 'null' to non-nullable field"); + + default: + throw new AssertionError("Unexpected AVRO Schema.Type: " + type.type.getType()); + } + } + + private static Object convertRecordStrict(GenericRecord record, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.ROW, "record"); + return toBeamRowStrict(record, fieldType.getRowSchema()); + } + + private static Object convertBytesStrict(ByteBuffer bb, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.BYTES, "bytes"); + + byte[] bytes = new byte[bb.remaining()]; + bb.duplicate().get(bytes); + return bytes; + } + + private static Object convertFixedStrict(GenericFixed fixed, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "fixed"); + checkArgument(FixedBytes.IDENTIFIER.equals(fieldType.getLogicalType().getIdentifier())); + return fixed.bytes().clone(); // clone because GenericFixed is mutable + } + + private static Object convertStringStrict(CharSequence value, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.STRING, "string"); + return value.toString(); + } + + private static Object convertIntStrict(Integer value, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.INT32, "int"); + return value; + } + + private static Object convertLongStrict(Long value, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.INT64, "long"); + return value; + } + + private static Object convertDecimal(BigDecimal value, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.DECIMAL, "decimal"); + return value; + } + + private static Object convertDateStrict(Integer epochDays, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.DATETIME, "date"); + return Instant.EPOCH.plus(Duration.standardDays(epochDays)); + } + + private static Object convertDateTimeStrict(Long value, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.DATETIME, "dateTime"); + return new Instant(value); + } + + private static Object convertFloatStrict(Float value, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.FLOAT, "float"); + return value; + } + + private static Object convertDoubleStrict(Double value, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.DOUBLE, "double"); + return value; + } + + private static Object convertBooleanStrict(Boolean value, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.BOOLEAN, "boolean"); + return value; + } + + private static Object convertEnumStrict(Object value, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "enum"); + checkArgument(fieldType.getLogicalType().getIdentifier().equals(EnumerationType.IDENTIFIER)); + EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); + return enumerationType.valueOf(value.toString()); + } + + private static Object convertUnionStrict( + Object value, org.apache.avro.Schema unionAvroSchema, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "oneOfType"); + checkArgument(fieldType.getLogicalType().getIdentifier().equals(OneOfType.IDENTIFIER)); + OneOfType oneOfType = fieldType.getLogicalType(OneOfType.class); + int fieldNumber = GenericData.get().resolveUnion(unionAvroSchema, value); + FieldType baseFieldType = oneOfType.getOneOfSchema().getField(fieldNumber).getType(); + Object convertedValue = + convertAvroFieldStrict(value, unionAvroSchema.getTypes().get(fieldNumber), baseFieldType); + return oneOfType.createValue(fieldNumber, convertedValue); + } + + private static Object convertArrayStrict( + List values, org.apache.avro.Schema elemAvroSchema, FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.ARRAY, "array"); + + List ret = new ArrayList<>(values.size()); + FieldType elemFieldType = fieldType.getCollectionElementType(); + + for (Object value : values) { + ret.add(convertAvroFieldStrict(value, elemAvroSchema, elemFieldType)); + } + + return ret; + } + + private static Object convertMapStrict( + Map values, + org.apache.avro.Schema valueAvroSchema, + FieldType fieldType) { + checkTypeName(fieldType.getTypeName(), TypeName.MAP, "map"); + checkNotNull(fieldType.getMapKeyType()); + checkNotNull(fieldType.getMapValueType()); + + if (!fieldType.getMapKeyType().equals(FieldType.STRING)) { + throw new IllegalArgumentException( + "Can't convert 'string' map keys to " + fieldType.getMapKeyType()); + } + + Map ret = new HashMap<>(); + + for (Map.Entry value : values.entrySet()) { + ret.put( + convertStringStrict(value.getKey(), fieldType.getMapKeyType()), + convertAvroFieldStrict(value.getValue(), valueAvroSchema, fieldType.getMapValueType())); + } + + return ret; + } + + private static void checkTypeName(TypeName got, TypeName expected, String label) { + checkArgument( + got.equals(expected), "Can't convert '%s' to %s, expected: %s", label, got, expected); + } + + /** + * Helper factory to build Avro Logical types schemas for SQL *CHAR types. This method represents + * the logical as Hive does. + */ + private static org.apache.avro.Schema buildHiveLogicalTypeSchema( + String hiveLogicalType, int size) { + String schemaJson = + String.format( + "{\"type\": \"string\", \"logicalType\": \"%s\", \"maxLength\": %s}", + hiveLogicalType, size); + return new org.apache.avro.Schema.Parser().parse(schemaJson); + } +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/package-info.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/package-info.java new file mode 100644 index 000000000000..df84a556c28c --- /dev/null +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/package-info.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** Defines utilities for deailing with schemas using Apache Avro. */ +@DefaultAnnotation(NonNull.class) +@Experimental(Kind.EXTENSION) +package org.apache.beam.sdk.extensions.avro.schemas.utils; + +import edu.umd.cs.findbugs.annotations.DefaultAnnotation; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.checkerframework.checker.nullness.qual.NonNull; diff --git a/sdks/java/extensions/avro/src/test/avro/org/apache/beam/sdk/extensions/avro/io/user.avsc b/sdks/java/extensions/avro/src/test/avro/org/apache/beam/sdk/extensions/avro/io/user.avsc new file mode 100644 index 000000000000..134829746e49 --- /dev/null +++ b/sdks/java/extensions/avro/src/test/avro/org/apache/beam/sdk/extensions/avro/io/user.avsc @@ -0,0 +1,10 @@ +{ + "namespace": "org.apache.beam.sdk.extensions.avro.io", + "type": "record", + "name": "AvroGeneratedUser", + "fields": [ + { "name": "name", "type": "string"}, + { "name": "favorite_number", "type": ["int", "null"]}, + { "name": "favorite_color", "type": ["string", "null"]} + ] +} diff --git a/sdks/java/extensions/avro/src/test/avro/org/apache/beam/sdk/extensions/avro/schemas/test.avsc b/sdks/java/extensions/avro/src/test/avro/org/apache/beam/sdk/extensions/avro/schemas/test.avsc new file mode 100644 index 000000000000..a7d13e4ce451 --- /dev/null +++ b/sdks/java/extensions/avro/src/test/avro/org/apache/beam/sdk/extensions/avro/schemas/test.avsc @@ -0,0 +1,30 @@ +{ + "namespace": "org.apache.beam.sdk.extensions.avro.schemas", + "type": "record", + "name": "TestAvro", + "fields": [ + { "name": "bool_non_nullable", "type": "boolean"}, + { "name": "int", "type": ["int", "null"]}, + { "name": "long", "type": ["long", "null"]}, + { "name": "float", "type": ["float", "null"]}, + { "name": "double", "type": ["double", "null"]}, + { "name": "string", "type": ["string", "null"]}, + { "name": "bytes", "type": ["bytes", "null"]}, + { "name": "fixed", "type": {"type": "fixed", "size": 4, "name": "fixed4"} }, + { "name": "date", "type": {"type": "int", "logicalType": "date"} }, + { "name": "timestampMillis", "type": {"type": "long", "logicalType": "timestamp-millis"} }, + { "name": "TestEnum", "type": {"name": "TestEnum", "type": "enum", "symbols": ["abc","cde"] } }, + { "name": "row", "type": ["null", { + "type": "record", + "name": "TestAvroNested", + "fields": [ + { "name": "BOOL_NON_NULLABLE", "type": "boolean"}, + { "name": "int", "type": ["int", "null"]} + ] + }] + }, + { "name": "array", "type":["null", {"type": "array", "items": ["null", "TestAvroNested"] }]}, + { "name": "map", "type": ["null", {"type": "map", "values": ["null", "TestAvroNested"]}]} + ] +} + diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoderTest.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoderTest.java new file mode 100644 index 000000000000..730ccf60e0b9 --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoderTest.java @@ -0,0 +1,1108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.coders; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import org.apache.avro.AvroRuntimeException; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.reflect.AvroName; +import org.apache.avro.reflect.AvroSchema; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.Stringable; +import org.apache.avro.reflect.Union; +import org.apache.avro.specific.SpecificData; +import org.apache.avro.specific.SpecificRecord; +import org.apache.avro.util.Utf8; +import org.apache.beam.sdk.coders.Coder.Context; +import org.apache.beam.sdk.coders.Coder.NonDeterministicException; +import org.apache.beam.sdk.coders.DefaultCoder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.extensions.avro.schemas.TestAvro; +import org.apache.beam.sdk.extensions.avro.schemas.TestAvroNested; +import org.apache.beam.sdk.extensions.avro.schemas.TestEnum; +import org.apache.beam.sdk.extensions.avro.schemas.fixed4; +import org.apache.beam.sdk.testing.CoderProperties; +import org.apache.beam.sdk.testing.InterceptingUrlClassLoader; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.InstanceBuilder; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.Matchers; +import org.hamcrest.TypeSafeMatcher; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.joda.time.LocalDate; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.objenesis.strategy.StdInstantiatorStrategy; + +/** Tests for {@link AvroCoder}. */ +@RunWith(JUnit4.class) +public class AvroCoderTest { + + public static final DateTime DATETIME_A = + new DateTime().withDate(1994, 10, 31).withZone(DateTimeZone.UTC); + public static final DateTime DATETIME_B = + new DateTime().withDate(1997, 4, 25).withZone(DateTimeZone.UTC); + private static final TestAvroNested AVRO_NESTED_SPECIFIC_RECORD = new TestAvroNested(true, 42); + private static final TestAvro AVRO_SPECIFIC_RECORD = + new TestAvro( + true, + 43, + 44L, + 44.1f, + 44.2d, + "mystring", + ByteBuffer.wrap(new byte[] {1, 2, 3, 4}), + new fixed4(new byte[] {1, 2, 3, 4}), + new LocalDate(1979, 3, 14), + new DateTime().withDate(1979, 3, 14).withTime(1, 2, 3, 4), + TestEnum.abc, + AVRO_NESTED_SPECIFIC_RECORD, + ImmutableList.of(AVRO_NESTED_SPECIFIC_RECORD, AVRO_NESTED_SPECIFIC_RECORD), + ImmutableMap.of("k1", AVRO_NESTED_SPECIFIC_RECORD, "k2", AVRO_NESTED_SPECIFIC_RECORD)); + + @DefaultCoder(AvroCoder.class) + private static class Pojo { + public String text; + public int count; + + @AvroSchema("{\"type\": \"long\", \"logicalType\": \"timestamp-millis\"}") + public DateTime timestamp; + + // Empty constructor required for Avro decoding. + @SuppressWarnings("unused") + public Pojo() {} + + public Pojo(String text, int count, DateTime timestamp) { + this.text = text; + this.count = count; + this.timestamp = timestamp; + } + + // auto-generated + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Pojo pojo = (Pojo) o; + + if (count != pojo.count) { + return false; + } + if (text != null ? !text.equals(pojo.text) : pojo.text != null) { + return false; + } + if (timestamp != null ? !timestamp.equals(pojo.timestamp) : pojo.timestamp != null) { + return false; + } + + return true; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public String toString() { + return "Pojo{" + + "text='" + + text + + '\'' + + ", count=" + + count + + ", timestamp=" + + timestamp + + '}'; + } + } + + private static class GetTextFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element().text); + } + } + + @Rule public TestPipeline pipeline = TestPipeline.create(); + + @Test + public void testAvroCoderEncoding() throws Exception { + AvroCoder coder = AvroCoder.of(Pojo.class); + CoderProperties.coderSerializable(coder); + AvroCoder copy = SerializableUtils.clone(coder); + + Pojo pojo = new Pojo("foo", 3, DATETIME_A); + Pojo equalPojo = new Pojo("foo", 3, DATETIME_A); + Pojo otherPojo = new Pojo("bar", -19, DATETIME_B); + CoderProperties.coderConsistentWithEquals(coder, pojo, equalPojo); + CoderProperties.coderConsistentWithEquals(copy, pojo, equalPojo); + CoderProperties.coderConsistentWithEquals(coder, pojo, otherPojo); + CoderProperties.coderConsistentWithEquals(copy, pojo, otherPojo); + } + + /** + * Tests that {@link AvroCoder} works around issues in Avro where cache classes might be from the + * wrong ClassLoader, causing confusing "Cannot cast X to X" error messages. + */ + @SuppressWarnings("ReturnValueIgnored") + @Test + public void testTwoClassLoaders() throws Exception { + ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + ClassLoader loader1 = + new InterceptingUrlClassLoader(contextClassLoader, AvroCoderTestPojo.class.getName()); + ClassLoader loader2 = + new InterceptingUrlClassLoader(contextClassLoader, AvroCoderTestPojo.class.getName()); + + Class pojoClass1 = loader1.loadClass(AvroCoderTestPojo.class.getName()); + Class pojoClass2 = loader2.loadClass(AvroCoderTestPojo.class.getName()); + + Object pojo1 = InstanceBuilder.ofType(pojoClass1).withArg(String.class, "hello").build(); + Object pojo2 = InstanceBuilder.ofType(pojoClass2).withArg(String.class, "goodbye").build(); + + // Confirm incompatibility + try { + pojoClass2.cast(pojo1); + fail("Expected ClassCastException; without it, this test is vacuous"); + } catch (ClassCastException e) { + // g2g + } + + // The first coder is expected to populate the Avro SpecificData cache + // The second coder is expected to be corrupted if the caching is done wrong. + AvroCoder avroCoder1 = (AvroCoder) AvroCoder.of(pojoClass1); + AvroCoder avroCoder2 = (AvroCoder) AvroCoder.of(pojoClass2); + + Object cloned1 = CoderUtils.clone(avroCoder1, pojo1); + Object cloned2 = CoderUtils.clone(avroCoder2, pojo2); + + // Confirming that the uncorrupted coder is fine + pojoClass1.cast(cloned1); + + // Confirmed to fail prior to the fix + pojoClass2.cast(cloned2); + } + + /** + * Confirm that we can serialize and deserialize an AvroCoder object and still decode after. + * (https://github.com/apache/beam/issues/18022). + * + * @throws Exception + */ + @Test + public void testTransientFieldInitialization() throws Exception { + Pojo value = new Pojo("Hello", 42, DATETIME_A); + AvroCoder coder = AvroCoder.of(Pojo.class); + + // Serialization of object + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bos); + out.writeObject(coder); + + // De-serialization of object + ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); + ObjectInputStream in = new ObjectInputStream(bis); + AvroCoder copied = (AvroCoder) in.readObject(); + + CoderProperties.coderDecodeEncodeEqual(copied, value); + } + + /** + * Confirm that we can serialize and deserialize an AvroCoder object using Kryo. (BEAM-626). + * + * @throws Exception + */ + @Test + public void testKryoSerialization() throws Exception { + Pojo value = new Pojo("Hello", 42, DATETIME_A); + AvroCoder coder = AvroCoder.of(Pojo.class); + + // Kryo instantiation + Kryo kryo = new Kryo(); + kryo.setInstantiatorStrategy(new StdInstantiatorStrategy()); + + // Serialization of object without any memoization + ByteArrayOutputStream coderWithoutMemoizationBos = new ByteArrayOutputStream(); + try (Output output = new Output(coderWithoutMemoizationBos)) { + kryo.writeObject(output, coder); + } + + // Force thread local memoization to store values. + CoderProperties.coderDecodeEncodeEqual(coder, value); + + // Serialization of object with memoized fields + ByteArrayOutputStream coderWithMemoizationBos = new ByteArrayOutputStream(); + try (Output output = new Output(coderWithMemoizationBos)) { + kryo.writeObject(output, coder); + } + + // Copy empty and memoized variants of the Coder + ByteArrayInputStream bisWithoutMemoization = + new ByteArrayInputStream(coderWithoutMemoizationBos.toByteArray()); + AvroCoder copiedWithoutMemoization = + (AvroCoder) kryo.readObject(new Input(bisWithoutMemoization), AvroCoder.class); + ByteArrayInputStream bisWithMemoization = + new ByteArrayInputStream(coderWithMemoizationBos.toByteArray()); + AvroCoder copiedWithMemoization = + (AvroCoder) kryo.readObject(new Input(bisWithMemoization), AvroCoder.class); + + CoderProperties.coderDecodeEncodeEqual(copiedWithoutMemoization, value); + CoderProperties.coderDecodeEncodeEqual(copiedWithMemoization, value); + } + + @Test + public void testPojoEncoding() throws Exception { + Pojo value = new Pojo("Hello", 42, DATETIME_A); + AvroCoder coder = AvroCoder.of(Pojo.class); + + CoderProperties.coderDecodeEncodeEqual(coder, value); + } + + @Test + public void testSpecificRecordEncoding() throws Exception { + AvroCoder coder = + AvroCoder.of(TestAvro.class, AVRO_SPECIFIC_RECORD.getSchema(), false); + + assertTrue(SpecificRecord.class.isAssignableFrom(coder.getType())); + CoderProperties.coderDecodeEncodeEqual(coder, AVRO_SPECIFIC_RECORD); + } + + @Test + public void testReflectRecordEncoding() throws Exception { + AvroCoder coder = AvroCoder.of(TestAvro.class, true); + AvroCoder coderWithSchema = + AvroCoder.of(TestAvro.class, AVRO_SPECIFIC_RECORD.getSchema(), true); + + assertTrue(SpecificRecord.class.isAssignableFrom(coder.getType())); + assertTrue(SpecificRecord.class.isAssignableFrom(coderWithSchema.getType())); + + CoderProperties.coderDecodeEncodeEqual(coder, AVRO_SPECIFIC_RECORD); + CoderProperties.coderDecodeEncodeEqual(coderWithSchema, AVRO_SPECIFIC_RECORD); + } + + @Test + public void testDisableReflectionEncoding() { + try { + AvroCoder.of(Pojo.class, false); + fail("When userReclectApi is disable, schema should not be generated through reflection"); + } catch (AvroRuntimeException e) { + String message = + "avro.shaded.com.google.common.util.concurrent.UncheckedExecutionException: " + + "org.apache.avro.AvroRuntimeException: " + + "Not a Specific class: class org.apache.beam.sdk.extensions.avro.coders.AvroCoderTest$Pojo"; + assertEquals(message, e.getMessage()); + } + } + + @Test + public void testGenericRecordEncoding() throws Exception { + String schemaString = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"User\",\n" + + " \"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"},\n" + + " {\"name\": \"favorite_number\", \"type\": [\"int\", \"null\"]},\n" + + " {\"name\": \"favorite_color\", \"type\": [\"string\", \"null\"]}\n" + + " ]\n" + + "}"; + Schema schema = new Schema.Parser().parse(schemaString); + + GenericRecord before = new GenericData.Record(schema); + before.put("name", "Bob"); + before.put("favorite_number", 256); + // Leave favorite_color null + + AvroCoder coder = AvroCoder.of(GenericRecord.class, schema); + + CoderProperties.coderDecodeEncodeEqual(coder, before); + assertEquals(schema, coder.getSchema()); + } + + @Test + public void testEncodingNotBuffered() throws Exception { + // This test ensures that the coder doesn't read ahead and buffer data. + // Reading ahead causes a problem if the stream consists of records of different + // types. + Pojo before = new Pojo("Hello", 42, DATETIME_A); + + AvroCoder coder = AvroCoder.of(Pojo.class); + SerializableCoder intCoder = SerializableCoder.of(Integer.class); + + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + + Context context = Context.NESTED; + coder.encode(before, outStream, context); + intCoder.encode(10, outStream, context); + + ByteArrayInputStream inStream = new ByteArrayInputStream(outStream.toByteArray()); + + Pojo after = coder.decode(inStream, context); + assertEquals(before, after); + + Integer intAfter = intCoder.decode(inStream, context); + assertEquals(Integer.valueOf(10), intAfter); + } + + @Test + @Category(NeedsRunner.class) + public void testDefaultCoder() throws Exception { + // Use MyRecord as input and output types without explicitly specifying + // a coder (this uses the default coders, which may not be AvroCoder). + PCollection output = + pipeline + .apply(Create.of(new Pojo("hello", 1, DATETIME_A), new Pojo("world", 2, DATETIME_B))) + .apply(ParDo.of(new GetTextFn())); + + PAssert.that(output).containsInAnyOrder("hello", "world"); + pipeline.run(); + } + + @Test + public void testAvroCoderIsSerializable() throws Exception { + AvroCoder coder = AvroCoder.of(Pojo.class); + + // Check that the coder is serializable using the regular JSON approach. + SerializableUtils.ensureSerializable(coder); + } + + @Test + public void testAvroSpecificCoderIsSerializable() throws Exception { + AvroCoder coder = AvroCoder.of(TestAvro.class, false); + + // Check that the coder is serializable using the regular JSON approach. + SerializableUtils.ensureSerializable(coder); + } + + private void assertDeterministic(AvroCoder coder) { + try { + coder.verifyDeterministic(); + } catch (NonDeterministicException e) { + fail("Expected " + coder + " to be deterministic, but got:\n" + e); + } + } + + private void assertNonDeterministic(AvroCoder coder, Matcher reason1) { + try { + coder.verifyDeterministic(); + fail("Expected " + coder + " to be non-deterministic."); + } catch (NonDeterministicException e) { + assertThat(e.getReasons(), Matchers.iterableWithSize(1)); + assertThat(e.getReasons(), Matchers.contains(reason1)); + } + } + + @Test + public void testDeterministicInteger() { + assertDeterministic(AvroCoder.of(Integer.class)); + } + + @Test + public void testDeterministicInt() { + assertDeterministic(AvroCoder.of(int.class)); + } + + private static class SimpleDeterministicClass { + @SuppressWarnings("unused") + private Integer intField; + + @SuppressWarnings("unused") + private char charField; + + @SuppressWarnings("unused") + private Integer[] intArray; + + @SuppressWarnings("unused") + private Utf8 utf8field; + } + + @Test + public void testDeterministicSimple() { + assertDeterministic(AvroCoder.of(SimpleDeterministicClass.class)); + } + + private static class UnorderedMapClass { + @SuppressWarnings("unused") + private Map mapField; + } + + private Matcher reason(final String prefix, final String messagePart) { + return new TypeSafeMatcher(String.class) { + @Override + public void describeTo(Description description) { + description.appendText( + String.format("Reason starting with '%s:' containing '%s'", prefix, messagePart)); + } + + @Override + protected boolean matchesSafely(String item) { + return item.startsWith(prefix + ":") && item.contains(messagePart); + } + }; + } + + private Matcher reasonClass(Class clazz, String message) { + return reason(clazz.getName(), message); + } + + private Matcher reasonField(Class clazz, String field, String message) { + return reason(clazz.getName() + "#" + field, message); + } + + @Test + public void testDeterministicUnorderedMap() { + assertNonDeterministic( + AvroCoder.of(UnorderedMapClass.class), + reasonField( + UnorderedMapClass.class, + "mapField", + "java.util.Map " + + "may not be deterministically ordered")); + } + + private static class NonDeterministicArray { + @SuppressWarnings("unused") + private UnorderedMapClass[] arrayField; + } + + @Test + public void testDeterministicNonDeterministicArray() { + assertNonDeterministic( + AvroCoder.of(NonDeterministicArray.class), + reasonField( + UnorderedMapClass.class, + "mapField", + "java.util.Map" + + " may not be deterministically ordered")); + } + + private static class SubclassOfUnorderedMapClass extends UnorderedMapClass {} + + @Test + public void testDeterministicNonDeterministicChild() { + // Super class has non deterministic fields. + assertNonDeterministic( + AvroCoder.of(SubclassOfUnorderedMapClass.class), + reasonField(UnorderedMapClass.class, "mapField", "may not be deterministically ordered")); + } + + private static class SubclassHidingParent extends UnorderedMapClass { + @SuppressWarnings("unused") + @AvroName("mapField2") // AvroName is not enough + private int mapField; + } + + @Test + public void testAvroProhibitsShadowing() { + // This test verifies that Avro won't serialize a class with two fields of + // the same name. This is important for our error reporting, and also how + // we lookup a field. + try { + ReflectData.get().getSchema(SubclassHidingParent.class); + fail("Expected AvroTypeException"); + } catch (AvroRuntimeException e) { + assertThat(e.getMessage(), containsString("mapField")); + assertThat(e.getMessage(), containsString("two fields named")); + } + } + + private static class FieldWithAvroName { + @AvroName("name") + @SuppressWarnings("unused") + private int someField; + } + + @Test + public void testDeterministicWithAvroName() { + assertDeterministic(AvroCoder.of(FieldWithAvroName.class)); + } + + @Test + public void testDeterminismSortedMap() { + assertDeterministic(AvroCoder.of(StringSortedMapField.class)); + } + + private static class StringSortedMapField { + @SuppressWarnings("unused") + SortedMap sortedMapField; + } + + @Test + public void testDeterminismTreeMapValue() { + // The value is non-deterministic, so we should fail. + assertNonDeterministic( + AvroCoder.of(TreeMapNonDetValue.class), + reasonField( + UnorderedMapClass.class, + "mapField", + "java.util.Map " + + "may not be deterministically ordered")); + } + + private static class TreeMapNonDetValue { + @SuppressWarnings("unused") + TreeMap nonDeterministicField; + } + + @Test + public void testDeterminismUnorderedMap() { + // LinkedHashMap is not deterministically ordered, so we should fail. + assertNonDeterministic( + AvroCoder.of(LinkedHashMapField.class), + reasonField( + LinkedHashMapField.class, + "nonDeterministicMap", + "java.util.LinkedHashMap " + + "may not be deterministically ordered")); + } + + private static class LinkedHashMapField { + @SuppressWarnings("unused") + LinkedHashMap nonDeterministicMap; + } + + @Test + public void testDeterminismCollection() { + assertNonDeterministic( + AvroCoder.of(StringCollection.class), + reasonField( + StringCollection.class, + "stringCollection", + "java.util.Collection may not be deterministically ordered")); + } + + private static class StringCollection { + @SuppressWarnings("unused") + Collection stringCollection; + } + + @Test + public void testDeterminismList() { + assertDeterministic(AvroCoder.of(StringList.class)); + assertDeterministic(AvroCoder.of(StringArrayList.class)); + } + + private static class StringList { + @SuppressWarnings("unused") + List stringCollection; + } + + private static class StringArrayList { + @SuppressWarnings("unused") + ArrayList stringCollection; + } + + @Test + public void testDeterminismSet() { + assertDeterministic(AvroCoder.of(StringSortedSet.class)); + assertDeterministic(AvroCoder.of(StringTreeSet.class)); + assertNonDeterministic( + AvroCoder.of(StringHashSet.class), + reasonField( + StringHashSet.class, + "stringCollection", + "java.util.HashSet may not be deterministically ordered")); + } + + private static class StringSortedSet { + @SuppressWarnings("unused") + SortedSet stringCollection; + } + + private static class StringTreeSet { + @SuppressWarnings("unused") + TreeSet stringCollection; + } + + private static class StringHashSet { + @SuppressWarnings("unused") + HashSet stringCollection; + } + + @Test + public void testDeterminismCollectionValue() { + assertNonDeterministic( + AvroCoder.of(OrderedSetOfNonDetValues.class), + reasonField(UnorderedMapClass.class, "mapField", "may not be deterministically ordered")); + assertNonDeterministic( + AvroCoder.of(ListOfNonDetValues.class), + reasonField(UnorderedMapClass.class, "mapField", "may not be deterministically ordered")); + } + + private static class OrderedSetOfNonDetValues { + @SuppressWarnings("unused") + SortedSet set; + } + + private static class ListOfNonDetValues { + @SuppressWarnings("unused") + List set; + } + + @Test + public void testDeterminismUnion() { + assertDeterministic(AvroCoder.of(DeterministicUnionBase.class)); + assertNonDeterministic( + AvroCoder.of(NonDeterministicUnionBase.class), + reasonField(UnionCase3.class, "mapField", "may not be deterministically ordered")); + } + + @Test + public void testDeterminismStringable() { + assertDeterministic(AvroCoder.of(String.class)); + assertNonDeterministic( + AvroCoder.of(StringableClass.class), + reasonClass(StringableClass.class, "may not have deterministic #toString()")); + } + + @Stringable + private static class StringableClass {} + + @Test + public void testDeterminismCyclicClass() { + assertNonDeterministic( + AvroCoder.of(Cyclic.class), + reasonField(Cyclic.class, "cyclicField", "appears recursively")); + assertNonDeterministic( + AvroCoder.of(CyclicField.class), + reasonField(Cyclic.class, "cyclicField", Cyclic.class.getName() + " appears recursively")); + assertNonDeterministic( + AvroCoder.of(IndirectCycle1.class), + reasonField( + IndirectCycle2.class, + "field2", + IndirectCycle1.class.getName() + " appears recursively")); + } + + private static class Cyclic { + @SuppressWarnings("unused") + int intField; + + @SuppressWarnings("unused") + Cyclic cyclicField; + } + + private static class CyclicField { + @SuppressWarnings("unused") + Cyclic cyclicField2; + } + + private static class IndirectCycle1 { + @SuppressWarnings("unused") + IndirectCycle2 field1; + } + + private static class IndirectCycle2 { + @SuppressWarnings("unused") + IndirectCycle1 field2; + } + + @Test + public void testDeterminismHasGenericRecord() { + assertDeterministic(AvroCoder.of(HasGenericRecord.class)); + } + + private static class HasGenericRecord { + @AvroSchema( + "{\"name\": \"bar\", \"type\": \"record\", \"fields\": [" + + "{\"name\": \"foo\", \"type\": \"int\"}]}") + GenericRecord genericRecord; + } + + @Test + public void testDeterminismHasCustomSchema() { + assertNonDeterministic( + AvroCoder.of(HasCustomSchema.class), + reasonField( + HasCustomSchema.class, + "withCustomSchema", + "Custom schemas are only supported for subtypes of IndexedRecord.")); + } + + private static class HasCustomSchema { + @AvroSchema( + "{\"name\": \"bar\", \"type\": \"record\", \"fields\": [" + + "{\"name\": \"foo\", \"type\": \"int\"}]}") + int withCustomSchema; + } + + @Test + public void testAvroCoderTreeMapDeterminism() throws Exception, NonDeterministicException { + TreeMapField size1 = new TreeMapField(); + TreeMapField size2 = new TreeMapField(); + + // Different order for entries + size1.field.put("hello", "world"); + size1.field.put("another", "entry"); + + size2.field.put("another", "entry"); + size2.field.put("hello", "world"); + + AvroCoder coder = AvroCoder.of(TreeMapField.class); + coder.verifyDeterministic(); + + ByteArrayOutputStream outStream1 = new ByteArrayOutputStream(); + ByteArrayOutputStream outStream2 = new ByteArrayOutputStream(); + + Context context = Context.NESTED; + coder.encode(size1, outStream1, context); + coder.encode(size2, outStream2, context); + + assertArrayEquals(outStream1.toByteArray(), outStream2.toByteArray()); + } + + private static class TreeMapField { + private TreeMap field = new TreeMap<>(); + } + + @Union({UnionCase1.class, UnionCase2.class}) + private abstract static class DeterministicUnionBase {} + + @Union({UnionCase1.class, UnionCase2.class, UnionCase3.class}) + private abstract static class NonDeterministicUnionBase {} + + private static class UnionCase1 extends DeterministicUnionBase {} + + private static class UnionCase2 extends DeterministicUnionBase { + @SuppressWarnings("unused") + String field; + } + + private static class UnionCase3 extends NonDeterministicUnionBase { + @SuppressWarnings("unused") + private Map mapField; + } + + @Test + public void testAvroCoderSimpleSchemaDeterminism() { + assertDeterministic(AvroCoder.of(SchemaBuilder.record("someRecord").fields().endRecord())); + assertDeterministic( + AvroCoder.of( + SchemaBuilder.record("someRecord") + .fields() + .name("int") + .type() + .intType() + .noDefault() + .endRecord())); + assertDeterministic( + AvroCoder.of( + SchemaBuilder.record("someRecord") + .fields() + .name("string") + .type() + .stringType() + .noDefault() + .endRecord())); + + assertNonDeterministic( + AvroCoder.of( + SchemaBuilder.record("someRecord") + .fields() + .name("map") + .type() + .map() + .values() + .stringType() + .noDefault() + .endRecord()), + reason("someRecord.map", "HashMap to represent MAPs")); + + assertDeterministic( + AvroCoder.of( + SchemaBuilder.record("someRecord") + .fields() + .name("array") + .type() + .array() + .items() + .stringType() + .noDefault() + .endRecord())); + + assertDeterministic( + AvroCoder.of( + SchemaBuilder.record("someRecord") + .fields() + .name("enum") + .type() + .enumeration("anEnum") + .symbols("s1", "s2") + .enumDefault("s1") + .endRecord())); + + assertDeterministic( + AvroCoder.of( + SchemaBuilder.unionOf() + .intType() + .and() + .record("someRecord") + .fields() + .nullableString("someField", "") + .endRecord() + .endUnion())); + } + + @Test + public void testAvroCoderStrings() { + // Custom Strings in Records + assertDeterministic( + AvroCoder.of( + SchemaBuilder.record("someRecord") + .fields() + .name("string") + .prop(SpecificData.CLASS_PROP, "java.lang.String") + .type() + .stringType() + .noDefault() + .endRecord())); + assertNonDeterministic( + AvroCoder.of( + SchemaBuilder.record("someRecord") + .fields() + .name("string") + .prop(SpecificData.CLASS_PROP, "unknownString") + .type() + .stringType() + .noDefault() + .endRecord()), + reason("someRecord.string", "unknownString is not known to be deterministic")); + + // Custom Strings in Unions + assertNonDeterministic( + AvroCoder.of( + SchemaBuilder.unionOf() + .intType() + .and() + .record("someRecord") + .fields() + .name("someField") + .prop(SpecificData.CLASS_PROP, "unknownString") + .type() + .stringType() + .noDefault() + .endRecord() + .endUnion()), + reason("someRecord.someField", "unknownString is not known to be deterministic")); + } + + @Test + public void testAvroCoderNestedRecords() { + // Nested Record + assertDeterministic( + AvroCoder.of( + SchemaBuilder.record("nestedRecord") + .fields() + .name("subRecord") + .type() + .record("subRecord") + .fields() + .name("innerField") + .type() + .stringType() + .noDefault() + .endRecord() + .noDefault() + .endRecord())); + } + + @Test + public void testAvroCoderCyclicRecords() { + // Recursive record + assertNonDeterministic( + AvroCoder.of( + SchemaBuilder.record("cyclicRecord") + .fields() + .name("cycle") + .type("cyclicRecord") + .noDefault() + .endRecord()), + reason("cyclicRecord.cycle", "cyclicRecord appears recursively")); + } + + private static class NullableField { + @SuppressWarnings("unused") + private @Nullable String nullable; + } + + @Test + public void testNullableField() { + assertDeterministic(AvroCoder.of(NullableField.class)); + } + + private static class NullableNonDeterministicField { + @SuppressWarnings("unused") + private @Nullable NonDeterministicArray nullableNonDetArray; + } + + private static class NullableCyclic { + @SuppressWarnings("unused") + private @Nullable NullableCyclic nullableNullableCyclicField; + } + + private static class NullableCyclicField { + @SuppressWarnings("unused") + private @Nullable Cyclic nullableCyclicField; + } + + @Test + public void testNullableNonDeterministicField() { + assertNonDeterministic( + AvroCoder.of(NullableCyclic.class), + reasonField( + NullableCyclic.class, + "nullableNullableCyclicField", + NullableCyclic.class.getName() + " appears recursively")); + assertNonDeterministic( + AvroCoder.of(NullableCyclicField.class), + reasonField(Cyclic.class, "cyclicField", Cyclic.class.getName() + " appears recursively")); + assertNonDeterministic( + AvroCoder.of(NullableNonDeterministicField.class), + reasonField(UnorderedMapClass.class, "mapField", " may not be deterministically ordered")); + } + + /** + * Tests that a parameterized class can have an automatically generated schema if the generic + * field is annotated with a union tag. + */ + @Test + public void testGenericClassWithUnionAnnotation() throws Exception { + // Cast is safe as long as the same coder is used for encoding and decoding. + @SuppressWarnings({"unchecked", "rawtypes"}) + AvroCoder> coder = + (AvroCoder) AvroCoder.of(GenericWithAnnotation.class); + + assertThat( + coder.getSchema().getField("onlySomeTypesAllowed").schema().getType(), + equalTo(Schema.Type.UNION)); + + CoderProperties.coderDecodeEncodeEqual(coder, new GenericWithAnnotation<>("hello")); + } + + private static class GenericWithAnnotation { + @AvroSchema("[\"string\", \"int\"]") + private T onlySomeTypesAllowed; + + public GenericWithAnnotation(T value) { + onlySomeTypesAllowed = value; + } + + // For deserialization only + @SuppressWarnings("unused") + protected GenericWithAnnotation() {} + + @Override + public boolean equals(@Nullable Object other) { + return other instanceof GenericWithAnnotation + && onlySomeTypesAllowed.equals(((GenericWithAnnotation) other).onlySomeTypesAllowed); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), onlySomeTypesAllowed); + } + } + + @Test + public void testAvroCoderForGenerics() throws Exception { + Schema fooSchema = AvroCoder.of(Foo.class).getSchema(); + Schema schema = + new Schema.Parser() + .parse( + "{" + + "\"type\":\"record\"," + + "\"name\":\"SomeGeneric\"," + + "\"namespace\":\"ns\"," + + "\"fields\":[" + + " {\"name\":\"foo\", \"type\":" + + fooSchema.toString() + + "}" + + "]}"); + @SuppressWarnings("rawtypes") + AvroCoder coder = AvroCoder.of(SomeGeneric.class, schema); + + assertNonDeterministic(coder, reasonField(SomeGeneric.class, "foo", "erasure")); + } + + @Test + public void testEncodedTypeDescriptor() throws Exception { + AvroCoder coder = AvroCoder.of(Pojo.class); + assertThat(coder.getEncodedTypeDescriptor(), equalTo(TypeDescriptor.of(Pojo.class))); + } + + private static class SomeGeneric { + @SuppressWarnings("unused") + private T foo; + } + + private static class Foo { + @SuppressWarnings("unused") + String id; + } +} diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoderTestPojo.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoderTestPojo.java new file mode 100644 index 000000000000..9d1700313dfa --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/coders/AvroCoderTestPojo.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.coders; + +import java.util.Objects; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** A Pojo at the top level for use in tests. */ +class AvroCoderTestPojo { + + public String text; + + // Empty constructor required for Avro decoding. + @SuppressWarnings("unused") + public AvroCoderTestPojo() {} + + public AvroCoderTestPojo(String text) { + this.text = text; + } + + @Override + public boolean equals(@Nullable Object other) { + return (other instanceof AvroCoderTestPojo) && ((AvroCoderTestPojo) other).text.equals(text); + } + + @Override + public int hashCode() { + return Objects.hash(AvroCoderTestPojo.class, text); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("text", text).toString(); + } +} diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroIOTest.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroIOTest.java new file mode 100644 index 000000000000..5bf753a11f8b --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroIOTest.java @@ -0,0 +1,1587 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import static org.apache.avro.file.DataFileConstants.SNAPPY_CODEC; +import static org.apache.beam.sdk.io.Compression.AUTO; +import static org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions.RESOLVE_FILE; +import static org.apache.beam.sdk.transforms.Contextful.fn; +import static org.apache.beam.sdk.transforms.Requirements.requiresSideInputs; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects.firstNonNull; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.CodecFactory; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.DataFileStream; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.Encoder; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumReader; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.DefaultCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.io.DefaultFilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSink; +import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; +import org.apache.beam.sdk.io.FileBasedSink.OutputFileHints; +import org.apache.beam.sdk.io.FileIO; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.io.WriteFilesResult; +import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.testing.UsesTestStream; +import org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.transforms.Watch; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.windowing.AfterPane; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; + +/** Tests for AvroIO Read and Write transforms. */ +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class AvroIOTest implements Serializable { + /** Unit tests. */ + @RunWith(JUnit4.class) + public static class SimpleTests implements Serializable { + @Test + public void testAvroIOGetName() { + assertEquals("AvroIO.Read", AvroIO.read(String.class).from("/tmp/foo*/baz").getName()); + assertEquals("AvroIO.Write", AvroIO.write(String.class).to("/tmp/foo/baz").getName()); + } + + @Test + public void testWriteWithDefaultCodec() { + AvroIO.Write write = AvroIO.write(String.class).to("/tmp/foo/baz"); + assertEquals(CodecFactory.snappyCodec().toString(), write.inner.getCodec().toString()); + } + + @Test + public void testWriteWithCustomCodec() { + AvroIO.Write write = + AvroIO.write(String.class).to("/tmp/foo/baz").withCodec(CodecFactory.snappyCodec()); + assertEquals(SNAPPY_CODEC, write.inner.getCodec().toString()); + } + + @Test + public void testWriteWithSerDeCustomDeflateCodec() { + AvroIO.Write write = + AvroIO.write(String.class).to("/tmp/foo/baz").withCodec(CodecFactory.deflateCodec(9)); + + assertEquals( + CodecFactory.deflateCodec(9).toString(), + SerializableUtils.clone(write.inner.getCodec()).getCodec().toString()); + } + + @Test + public void testWriteWithSerDeCustomXZCodec() { + AvroIO.Write write = + AvroIO.write(String.class).to("/tmp/foo/baz").withCodec(CodecFactory.xzCodec(9)); + + assertEquals( + CodecFactory.xzCodec(9).toString(), + SerializableUtils.clone(write.inner.getCodec()).getCodec().toString()); + } + + @Test + public void testReadDisplayData() { + AvroIO.Read read = AvroIO.read(String.class).from("/foo.*"); + + DisplayData displayData = DisplayData.from(read); + assertThat(displayData, hasDisplayItem("filePattern", "/foo.*")); + } + } + + /** NeedsRunner tests. */ + @RunWith(Parameterized.class) + @Category(NeedsRunner.class) + public static class NeedsRunnerTests implements Serializable { + @Rule public transient TestPipeline writePipeline = TestPipeline.create(); + + @Rule public transient TestPipeline readPipeline = TestPipeline.create(); + + @Rule public transient TestPipeline windowedAvroWritePipeline = TestPipeline.create(); + + @Rule public transient TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Rule public transient ExpectedException expectedException = ExpectedException.none(); + + @Parameterized.Parameters(name = "{index}: {0}") + public static Collection params() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameterized.Parameter public boolean withBeamSchemas; + + @DefaultCoder(AvroCoder.class) + static class GenericClass { + int intField; + String stringField; + + GenericClass() {} + + GenericClass(int intField, String stringField) { + this.intField = intField; + this.stringField = stringField; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("intField", intField) + .add("stringField", stringField) + .toString(); + } + + @Override + public int hashCode() { + return Objects.hash(intField, stringField); + } + + @Override + public boolean equals(@Nullable Object other) { + if (other == null || !(other instanceof GenericClass)) { + return false; + } + GenericClass o = (GenericClass) other; + return intField == o.intField && Objects.equals(stringField, o.stringField); + } + } + + private static class ParseGenericClass + implements SerializableFunction { + @Override + public GenericClass apply(GenericRecord input) { + return new GenericClass((int) input.get("intField"), input.get("stringField").toString()); + } + + @Test + public void testWriteDisplayData() { + AvroIO.Write write = + AvroIO.write(GenericClass.class) + .to("/foo") + .withShardNameTemplate("-SS-of-NN-") + .withSuffix("bar") + .withNumShards(100) + .withCodec(CodecFactory.deflateCodec(6)); + + DisplayData displayData = DisplayData.from(write); + + assertThat(displayData, hasDisplayItem("filePrefix", "/foo")); + assertThat(displayData, hasDisplayItem("shardNameTemplate", "-SS-of-NN-")); + assertThat(displayData, hasDisplayItem("fileSuffix", "bar")); + assertThat( + displayData, + hasDisplayItem( + "schema", + "{\"type\":\"record\",\"name\":\"GenericClass\",\"namespace\":\"org.apache.beam.sdk.io" + + ".AvroIOTest$\",\"fields\":[{\"name\":\"intField\",\"type\":\"int\"}," + + "{\"name\":\"stringField\",\"type\":\"string\"}]}")); + assertThat(displayData, hasDisplayItem("numShards", 100)); + assertThat(displayData, hasDisplayItem("codec", CodecFactory.deflateCodec(6).toString())); + } + } + + private enum Sharding { + RUNNER_DETERMINED, + WITHOUT_SHARDING, + FIXED_3_SHARDS + } + + private enum WriteMethod { + AVROIO_WRITE, + AVROIO_SINK_WITH_CLASS, + AVROIO_SINK_WITH_SCHEMA, + /** @deprecated Test code for the deprecated {AvroIO.RecordFormatter}. */ + @Deprecated + AVROIO_SINK_WITH_FORMATTER + } + + private static final String SCHEMA_STRING = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"AvroGeneratedUser\",\n" + + " \"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"},\n" + + " {\"name\": \"favorite_number\", \"type\": [\"int\", \"null\"]},\n" + + " {\"name\": \"favorite_color\", \"type\": [\"string\", \"null\"]}\n" + + " ]\n" + + "}"; + + private static final Schema SCHEMA = new Schema.Parser().parse(SCHEMA_STRING); + + @Test + @Category(NeedsRunner.class) + public void testWriteThenReadJavaClass() throws Throwable { + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); + File outputFile = tmpFolder.newFile("output.avro"); + + writePipeline + .apply(Create.of(values)) + .apply( + AvroIO.write(GenericClass.class) + .to(writePipeline.newProvider(outputFile.getAbsolutePath())) + .withoutSharding()); + writePipeline.run(); + + PAssert.that( + readPipeline.apply( + "Read", + AvroIO.read(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .from(readPipeline.newProvider(outputFile.getAbsolutePath())))) + .containsInAnyOrder(values); + + readPipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testWriteThenReadCustomType() throws Throwable { + List values = Arrays.asList(0L, 1L, 2L); + File outputFile = tmpFolder.newFile("output.avro"); + + writePipeline + .apply(Create.of(values)) + .apply( + AvroIO.writeCustomType() + .to(writePipeline.newProvider(outputFile.getAbsolutePath())) + .withFormatFunction(new CreateGenericClass()) + .withSchema(ReflectData.get().getSchema(GenericClass.class)) + .withoutSharding()); + writePipeline.run(); + + PAssert.that( + readPipeline + .apply( + "Read", + AvroIO.read(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .from(readPipeline.newProvider(outputFile.getAbsolutePath()))) + .apply( + MapElements.via( + new SimpleFunction() { + @Override + public Long apply(GenericClass input) { + return (long) input.intField; + } + }))) + .containsInAnyOrder(values); + + readPipeline.run(); + } + + private void testWriteThenReadGeneratedClass( + AvroIO.Write writeTransform, AvroIO.Read readTransform) throws Exception { + File outputFile = tmpFolder.newFile("output.avro"); + + List values = + ImmutableList.of( + (T) new AvroGeneratedUser("Bob", 256, null), + (T) new AvroGeneratedUser("Alice", 128, null), + (T) new AvroGeneratedUser("Ted", null, "white")); + + writePipeline + .apply(Create.of(values)) + .apply( + writeTransform + .to(writePipeline.newProvider(outputFile.getAbsolutePath())) + .withoutSharding()); + writePipeline.run(); + + PAssert.that( + readPipeline.apply( + "Read", + readTransform.from(readPipeline.newProvider(outputFile.getAbsolutePath())))) + .containsInAnyOrder(values); + + readPipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testWriteThenReadGeneratedClassWithClass() throws Throwable { + testWriteThenReadGeneratedClass( + AvroIO.write(AvroGeneratedUser.class), + AvroIO.read(AvroGeneratedUser.class).withBeamSchemas(withBeamSchemas)); + } + + @Test + @Category(NeedsRunner.class) + public void testWriteThenReadGeneratedClassWithSchema() throws Throwable { + testWriteThenReadGeneratedClass( + AvroIO.writeGenericRecords(SCHEMA), + AvroIO.readGenericRecords(SCHEMA).withBeamSchemas(withBeamSchemas)); + } + + @Test + @Category(NeedsRunner.class) + public void testWriteThenReadGeneratedClassWithSchemaString() throws Throwable { + testWriteThenReadGeneratedClass( + AvroIO.writeGenericRecords(SCHEMA.toString()), + AvroIO.readGenericRecords(SCHEMA.toString()).withBeamSchemas(withBeamSchemas)); + } + + @Test + @Category(NeedsRunner.class) + public void testWriteSingleFileThenReadUsingAllMethods() throws Throwable { + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); + File outputFile = tmpFolder.newFile("output.avro"); + + writePipeline + .apply(Create.of(values)) + .apply( + AvroIO.write(GenericClass.class).to(outputFile.getAbsolutePath()).withoutSharding()); + writePipeline.run(); + + // Test the same data using all versions of read(). + PCollection path = + readPipeline.apply("Create path", Create.of(outputFile.getAbsolutePath())); + PAssert.that( + readPipeline.apply( + "Read", + AvroIO.read(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .from(outputFile.getAbsolutePath()))) + .containsInAnyOrder(values); + PAssert.that( + readPipeline.apply( + "Read withHintMatchesManyFiles", + AvroIO.read(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .from(outputFile.getAbsolutePath()) + .withHintMatchesManyFiles())) + .containsInAnyOrder(values); + PAssert.that( + path.apply("MatchAllReadFiles", FileIO.matchAll()) + .apply("ReadMatchesReadFiles", FileIO.readMatches().withCompression(AUTO)) + .apply( + "ReadFiles", + AvroIO.readFiles(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(values); + PAssert.that( + path.apply( + "ReadAll", + AvroIO.readAll(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(values); + PAssert.that( + readPipeline.apply( + "Parse", + AvroIO.parseGenericRecords(new ParseGenericClass()) + .from(outputFile.getAbsolutePath()) + .withCoder(AvroCoder.of(GenericClass.class)))) + .containsInAnyOrder(values); + PAssert.that( + readPipeline.apply( + "Parse withHintMatchesManyFiles", + AvroIO.parseGenericRecords(new ParseGenericClass()) + .from(outputFile.getAbsolutePath()) + .withCoder(AvroCoder.of(GenericClass.class)) + .withHintMatchesManyFiles())) + .containsInAnyOrder(values); + PAssert.that( + path.apply("MatchAllParseFilesGenericRecords", FileIO.matchAll()) + .apply( + "ReadMatchesParseFilesGenericRecords", + FileIO.readMatches() + .withDirectoryTreatment(FileIO.ReadMatches.DirectoryTreatment.PROHIBIT)) + .apply( + "ParseFilesGenericRecords", + AvroIO.parseFilesGenericRecords(new ParseGenericClass()) + .withCoder(AvroCoder.of(GenericClass.class)) + .withUsesReshuffle(false) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(values); + PAssert.that( + path.apply("MatchAllParseFilesGenericRecordsWithShuffle", FileIO.matchAll()) + .apply( + "ReadMatchesParseFilesGenericRecordsWithShuffle", + FileIO.readMatches() + .withDirectoryTreatment(FileIO.ReadMatches.DirectoryTreatment.PROHIBIT)) + .apply( + "ParseFilesGenericRecordsWithShuffle", + AvroIO.parseFilesGenericRecords(new ParseGenericClass()) + .withCoder(AvroCoder.of(GenericClass.class)) + .withUsesReshuffle(true) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(values); + PAssert.that( + path.apply( + "ParseAllGenericRecords", + AvroIO.parseAllGenericRecords(new ParseGenericClass()) + .withCoder(AvroCoder.of(GenericClass.class)) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(values); + + readPipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testWriteThenReadMultipleFilepatterns() { + List firstValues = new ArrayList<>(); + List secondValues = new ArrayList<>(); + for (int i = 0; i < 10; ++i) { + firstValues.add(new GenericClass(i, "a" + i)); + secondValues.add(new GenericClass(i, "b" + i)); + } + writePipeline + .apply("Create first", Create.of(firstValues)) + .apply( + "Write first", + AvroIO.write(GenericClass.class) + .to(tmpFolder.getRoot().getAbsolutePath() + "/first") + .withNumShards(2)); + writePipeline + .apply("Create second", Create.of(secondValues)) + .apply( + "Write second", + AvroIO.write(GenericClass.class) + .to(tmpFolder.getRoot().getAbsolutePath() + "/second") + .withNumShards(3)); + writePipeline.run(); + + // Test readFiles(), readAll(), parseFilesGenericRecords() and parseAllGenericRecords(). + PCollection paths = + readPipeline.apply( + "Create paths", + Create.of( + tmpFolder.getRoot().getAbsolutePath() + "/first*", + tmpFolder.getRoot().getAbsolutePath() + "/second*")); + PAssert.that( + paths + .apply("MatchAllReadFiles", FileIO.matchAll()) + .apply("ReadMatchesReadFiles", FileIO.readMatches().withCompression(AUTO)) + .apply( + "ReadFiles", + AvroIO.readFiles(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(Iterables.concat(firstValues, secondValues)); + PAssert.that( + paths.apply( + "ReadAll", + AvroIO.readAll(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(Iterables.concat(firstValues, secondValues)); + PAssert.that( + paths + .apply("MatchAllParseFilesGenericRecords", FileIO.matchAll()) + .apply( + "ReadMatchesParseFilesGenericRecords", + FileIO.readMatches() + .withDirectoryTreatment(FileIO.ReadMatches.DirectoryTreatment.PROHIBIT)) + .apply( + "ParseFilesGenericRecords", + AvroIO.parseFilesGenericRecords(new ParseGenericClass()) + .withCoder(AvroCoder.of(GenericClass.class)) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(Iterables.concat(firstValues, secondValues)); + PAssert.that( + paths.apply( + "ParseAllGenericRecords", + AvroIO.parseAllGenericRecords(new ParseGenericClass()) + .withCoder(AvroCoder.of(GenericClass.class)) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(Iterables.concat(firstValues, secondValues)); + + readPipeline.run(); + } + + private static class CreateGenericClass extends SimpleFunction { + @Override + public GenericClass apply(Long i) { + return new GenericClass(i.intValue(), "value" + i); + } + } + + @Test + @Category({NeedsRunner.class, UsesUnboundedSplittableParDo.class}) + public void testContinuouslyWriteAndReadMultipleFilepatterns() { + SimpleFunction mapFn = new CreateGenericClass(); + List firstValues = new ArrayList<>(); + List secondValues = new ArrayList<>(); + for (int i = 0; i < 7; ++i) { + (i < 3 ? firstValues : secondValues).add(mapFn.apply((long) i)); + } + // Configure windowing of the input so that it fires every time a new element is generated, + // so that files are written continuously. + Window window = + Window.into(FixedWindows.of(Duration.millis(100))) + .withAllowedLateness(Duration.ZERO) + .triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))) + .discardingFiredPanes(); + readPipeline + .apply("Sequence first", GenerateSequence.from(0).to(3).withRate(1, Duration.millis(300))) + .apply("Window first", window) + .apply("Map first", MapElements.via(mapFn)) + .apply( + "Write first", + AvroIO.write(GenericClass.class) + .to(tmpFolder.getRoot().getAbsolutePath() + "/first") + .withNumShards(2) + .withWindowedWrites()); + readPipeline + .apply( + "Sequence second", GenerateSequence.from(3).to(7).withRate(1, Duration.millis(300))) + .apply("Window second", window) + .apply("Map second", MapElements.via(mapFn)) + .apply( + "Write second", + AvroIO.write(GenericClass.class) + .to(tmpFolder.getRoot().getAbsolutePath() + "/second") + .withNumShards(3) + .withWindowedWrites()); + + // Test read(), readFiles(), readAll(), parse(), parseFilesGenericRecords() and + // parseAllGenericRecords() with watchForNewFiles(). + PAssert.that( + readPipeline.apply( + "Read", + AvroIO.read(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .from(tmpFolder.getRoot().getAbsolutePath() + "/first*") + .watchForNewFiles( + Duration.millis(100), + Watch.Growth.afterTimeSinceNewOutput(Duration.standardSeconds(3))))) + .containsInAnyOrder(firstValues); + PAssert.that( + readPipeline.apply( + "Parse", + AvroIO.parseGenericRecords(new ParseGenericClass()) + .from(tmpFolder.getRoot().getAbsolutePath() + "/first*") + .watchForNewFiles( + Duration.millis(100), + Watch.Growth.afterTimeSinceNewOutput(Duration.standardSeconds(3))))) + .containsInAnyOrder(firstValues); + + PCollection paths = + readPipeline.apply( + "Create paths", + Create.of( + tmpFolder.getRoot().getAbsolutePath() + "/first*", + tmpFolder.getRoot().getAbsolutePath() + "/second*")); + PAssert.that( + paths + .apply( + "Match All Read files", + FileIO.matchAll() + .continuously( + Duration.millis(100), + Watch.Growth.afterTimeSinceNewOutput(Duration.standardSeconds(3)))) + .apply( + "Read Matches Read files", + FileIO.readMatches() + .withDirectoryTreatment(FileIO.ReadMatches.DirectoryTreatment.PROHIBIT)) + .apply( + "Read files", + AvroIO.readFiles(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(Iterables.concat(firstValues, secondValues)); + PAssert.that( + paths.apply( + "Read all", + AvroIO.readAll(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .watchForNewFiles( + Duration.millis(100), + Watch.Growth.afterTimeSinceNewOutput(Duration.standardSeconds(3))) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(Iterables.concat(firstValues, secondValues)); + PAssert.that( + paths + .apply( + "Match All ParseFilesGenericRecords", + FileIO.matchAll() + .continuously( + Duration.millis(100), + Watch.Growth.afterTimeSinceNewOutput(Duration.standardSeconds(3)))) + .apply( + "Match Matches ParseFilesGenericRecords", + FileIO.readMatches() + .withDirectoryTreatment(FileIO.ReadMatches.DirectoryTreatment.PROHIBIT)) + .apply( + "ParseFilesGenericRecords", + AvroIO.parseFilesGenericRecords(new ParseGenericClass()) + .withCoder(AvroCoder.of(GenericClass.class)) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(Iterables.concat(firstValues, secondValues)); + PAssert.that( + paths.apply( + "ParseAllGenericRecords", + AvroIO.parseAllGenericRecords(new ParseGenericClass()) + .withCoder(AvroCoder.of(GenericClass.class)) + .watchForNewFiles( + Duration.millis(100), + Watch.Growth.afterTimeSinceNewOutput(Duration.standardSeconds(3))) + .withDesiredBundleSizeBytes(10))) + .containsInAnyOrder(Iterables.concat(firstValues, secondValues)); + readPipeline.run(); + } + + @Test + @SuppressWarnings("unchecked") + @Category(NeedsRunner.class) + public void testCompressedWriteAndReadASingleFile() throws Throwable { + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); + File outputFile = tmpFolder.newFile("output.avro"); + + writePipeline + .apply(Create.of(values)) + .apply( + AvroIO.write(GenericClass.class) + .to(outputFile.getAbsolutePath()) + .withoutSharding() + .withCodec(CodecFactory.deflateCodec(9))); + writePipeline.run(); + + PAssert.that( + readPipeline.apply( + AvroIO.read(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .from(outputFile.getAbsolutePath()))) + .containsInAnyOrder(values); + readPipeline.run(); + + try (DataFileStream dataFileStream = + new DataFileStream(new FileInputStream(outputFile), new GenericDatumReader())) { + assertEquals("deflate", dataFileStream.getMetaString("avro.codec")); + } + } + + @Test + @SuppressWarnings("unchecked") + @Category(NeedsRunner.class) + public void testWriteThenReadASingleFileWithNullCodec() throws Throwable { + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); + File outputFile = tmpFolder.newFile("output.avro"); + + writePipeline + .apply(Create.of(values)) + .apply( + AvroIO.write(GenericClass.class) + .to(outputFile.getAbsolutePath()) + .withoutSharding() + .withCodec(CodecFactory.nullCodec())); + writePipeline.run(); + + PAssert.that( + readPipeline.apply( + AvroIO.read(GenericClass.class) + .withBeamSchemas(withBeamSchemas) + .from(outputFile.getAbsolutePath()))) + .containsInAnyOrder(values); + readPipeline.run(); + + try (DataFileStream dataFileStream = + new DataFileStream(new FileInputStream(outputFile), new GenericDatumReader())) { + assertEquals("null", dataFileStream.getMetaString("avro.codec")); + } + } + + @DefaultCoder(AvroCoder.class) + static class GenericClassV2 { + int intField; + String stringField; + @org.apache.avro.reflect.Nullable String nullableField; + + GenericClassV2() {} + + GenericClassV2(int intValue, String stringValue, String nullableValue) { + this.intField = intValue; + this.stringField = stringValue; + this.nullableField = nullableValue; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(getClass()) + .add("intField", intField) + .add("stringField", stringField) + .add("nullableField", nullableField) + .toString(); + } + + @Override + public int hashCode() { + return Objects.hash(intField, stringField, nullableField); + } + + @Override + public boolean equals(@Nullable Object other) { + if (!(other instanceof GenericClassV2)) { + return false; + } + GenericClassV2 o = (GenericClassV2) other; + return intField == o.intField + && Objects.equals(stringField, o.stringField) + && Objects.equals(nullableField, o.nullableField); + } + } + + /** + * Tests that {@code AvroIO} can read an upgraded version of an old class, as long as the schema + * resolution process succeeds. This test covers the case when a new, {@code @Nullable} field + * has been added. + * + *

For more information, see http://avro.apache.org/docs/1.7.7/spec.html#Schema+Resolution + */ + @Test + @Category(NeedsRunner.class) + public void testWriteThenReadSchemaUpgrade() throws Throwable { + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); + File outputFile = tmpFolder.newFile("output.avro"); + + writePipeline + .apply(Create.of(values)) + .apply( + AvroIO.write(GenericClass.class).to(outputFile.getAbsolutePath()).withoutSharding()); + writePipeline.run(); + + List expected = + ImmutableList.of(new GenericClassV2(3, "hi", null), new GenericClassV2(5, "bar", null)); + + PAssert.that( + readPipeline.apply( + AvroIO.read(GenericClassV2.class) + .withBeamSchemas(withBeamSchemas) + .from(outputFile.getAbsolutePath()))) + .containsInAnyOrder(expected); + readPipeline.run(); + } + + private static class WindowedFilenamePolicy extends FilenamePolicy { + final ResourceId outputFilePrefix; + + WindowedFilenamePolicy(ResourceId outputFilePrefix) { + this.outputFilePrefix = outputFilePrefix; + } + + @Override + public ResourceId windowedFilename( + int shardNumber, + int numShards, + BoundedWindow window, + PaneInfo paneInfo, + OutputFileHints outputFileHints) { + String filenamePrefix = + outputFilePrefix.isDirectory() ? "" : firstNonNull(outputFilePrefix.getFilename(), ""); + + IntervalWindow interval = (IntervalWindow) window; + String windowStr = + String.format("%s-%s", interval.start().toString(), interval.end().toString()); + String filename = + String.format( + "%s-%s-%s-of-%s-pane-%s%s%s.avro", + filenamePrefix, + windowStr, + shardNumber, + numShards, + paneInfo.getIndex(), + paneInfo.isLast() ? "-last" : "", + outputFileHints.getSuggestedFilenameSuffix()); + return outputFilePrefix.getCurrentDirectory().resolve(filename, RESOLVE_FILE); + } + + @Override + public ResourceId unwindowedFilename( + int shardNumber, int numShards, OutputFileHints outputFileHints) { + throw new UnsupportedOperationException("Expecting windowed outputs only"); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add( + DisplayData.item("fileNamePrefix", outputFilePrefix.toString()) + .withLabel("File Name Prefix")); + } + } + + @Test + @Category({NeedsRunner.class, UsesTestStream.class}) + public void testWriteWindowed() throws Throwable { + testWindowedAvroIOWriteUsingMethod(WriteMethod.AVROIO_WRITE); + } + + @Test + @Category({NeedsRunner.class, UsesTestStream.class}) + public void testWindowedAvroIOWriteViaSink() throws Throwable { + testWindowedAvroIOWriteUsingMethod(WriteMethod.AVROIO_SINK_WITH_CLASS); + } + + void testWindowedAvroIOWriteUsingMethod(WriteMethod method) throws IOException { + Path baseDir = Files.createTempDirectory(tmpFolder.getRoot().toPath(), "testwrite"); + final String baseFilename = baseDir.resolve("prefix").toString(); + + Instant base = new Instant(0); + ArrayList allElements = new ArrayList<>(); + ArrayList> firstWindowElements = new ArrayList<>(); + ArrayList firstWindowTimestamps = + Lists.newArrayList( + base.plus(Duration.ZERO), base.plus(Duration.standardSeconds(10)), + base.plus(Duration.standardSeconds(20)), base.plus(Duration.standardSeconds(30))); + + Random random = new Random(); + for (int i = 0; i < 100; ++i) { + GenericClass item = new GenericClass(i, String.valueOf(i)); + allElements.add(item); + firstWindowElements.add( + TimestampedValue.of( + item, firstWindowTimestamps.get(random.nextInt(firstWindowTimestamps.size())))); + } + + ArrayList> secondWindowElements = new ArrayList<>(); + ArrayList secondWindowTimestamps = + Lists.newArrayList( + base.plus(Duration.standardSeconds(60)), base.plus(Duration.standardSeconds(70)), + base.plus(Duration.standardSeconds(80)), base.plus(Duration.standardSeconds(90))); + for (int i = 100; i < 200; ++i) { + GenericClass item = new GenericClass(i, String.valueOf(i)); + allElements.add(new GenericClass(i, String.valueOf(i))); + secondWindowElements.add( + TimestampedValue.of( + item, secondWindowTimestamps.get(random.nextInt(secondWindowTimestamps.size())))); + } + + TimestampedValue[] firstWindowArray = + firstWindowElements.toArray(new TimestampedValue[100]); + TimestampedValue[] secondWindowArray = + secondWindowElements.toArray(new TimestampedValue[100]); + + TestStream values = + TestStream.create(AvroCoder.of(GenericClass.class)) + .advanceWatermarkTo(new Instant(0)) + .addElements( + firstWindowArray[0], + Arrays.copyOfRange(firstWindowArray, 1, firstWindowArray.length)) + .advanceWatermarkTo(new Instant(0).plus(Duration.standardMinutes(1))) + .addElements( + secondWindowArray[0], + Arrays.copyOfRange(secondWindowArray, 1, secondWindowArray.length)) + .advanceWatermarkToInfinity(); + + final PTransform, WriteFilesResult> write; + switch (method) { + case AVROIO_WRITE: + { + FilenamePolicy policy = + new WindowedFilenamePolicy( + FileBasedSink.convertToFileResourceIfPossible(baseFilename)); + write = + AvroIO.write(GenericClass.class) + .to(policy) + .withTempDirectory( + StaticValueProvider.of( + FileSystems.matchNewResource(baseDir.toString(), true))) + .withWindowedWrites() + .withNumShards(2) + .withOutputFilenames(); + break; + } + + case AVROIO_SINK_WITH_CLASS: + { + write = + FileIO.write() + .via(AvroIO.sink(GenericClass.class)) + .to(baseDir.toString()) + .withPrefix("prefix") + .withSuffix(".avro") + .withTempDirectory(baseDir.toString()) + .withNumShards(2); + break; + } + + default: + throw new UnsupportedOperationException(); + } + windowedAvroWritePipeline + .apply(values) + .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1)))) + .apply(write); + windowedAvroWritePipeline.run(); + + // Validate that the data written matches the expected elements in the expected order + List expectedFiles = new ArrayList<>(); + for (int shard = 0; shard < 2; shard++) { + for (int window = 0; window < 2; window++) { + Instant windowStart = new Instant(0).plus(Duration.standardMinutes(window)); + IntervalWindow iw = new IntervalWindow(windowStart, Duration.standardMinutes(1)); + String baseAndWindow = baseFilename + "-" + iw.start() + "-" + iw.end(); + switch (method) { + case AVROIO_WRITE: + expectedFiles.add(new File(baseAndWindow + "-" + shard + "-of-2-pane-0-last.avro")); + break; + case AVROIO_SINK_WITH_CLASS: + expectedFiles.add(new File(baseAndWindow + "-0000" + shard + "-of-00002.avro")); + break; + default: + throw new UnsupportedOperationException("Unknown write method " + method); + } + } + } + + List actualElements = new ArrayList<>(); + for (File outputFile : expectedFiles) { + assertTrue("Expected output file " + outputFile.getAbsolutePath(), outputFile.exists()); + try (DataFileReader reader = + new DataFileReader<>( + outputFile, + new ReflectDatumReader<>(ReflectData.get().getSchema(GenericClass.class)))) { + Iterators.addAll(actualElements, reader); + } + outputFile.delete(); + } + assertThat(actualElements, containsInAnyOrder(allElements.toArray())); + } + + private static final String SCHEMA_TEMPLATE_STRING = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"$$TestTemplateSchema\",\n" + + " \"fields\": [\n" + + " {\"name\": \"$$full\", \"type\": \"string\"},\n" + + " {\"name\": \"$$suffix\", \"type\": [\"string\", \"null\"]}\n" + + " ]\n" + + "}"; + + private static String schemaFromPrefix(String prefix) { + return SCHEMA_TEMPLATE_STRING.replace("$$", prefix); + } + + private static GenericRecord createRecord(String record, String prefix, Schema schema) { + GenericRecord genericRecord = new GenericData.Record(schema); + genericRecord.put(prefix + "full", record); + genericRecord.put(prefix + "suffix", record.substring(1)); + return genericRecord; + } + + private static class TestDynamicDestinations + extends DynamicAvroDestinations { + final ResourceId baseDir; + final PCollectionView> schemaView; + + TestDynamicDestinations(ResourceId baseDir, PCollectionView> schemaView) { + this.baseDir = baseDir; + this.schemaView = schemaView; + } + + @Override + public Schema getSchema(String destination) { + // Return a per-destination schema. + String schema = sideInput(schemaView).get(destination); + return new Schema.Parser().parse(schema); + } + + @Override + public List> getSideInputs() { + return ImmutableList.of(schemaView); + } + + @Override + public GenericRecord formatRecord(String record) { + String prefix = record.substring(0, 1); + return createRecord(record, prefix, getSchema(prefix)); + } + + @Override + public String getDestination(String element) { + // Destination is based on first character of string. + return element.substring(0, 1); + } + + @Override + public String getDefaultDestination() { + return ""; + } + + @Override + public FilenamePolicy getFilenamePolicy(String destination) { + return DefaultFilenamePolicy.fromStandardParameters( + StaticValueProvider.of(baseDir.resolve("file_" + destination, RESOLVE_FILE)), + "-SSSSS-of-NNNNN", + ".avro", + false); + } + } + + /** + * Example of a {@link Coder} for a collection of Avro records with different schemas. + * + *

All the schemas are known at pipeline construction, and are keyed internally on the prefix + * character (lower byte only for UTF-8 data). + */ + private static class AvroMultiplexCoder extends Coder { + + /** Lookup table for the possible schemas, keyed on the prefix character. */ + private final Map> coderMap = Maps.newHashMap(); + + protected AvroMultiplexCoder(Map schemaMap) { + for (Map.Entry entry : schemaMap.entrySet()) { + coderMap.put( + entry.getKey().charAt(0), AvroCoder.of(new Schema.Parser().parse(entry.getValue()))); + } + } + + @Override + public void encode(GenericRecord value, OutputStream outStream) throws IOException { + char prefix = value.getSchema().getName().charAt(0); + outStream.write(prefix); // Only reads and writes the low byte. + coderMap.get(prefix).encode(value, outStream); + } + + @Override + public GenericRecord decode(InputStream inStream) throws CoderException, IOException { + char prefix = (char) inStream.read(); + return coderMap.get(prefix).decode(inStream); + } + + @Override + public List> getCoderArguments() { + return Collections.emptyList(); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + for (AvroCoder internalCoder : coderMap.values()) { + internalCoder.verifyDeterministic(); + } + } + } + + private void testDynamicDestinationsUnwindowedWithSharding( + WriteMethod writeMethod, Sharding sharding) throws Exception { + final ResourceId baseDir = + FileSystems.matchNewResource( + Files.createTempDirectory(tmpFolder.getRoot().toPath(), "testDynamicDestinations") + .toString(), + true); + + List elements = Lists.newArrayList("aaaa", "aaab", "baaa", "baab", "caaa", "caab"); + Multimap expectedElements = ArrayListMultimap.create(); + Map schemaMap = Maps.newHashMap(); + for (String element : elements) { + String prefix = element.substring(0, 1); + String jsonSchema = schemaFromPrefix(prefix); + schemaMap.put(prefix, jsonSchema); + expectedElements.put( + prefix, createRecord(element, prefix, new Schema.Parser().parse(jsonSchema))); + } + final PCollectionView> schemaView = + writePipeline.apply("createSchemaView", Create.of(schemaMap)).apply(View.asMap()); + + PCollection input = + writePipeline.apply("createInput", Create.of(elements).withCoder(StringUtf8Coder.of())); + + switch (writeMethod) { + case AVROIO_WRITE: + { + AvroIO.TypedWrite write = + AvroIO.writeCustomTypeToGenericRecords() + .to(new TestDynamicDestinations(baseDir, schemaView)) + .withTempDirectory(baseDir); + + switch (sharding) { + case RUNNER_DETERMINED: + break; + case WITHOUT_SHARDING: + write = write.withoutSharding(); + break; + case FIXED_3_SHARDS: + write = write.withNumShards(3); + break; + default: + throw new IllegalArgumentException("Unknown sharding " + sharding); + } + + input.apply(write); + break; + } + + case AVROIO_SINK_WITH_SCHEMA: + { + FileIO.Write write = + FileIO.writeDynamic() + .by( + fn( + (element, c) -> { + c.sideInput(schemaView); // Ignore result + return element.getSchema().getName().substring(0, 1); + }, + requiresSideInputs(schemaView))) + .via( + fn( + (dest, c) -> { + Schema schema = + new Schema.Parser().parse(c.sideInput(schemaView).get(dest)); + return AvroIO.sink(schema); + }, + requiresSideInputs(schemaView))) + .to(baseDir.toString()) + .withNaming( + fn( + (dest, c) -> { + c.sideInput(schemaView); // Ignore result + return FileIO.Write.defaultNaming("file_" + dest, ".avro"); + }, + requiresSideInputs(schemaView))) + .withTempDirectory(baseDir.toString()) + .withDestinationCoder(StringUtf8Coder.of()) + .withIgnoreWindowing(); + switch (sharding) { + case RUNNER_DETERMINED: + break; + case WITHOUT_SHARDING: + write = write.withNumShards(1); + break; + case FIXED_3_SHARDS: + write = write.withNumShards(3); + break; + default: + throw new IllegalArgumentException("Unknown sharding " + sharding); + } + + MapElements toRecord = + MapElements.via( + new SimpleFunction() { + @Override + public GenericRecord apply(String element) { + String prefix = element.substring(0, 1); + GenericRecord record = + new GenericData.Record( + new Schema.Parser().parse(schemaFromPrefix(prefix))); + record.put(prefix + "full", element); + record.put(prefix + "suffix", element.substring(1)); + return record; + } + }); + + input.apply(toRecord).setCoder(new AvroMultiplexCoder(schemaMap)).apply(write); + break; + } + + case AVROIO_SINK_WITH_FORMATTER: + { + final AvroIO.RecordFormatter formatter = + (element, schema) -> { + String prefix = element.substring(0, 1); + GenericRecord record = new GenericData.Record(schema); + record.put(prefix + "full", element); + record.put(prefix + "suffix", element.substring(1)); + return record; + }; + FileIO.Write write = + FileIO.writeDynamic() + .by( + fn( + (element, c) -> { + c.sideInput(schemaView); // Ignore result + return element.substring(0, 1); + }, + requiresSideInputs(schemaView))) + .via( + fn( + (dest, c) -> { + Schema schema = + new Schema.Parser().parse(c.sideInput(schemaView).get(dest)); + return AvroIO.sinkViaGenericRecords(schema, formatter); + }, + requiresSideInputs(schemaView))) + .to(baseDir.toString()) + .withNaming( + fn( + (dest, c) -> { + c.sideInput(schemaView); // Ignore result + return FileIO.Write.defaultNaming("file_" + dest, ".avro"); + }, + requiresSideInputs(schemaView))) + .withTempDirectory(baseDir.toString()) + .withDestinationCoder(StringUtf8Coder.of()) + .withIgnoreWindowing(); + switch (sharding) { + case RUNNER_DETERMINED: + break; + case WITHOUT_SHARDING: + write = write.withNumShards(1); + break; + case FIXED_3_SHARDS: + write = write.withNumShards(3); + break; + default: + throw new IllegalArgumentException("Unknown sharding " + sharding); + } + + input.apply(write); + break; + } + default: + throw new UnsupportedOperationException("Unknown write method " + writeMethod); + } + + writePipeline.run(); + + // Validate that the data written matches the expected elements in the expected order. + + for (String prefix : expectedElements.keySet()) { + String shardPattern; + switch (sharding) { + case RUNNER_DETERMINED: + shardPattern = "-*"; + break; + case WITHOUT_SHARDING: + shardPattern = "-00000-of-00001"; + break; + case FIXED_3_SHARDS: + shardPattern = "-*-of-00003"; + break; + default: + throw new IllegalArgumentException("Unknown sharding " + sharding); + } + String expectedFilepattern = + baseDir.resolve("file_" + prefix + shardPattern + ".avro", RESOLVE_FILE).toString(); + + PCollection records = + readPipeline.apply( + "read_" + prefix, + AvroIO.readGenericRecords(schemaFromPrefix(prefix)) + .withBeamSchemas(withBeamSchemas) + .from(expectedFilepattern)); + PAssert.that(records).containsInAnyOrder(expectedElements.get(prefix)); + } + readPipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsRunnerDeterminedSharding() throws Exception { + testDynamicDestinationsUnwindowedWithSharding( + WriteMethod.AVROIO_WRITE, Sharding.RUNNER_DETERMINED); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsWithoutSharding() throws Exception { + testDynamicDestinationsUnwindowedWithSharding( + WriteMethod.AVROIO_WRITE, Sharding.WITHOUT_SHARDING); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsWithNumShards() throws Exception { + testDynamicDestinationsUnwindowedWithSharding( + WriteMethod.AVROIO_WRITE, Sharding.FIXED_3_SHARDS); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsViaSinkRunnerDeterminedSharding() throws Exception { + testDynamicDestinationsUnwindowedWithSharding( + WriteMethod.AVROIO_SINK_WITH_SCHEMA, Sharding.RUNNER_DETERMINED); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsViaSinkWithoutSharding() throws Exception { + testDynamicDestinationsUnwindowedWithSharding( + WriteMethod.AVROIO_SINK_WITH_SCHEMA, Sharding.WITHOUT_SHARDING); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsViaSinkWithNumShards() throws Exception { + testDynamicDestinationsUnwindowedWithSharding( + WriteMethod.AVROIO_SINK_WITH_SCHEMA, Sharding.FIXED_3_SHARDS); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsViaSinkWithFormatterRunnerDeterminedSharding() + throws Exception { + testDynamicDestinationsUnwindowedWithSharding( + WriteMethod.AVROIO_SINK_WITH_FORMATTER, Sharding.RUNNER_DETERMINED); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsViaSinkWithFormatterWithoutSharding() throws Exception { + testDynamicDestinationsUnwindowedWithSharding( + WriteMethod.AVROIO_SINK_WITH_FORMATTER, Sharding.WITHOUT_SHARDING); + } + + @Test + @Category(NeedsRunner.class) + public void testDynamicDestinationsViaSinkWithFormatterWithNumShards() throws Exception { + testDynamicDestinationsUnwindowedWithSharding( + WriteMethod.AVROIO_SINK_WITH_FORMATTER, Sharding.FIXED_3_SHARDS); + } + + @Test + @SuppressWarnings("unchecked") + @Category(NeedsRunner.class) + public void testMetadata() throws Exception { + List values = + ImmutableList.of(new GenericClass(3, "hi"), new GenericClass(5, "bar")); + File outputFile = tmpFolder.newFile("output.avro"); + + writePipeline + .apply(Create.of(values)) + .apply( + AvroIO.write(GenericClass.class) + .to(outputFile.getAbsolutePath()) + .withoutSharding() + .withMetadata( + ImmutableMap.of( + "stringKey", + "stringValue", + "longKey", + 100L, + "bytesKey", + "bytesValue".getBytes(Charsets.UTF_8)))); + writePipeline.run(); + + try (DataFileStream dataFileStream = + new DataFileStream(new FileInputStream(outputFile), new GenericDatumReader())) { + assertEquals("stringValue", dataFileStream.getMetaString("stringKey")); + assertEquals(100L, dataFileStream.getMetaLong("longKey")); + assertArrayEquals( + "bytesValue".getBytes(Charsets.UTF_8), dataFileStream.getMeta("bytesKey")); + } + } + + // using AvroCoder#createDatumReader for tests. + private void runTestWrite(String[] expectedElements, int numShards) throws IOException { + File baseOutputFile = new File(tmpFolder.getRoot(), "prefix"); + String outputFilePrefix = baseOutputFile.getAbsolutePath(); + + AvroIO.Write write = + AvroIO.write(String.class).to(outputFilePrefix).withSuffix(".avro"); + if (numShards > 1) { + write = write.withNumShards(numShards); + } else { + write = write.withoutSharding(); + } + writePipeline.apply(Create.of(ImmutableList.copyOf(expectedElements))).apply(write); + writePipeline.run(); + + String shardNameTemplate = + firstNonNull( + write.inner.getShardTemplate(), + DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE); + + assertTestOutputs(expectedElements, numShards, outputFilePrefix, shardNameTemplate); + } + + static void assertTestOutputs( + String[] expectedElements, int numShards, String outputFilePrefix, String shardNameTemplate) + throws IOException { + // Validate that the data written matches the expected elements in the expected order + List expectedFiles = new ArrayList<>(); + for (int i = 0; i < numShards; i++) { + expectedFiles.add( + new File( + DefaultFilenamePolicy.constructName( + FileBasedSink.convertToFileResourceIfPossible(outputFilePrefix), + shardNameTemplate, + ".avro", + i, + numShards, + null, + null) + .toString())); + } + + List actualElements = new ArrayList<>(); + for (File outputFile : expectedFiles) { + assertTrue("Expected output file " + outputFile.getName(), outputFile.exists()); + try (DataFileReader reader = + new DataFileReader<>( + outputFile, new ReflectDatumReader(ReflectData.get().getSchema(String.class)))) { + Iterators.addAll(actualElements, reader); + } + } + assertThat(actualElements, containsInAnyOrder(expectedElements)); + } + + @Test + @Category(NeedsRunner.class) + public void testAvroSinkWrite() throws Exception { + String[] expectedElements = new String[] {"first", "second", "third"}; + + runTestWrite(expectedElements, 1); + } + + @Test + @Category(NeedsRunner.class) + public void testAvroSinkShardedWrite() throws Exception { + String[] expectedElements = new String[] {"first", "second", "third", "fourth", "fifth"}; + + runTestWrite(expectedElements, 4); + } + + @Test + @Category(NeedsRunner.class) + public void testAvroSinkWriteWithCustomFactory() throws Exception { + Integer[] expectedElements = new Integer[] {1, 2, 3, 4, 5}; + + File baseOutputFile = new File(tmpFolder.getRoot(), "prefix"); + String outputFilePrefix = baseOutputFile.getAbsolutePath(); + + Schema recordSchema = SchemaBuilder.record("root").fields().requiredInt("i1").endRecord(); + + AvroIO.TypedWrite write = + AvroIO.writeCustomType() + .to(outputFilePrefix) + .withSchema(recordSchema) + .withFormatFunction(f -> f) + .withDatumWriterFactory( + f -> + new DatumWriter() { + private DatumWriter inner = new GenericDatumWriter<>(f); + + @Override + public void setSchema(Schema schema) { + inner.setSchema(schema); + } + + @Override + public void write(Integer datum, Encoder out) throws IOException { + GenericRecord record = + new GenericRecordBuilder(f).set("i1", datum).build(); + inner.write(record, out); + } + }) + .withSuffix(".avro"); + + write = write.withoutSharding(); + + writePipeline.apply(Create.of(ImmutableList.copyOf(expectedElements))).apply(write); + writePipeline.run(); + + File expectedFile = + new File( + DefaultFilenamePolicy.constructName( + FileBasedSink.convertToFileResourceIfPossible(outputFilePrefix), + "", + ".avro", + 1, + 1, + null, + null) + .toString()); + + assertTrue("Expected output file " + expectedFile.getName(), expectedFile.exists()); + DataFileReader dataFileReader = + new DataFileReader<>(expectedFile, new GenericDatumReader<>(recordSchema)); + + List actualRecords = new ArrayList<>(); + Iterators.addAll(actualRecords, dataFileReader); + + GenericRecord[] expectedRecords = + Arrays.stream(expectedElements) + .map(i -> new GenericRecordBuilder(recordSchema).set("i1", i).build()) + .toArray(GenericRecord[]::new); + + assertThat(actualRecords, containsInAnyOrder(expectedRecords)); + } + + // TODO: for Write only, test withSuffix, + // withShardNameTemplate and withoutSharding. + } +} diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroSchemaIOProviderTest.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroSchemaIOProviderTest.java new file mode 100644 index 000000000000..b003597200eb --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroSchemaIOProviderTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import static org.junit.Assert.assertEquals; + +import java.io.File; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.io.SchemaIO; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test for AvroSchemaIOProvider. */ +@RunWith(JUnit4.class) +public class AvroSchemaIOProviderTest { + @Rule public TestPipeline writePipeline = TestPipeline.create(); + @Rule public TestPipeline readPipeline = TestPipeline.create(); + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final Schema SCHEMA = + Schema.builder().addInt64Field("age").addStringField("age_str").build(); + + private Row createRow(long l) { + return Row.withSchema(SCHEMA).addValues(l, Long.valueOf(l).toString()).build(); + } + + @Test + @Category({NeedsRunner.class}) + public void testWriteAndReadTable() { + File destinationFile = new File(tempFolder.getRoot(), "person-info.avro"); + + AvroSchemaIOProvider provider = new AvroSchemaIOProvider(); + Row configuration = Row.withSchema(provider.configurationSchema()).addValue(null).build(); + SchemaIO io = provider.from(destinationFile.getAbsolutePath(), configuration, SCHEMA); + + List rowsList = Arrays.asList(createRow(1L), createRow(3L), createRow(4L)); + PCollection rows = + writePipeline.apply("Create", Create.of(rowsList).withCoder(RowCoder.of(SCHEMA))); + rows.apply(io.buildWriter()); + writePipeline.run(); + + PCollection read = readPipeline.begin().apply(io.buildReader()); + PAssert.that(read).containsInAnyOrder(rowsList); + readPipeline.run(); + } + + @Test + @Category({NeedsRunner.class}) + public void testStreamingWriteDefault() throws Exception { + File destinationFile = new File(tempFolder.getRoot(), "person-info"); + + AvroSchemaIOProvider provider = new AvroSchemaIOProvider(); + Row config = Row.withSchema(provider.configurationSchema()).addValue(null).build(); + SchemaIO writeIO = provider.from(destinationFile.getAbsolutePath(), config, SCHEMA); + + TestStream createEvents = + TestStream.create(RowCoder.of(SCHEMA)) + .addElements(TimestampedValue.of(createRow(1L), new Instant(1L))) + .addElements(TimestampedValue.of(createRow(2L), Instant.ofEpochSecond(120L))) + .advanceWatermarkToInfinity(); + + writePipeline.apply("create", createEvents).apply("write", writeIO.buildWriter()); + writePipeline.run(); + + // Verify we wrote two files. + String wildcardPath = destinationFile.getAbsolutePath() + "*"; + MatchResult result = FileSystems.match(wildcardPath); + assertEquals(2, result.metadata().size()); + + // Verify results of the files. + SchemaIO readIO = provider.from(wildcardPath, config, SCHEMA); + PCollection read = readPipeline.begin().apply("read", readIO.buildReader()); + PAssert.that(read).containsInAnyOrder(createRow(1L), createRow(2L)); + readPipeline.run(); + } + + @Test + @Category({NeedsRunner.class}) + public void testStreamingCustomWindowSize() throws Exception { + File destinationFile = new File(tempFolder.getRoot(), "person-info"); + + AvroSchemaIOProvider provider = new AvroSchemaIOProvider(); + Row config = + Row.withSchema(provider.configurationSchema()) + .addValue(Duration.ofMinutes(4).getSeconds()) + .build(); + SchemaIO writeIO = provider.from(destinationFile.getAbsolutePath(), config, SCHEMA); + + TestStream createEvents = + TestStream.create(RowCoder.of(SCHEMA)) + .addElements(TimestampedValue.of(createRow(1L), new Instant(1L))) + .addElements(TimestampedValue.of(createRow(2L), Instant.ofEpochSecond(120L))) + .advanceWatermarkToInfinity(); + + writePipeline.apply("create", createEvents).apply("write", writeIO.buildWriter()); + writePipeline.run(); + + // Verify we wrote one file. + String wildcardPath = destinationFile.getAbsolutePath() + "*"; + MatchResult result = FileSystems.match(wildcardPath); + assertEquals(1, result.metadata().size()); + + // Verify results of the files. + SchemaIO readIO = provider.from(wildcardPath, config, SCHEMA); + PCollection read = readPipeline.begin().apply("read", readIO.buildReader()); + PAssert.that(read).containsInAnyOrder(createRow(1L), createRow(2L)); + readPipeline.run(); + } + + @Test + @Category({NeedsRunner.class}) + public void testBatchCustomWindowSize() throws Exception { + File destinationFile = new File(tempFolder.getRoot(), "person-info"); + + AvroSchemaIOProvider provider = new AvroSchemaIOProvider(); + Row config = + Row.withSchema(provider.configurationSchema()) + .addValue(Duration.ofMinutes(4).getSeconds()) + .build(); + SchemaIO writeIO = provider.from(destinationFile.getAbsolutePath(), config, SCHEMA); + + List rowsList = Arrays.asList(createRow(1L), createRow(3L), createRow(4L)); + PCollection rows = + writePipeline.apply("Create", Create.of(rowsList).withCoder(RowCoder.of(SCHEMA))); + + rows.apply("write", writeIO.buildWriter()); + writePipeline.run(); + + // Verify we wrote one file. + String wildcardPath = destinationFile.getAbsolutePath() + "*"; + MatchResult result = FileSystems.match(wildcardPath); + assertEquals(1, result.metadata().size()); + + // Verify results of the files. + SchemaIO readIO = provider.from(wildcardPath, config, SCHEMA); + PCollection read = readPipeline.begin().apply("read", readIO.buildReader()); + PAssert.that(read).containsInAnyOrder(rowsList); + readPipeline.run(); + } +} diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroSourceTest.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroSourceTest.java new file mode 100644 index 000000000000..df382d86f215 --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/AvroSourceTest.java @@ -0,0 +1,846 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Random; +import java.util.stream.Collectors; +import org.apache.avro.Schema; +import org.apache.avro.file.CodecFactory; +import org.apache.avro.file.DataFileConstants; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.DatumWriter; +import org.apache.avro.io.Decoder; +import org.apache.avro.reflect.AvroDefault; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumWriter; +import org.apache.beam.sdk.coders.DefaultCoder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.io.BlockBasedSource; +import org.apache.beam.sdk.io.BlockBasedSource.BlockBasedReader; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.io.FileBasedSource; +import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.fs.MatchResult.Metadata; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.SourceTestUtils; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for AvroSource. */ +@RunWith(JUnit4.class) +public class AvroSourceTest { + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Rule public ExpectedException expectedException = ExpectedException.none(); + + private enum SyncBehavior { + SYNC_REGULAR, // Sync at regular, user defined intervals + SYNC_RANDOM, // Sync at random intervals + SYNC_DEFAULT // Sync at default intervals (i.e., no manual syncing). + } + + private static final int DEFAULT_RECORD_COUNT = 1000; + + /** + * Generates an input Avro file containing the given records in the temporary directory and + * returns the full path of the file. + */ + private String generateTestFile( + String filename, + List elems, + SyncBehavior syncBehavior, + int syncInterval, + AvroCoder coder, + String codec) + throws IOException { + Random random = new Random(0); + File tmpFile = tmpFolder.newFile(filename); + String path = tmpFile.toString(); + + FileOutputStream os = new FileOutputStream(tmpFile); + DatumWriter datumWriter = + coder.getType().equals(GenericRecord.class) + ? new GenericDatumWriter<>(coder.getSchema()) + : new ReflectDatumWriter<>(coder.getSchema()); + try (DataFileWriter writer = new DataFileWriter<>(datumWriter)) { + writer.setCodec(CodecFactory.fromString(codec)); + writer.create(coder.getSchema(), os); + + int recordIndex = 0; + int syncIndex = syncBehavior == SyncBehavior.SYNC_RANDOM ? random.nextInt(syncInterval) : 0; + + for (T elem : elems) { + writer.append(elem); + recordIndex++; + + switch (syncBehavior) { + case SYNC_REGULAR: + if (recordIndex == syncInterval) { + recordIndex = 0; + writer.sync(); + } + break; + case SYNC_RANDOM: + if (recordIndex == syncIndex) { + recordIndex = 0; + writer.sync(); + syncIndex = random.nextInt(syncInterval); + } + break; + case SYNC_DEFAULT: + default: + } + } + } + return path; + } + + @Test + public void testReadWithDifferentCodecs() throws Exception { + // Test reading files generated using all codecs. + String[] codecs = { + DataFileConstants.NULL_CODEC, + DataFileConstants.BZIP2_CODEC, + DataFileConstants.DEFLATE_CODEC, + DataFileConstants.SNAPPY_CODEC, + DataFileConstants.XZ_CODEC, + }; + // As Avro's default block size is 64KB, write 64K records to ensure at least one full block. + // We could make this smaller than 64KB assuming each record is at least B bytes, but then the + // test could silently stop testing the failure condition from BEAM-422. + List expected = createRandomRecords(1 << 16); + + for (String codec : codecs) { + String filename = + generateTestFile( + codec, expected, SyncBehavior.SYNC_DEFAULT, 0, AvroCoder.of(Bird.class), codec); + AvroSource source = AvroSource.from(filename).withSchema(Bird.class); + List actual = SourceTestUtils.readFromSource(source, null); + assertThat(expected, containsInAnyOrder(actual.toArray())); + } + } + + @Test + public void testSplitAtFraction() throws Exception { + // A reduced dataset is enough here. + List expected = createFixedRecords(DEFAULT_RECORD_COUNT); + // Create an AvroSource where each block is 1/10th of the total set of records. + String filename = + generateTestFile( + "tmp.avro", + expected, + SyncBehavior.SYNC_REGULAR, + DEFAULT_RECORD_COUNT / 10 /* max records per block */, + AvroCoder.of(FixedRecord.class), + DataFileConstants.NULL_CODEC); + File file = new File(filename); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + List> splits = source.split(file.length() / 3, null); + for (BoundedSource subSource : splits) { + int items = SourceTestUtils.readFromSource(subSource, null).size(); + // Shouldn't split while unstarted. + SourceTestUtils.assertSplitAtFractionFails(subSource, 0, 0.0, null); + SourceTestUtils.assertSplitAtFractionFails(subSource, 0, 0.7, null); + SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent(subSource, 1, 0.7, null); + SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent( + subSource, DEFAULT_RECORD_COUNT / 100, 0.7, null); + SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent( + subSource, DEFAULT_RECORD_COUNT / 10, 0.1, null); + SourceTestUtils.assertSplitAtFractionFails( + subSource, DEFAULT_RECORD_COUNT / 10 + 1, 0.1, null); + SourceTestUtils.assertSplitAtFractionFails(subSource, DEFAULT_RECORD_COUNT / 3, 0.3, null); + SourceTestUtils.assertSplitAtFractionFails(subSource, items, 0.9, null); + SourceTestUtils.assertSplitAtFractionFails(subSource, items, 1.0, null); + SourceTestUtils.assertSplitAtFractionSucceedsAndConsistent(subSource, items, 0.999, null); + } + } + + @Test + public void testGetProgressFromUnstartedReader() throws Exception { + List records = createFixedRecords(DEFAULT_RECORD_COUNT); + String filename = + generateTestFile( + "tmp.avro", + records, + SyncBehavior.SYNC_DEFAULT, + 1000, + AvroCoder.of(FixedRecord.class), + DataFileConstants.NULL_CODEC); + File file = new File(filename); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + try (BoundedReader reader = source.createReader(null)) { + assertEquals(Double.valueOf(0.0), reader.getFractionConsumed()); + } + + List> splits = source.split(file.length() / 3, null); + for (BoundedSource subSource : splits) { + try (BoundedReader reader = subSource.createReader(null)) { + assertEquals(Double.valueOf(0.0), reader.getFractionConsumed()); + } + } + } + + @Test + public void testProgress() throws Exception { + // 5 records, 2 per block. + List records = createFixedRecords(5); + String filename = + generateTestFile( + "tmp.avro", + records, + SyncBehavior.SYNC_REGULAR, + 2, + AvroCoder.of(FixedRecord.class), + DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + try (BoundedReader readerOrig = source.createReader(null)) { + assertThat(readerOrig, Matchers.instanceOf(BlockBasedReader.class)); + BlockBasedReader reader = (BlockBasedReader) readerOrig; + + // Before starting + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // First 2 records are in the same block. + assertTrue(reader.start()); + assertTrue(reader.isAtSplitPoint()); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + // continued + assertTrue(reader.advance()); + assertFalse(reader.isAtSplitPoint()); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Second block -> parallelism consumed becomes 1. + assertTrue(reader.advance()); + assertTrue(reader.isAtSplitPoint()); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + // continued + assertTrue(reader.advance()); + assertFalse(reader.isAtSplitPoint()); + assertEquals(1, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // Third and final block -> parallelism consumed becomes 2, remaining becomes 1. + assertTrue(reader.advance()); + assertTrue(reader.isAtSplitPoint()); + assertEquals(2, reader.getSplitPointsConsumed()); + assertEquals(1, reader.getSplitPointsRemaining()); + + // Done + assertFalse(reader.advance()); + assertEquals(3, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + } + } + + @Test + public void testProgressEmptySource() throws Exception { + // 0 records, 20 per block. + List records = Collections.emptyList(); + String filename = + generateTestFile( + "tmp.avro", + records, + SyncBehavior.SYNC_REGULAR, + 2, + AvroCoder.of(FixedRecord.class), + DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + try (BoundedReader readerOrig = source.createReader(null)) { + assertThat(readerOrig, Matchers.instanceOf(BlockBasedReader.class)); + BlockBasedReader reader = (BlockBasedReader) readerOrig; + + // before starting + assertEquals(0.0, reader.getFractionConsumed(), 1e-6); + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(BoundedReader.SPLIT_POINTS_UNKNOWN, reader.getSplitPointsRemaining()); + + // confirm empty + assertFalse(reader.start()); + + // after reading empty source + assertEquals(0, reader.getSplitPointsConsumed()); + assertEquals(0, reader.getSplitPointsRemaining()); + assertEquals(1.0, reader.getFractionConsumed(), 1e-6); + } + } + + @Test + public void testGetCurrentFromUnstartedReader() throws Exception { + List records = createFixedRecords(DEFAULT_RECORD_COUNT); + String filename = + generateTestFile( + "tmp.avro", + records, + SyncBehavior.SYNC_DEFAULT, + 1000, + AvroCoder.of(FixedRecord.class), + DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + try (BlockBasedSource.BlockBasedReader reader = + (BlockBasedSource.BlockBasedReader) source.createReader(null)) { + assertEquals(null, reader.getCurrentBlock()); + + expectedException.expect(NoSuchElementException.class); + expectedException.expectMessage("No block has been successfully read from"); + reader.getCurrent(); + } + } + + @Test + public void testSplitAtFractionExhaustive() throws Exception { + // A small-sized input is sufficient, because the test verifies that splitting is non-vacuous. + List expected = createFixedRecords(20); + String filename = + generateTestFile( + "tmp.avro", + expected, + SyncBehavior.SYNC_REGULAR, + 5, + AvroCoder.of(FixedRecord.class), + DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FixedRecord.class); + SourceTestUtils.assertSplitAtFractionExhaustive(source, null); + } + + @Test + public void testSplitsWithSmallBlocks() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + // Test reading from an object file with many small random-sized blocks. + // The file itself doesn't have to be big; we can use a decreased record count. + List expected = createRandomRecords(DEFAULT_RECORD_COUNT); + String filename = + generateTestFile( + "tmp.avro", + expected, + SyncBehavior.SYNC_RANDOM, + DEFAULT_RECORD_COUNT / 20 /* max records/block */, + AvroCoder.of(Bird.class), + DataFileConstants.NULL_CODEC); + File file = new File(filename); + + // Small minimum bundle size + AvroSource source = + AvroSource.from(filename).withSchema(Bird.class).withMinBundleSize(100L); + + // Assert that the source produces the expected records + assertEquals(expected, SourceTestUtils.readFromSource(source, options)); + + List> splits; + int nonEmptySplits; + + // Split with the minimum bundle size + splits = source.split(100L, options); + assertTrue(splits.size() > 2); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splits, options); + nonEmptySplits = 0; + for (BoundedSource subSource : splits) { + if (SourceTestUtils.readFromSource(subSource, options).size() > 0) { + nonEmptySplits += 1; + } + } + assertTrue(nonEmptySplits > 2); + + // Split with larger bundle size + splits = source.split(file.length() / 4, options); + assertTrue(splits.size() > 2); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splits, options); + nonEmptySplits = 0; + for (BoundedSource subSource : splits) { + if (SourceTestUtils.readFromSource(subSource, options).size() > 0) { + nonEmptySplits += 1; + } + } + assertTrue(nonEmptySplits > 2); + + // Split with the file length + splits = source.split(file.length(), options); + assertTrue(splits.size() == 1); + SourceTestUtils.assertSourcesEqualReferenceSource(source, splits, options); + } + + @Test + public void testMultipleFiles() throws Exception { + String baseName = "tmp-"; + List expected = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + List contents = createRandomRecords(DEFAULT_RECORD_COUNT / 10); + expected.addAll(contents); + generateTestFile( + baseName + i, + contents, + SyncBehavior.SYNC_DEFAULT, + 0, + AvroCoder.of(Bird.class), + DataFileConstants.NULL_CODEC); + } + + AvroSource source = + AvroSource.from(new File(tmpFolder.getRoot().toString(), baseName + "*").toString()) + .withSchema(Bird.class); + List actual = SourceTestUtils.readFromSource(source, null); + assertThat(actual, containsInAnyOrder(expected.toArray())); + } + + @Test + public void testCreationWithSchema() throws Exception { + List expected = createRandomRecords(100); + String filename = + generateTestFile( + "tmp.avro", + expected, + SyncBehavior.SYNC_DEFAULT, + 0, + AvroCoder.of(Bird.class), + DataFileConstants.NULL_CODEC); + + // Create a source with a schema object + Schema schema = ReflectData.get().getSchema(Bird.class); + AvroSource source = AvroSource.from(filename).withSchema(schema); + List records = SourceTestUtils.readFromSource(source, null); + assertEqualsWithGeneric(expected, records); + + // Create a source with a JSON schema + String schemaString = ReflectData.get().getSchema(Bird.class).toString(); + source = AvroSource.from(filename).withSchema(schemaString); + records = SourceTestUtils.readFromSource(source, null); + assertEqualsWithGeneric(expected, records); + } + + @Test + public void testSchemaUpdate() throws Exception { + List birds = createRandomRecords(100); + String filename = + generateTestFile( + "tmp.avro", + birds, + SyncBehavior.SYNC_DEFAULT, + 0, + AvroCoder.of(Bird.class), + DataFileConstants.NULL_CODEC); + + AvroSource source = AvroSource.from(filename).withSchema(FancyBird.class); + List actual = SourceTestUtils.readFromSource(source, null); + + List expected = new ArrayList<>(); + for (Bird bird : birds) { + expected.add( + new FancyBird( + bird.number, bird.species, bird.quality, bird.quantity, null, "MAXIMUM OVERDRIVE")); + } + + assertThat(actual, containsInAnyOrder(expected.toArray())); + } + + @Test + public void testSchemaStringIsInterned() throws Exception { + List birds = createRandomRecords(100); + String filename = + generateTestFile( + "tmp.avro", + birds, + SyncBehavior.SYNC_DEFAULT, + 0, + AvroCoder.of(Bird.class), + DataFileConstants.NULL_CODEC); + Metadata fileMetadata = FileSystems.matchSingleFileSpec(filename); + String schema = AvroSource.readMetadataFromFile(fileMetadata.resourceId()).getSchemaString(); + // Add "" to the schema to make sure it is not interned. + AvroSource sourceA = AvroSource.from(filename).withSchema("" + schema); + AvroSource sourceB = AvroSource.from(filename).withSchema("" + schema); + assertSame(sourceA.getReaderSchemaString(), sourceB.getReaderSchemaString()); + + // Ensure that deserialization still goes through interning + AvroSource sourceC = SerializableUtils.clone(sourceB); + assertSame(sourceA.getReaderSchemaString(), sourceC.getReaderSchemaString()); + } + + @Test + public void testParseFn() throws Exception { + List expected = createRandomRecords(100); + String filename = + generateTestFile( + "tmp.avro", + expected, + SyncBehavior.SYNC_DEFAULT, + 0, + AvroCoder.of(Bird.class), + DataFileConstants.NULL_CODEC); + + AvroSource source = + AvroSource.from(filename) + .withParseFn( + input -> + new Bird( + (long) input.get("number"), + input.get("species").toString(), + input.get("quality").toString(), + (long) input.get("quantity")), + AvroCoder.of(Bird.class)); + List actual = SourceTestUtils.readFromSource(source, null); + assertThat(actual, containsInAnyOrder(expected.toArray())); + } + + @Test + public void testDatumReaderFactoryWithGenericRecord() throws Exception { + List inputBirds = createRandomRecords(100); + + String filename = + generateTestFile( + "tmp.avro", + inputBirds, + SyncBehavior.SYNC_DEFAULT, + 0, + AvroCoder.of(Bird.class), + DataFileConstants.NULL_CODEC); + + AvroSource.DatumReaderFactory factory = + (writer, reader) -> + new GenericDatumReader(writer, reader) { + @Override + protected Object readString(Object old, Decoder in) throws IOException { + return super.readString(old, in) + "_custom"; + } + }; + + AvroSource source = + AvroSource.from(filename) + .withParseFn( + input -> + new Bird( + (long) input.get("number"), + input.get("species").toString(), + input.get("quality").toString(), + (long) input.get("quantity")), + AvroCoder.of(Bird.class)) + .withDatumReaderFactory(factory); + List actual = SourceTestUtils.readFromSource(source, null); + List expected = + inputBirds.stream() + .map(b -> new Bird(b.number, b.species + "_custom", b.quality + "_custom", b.quantity)) + .collect(Collectors.toList()); + + assertThat(actual, containsInAnyOrder(expected.toArray())); + } + + private void assertEqualsWithGeneric(List expected, List actual) { + assertEquals(expected.size(), actual.size()); + for (int i = 0; i < expected.size(); i++) { + Bird fixed = expected.get(i); + GenericRecord generic = actual.get(i); + assertEquals(fixed.number, generic.get("number")); + assertEquals(fixed.quality, generic.get("quality").toString()); // From Avro util.Utf8 + assertEquals(fixed.quantity, generic.get("quantity")); + assertEquals(fixed.species, generic.get("species").toString()); + } + } + + @Test + public void testDisplayData() { + AvroSource source = + AvroSource.from("foobar.txt").withSchema(Bird.class).withMinBundleSize(1234); + + DisplayData displayData = DisplayData.from(source); + assertThat(displayData, hasDisplayItem("filePattern", "foobar.txt")); + assertThat(displayData, hasDisplayItem("minBundleSize", 1234)); + } + + @Test + public void testReadMetadataWithCodecs() throws Exception { + // Test reading files generated using all codecs. + String[] codecs = { + DataFileConstants.NULL_CODEC, + DataFileConstants.BZIP2_CODEC, + DataFileConstants.DEFLATE_CODEC, + DataFileConstants.SNAPPY_CODEC, + DataFileConstants.XZ_CODEC + }; + List expected = createRandomRecords(DEFAULT_RECORD_COUNT); + + for (String codec : codecs) { + String filename = + generateTestFile( + codec, expected, SyncBehavior.SYNC_DEFAULT, 0, AvroCoder.of(Bird.class), codec); + + Metadata fileMeta = FileSystems.matchSingleFileSpec(filename); + AvroSource.AvroMetadata metadata = AvroSource.readMetadataFromFile(fileMeta.resourceId()); + assertEquals(codec, metadata.getCodec()); + } + } + + @Test + public void testReadSchemaString() throws Exception { + List expected = createRandomRecords(DEFAULT_RECORD_COUNT); + String codec = DataFileConstants.NULL_CODEC; + String filename = + generateTestFile( + codec, expected, SyncBehavior.SYNC_DEFAULT, 0, AvroCoder.of(Bird.class), codec); + Metadata fileMeta = FileSystems.matchSingleFileSpec(filename); + AvroSource.AvroMetadata metadata = AvroSource.readMetadataFromFile(fileMeta.resourceId()); + // By default, parse validates the schema, which is what we want. + Schema schema = new Schema.Parser().parse(metadata.getSchemaString()); + assertEquals(4, schema.getFields().size()); + } + + @Test + public void testCreateFromMetadata() throws Exception { + List expected = createRandomRecords(DEFAULT_RECORD_COUNT); + String codec = DataFileConstants.NULL_CODEC; + String filename = + generateTestFile( + codec, expected, SyncBehavior.SYNC_DEFAULT, 0, AvroCoder.of(Bird.class), codec); + Metadata fileMeta = FileSystems.matchSingleFileSpec(filename); + + AvroSource source = AvroSource.from(fileMeta); + AvroSource sourceWithSchema = source.withSchema(Bird.class); + AvroSource sourceWithSchemaWithMinBundleSize = sourceWithSchema.withMinBundleSize(1234); + + assertEquals(FileBasedSource.Mode.SINGLE_FILE_OR_SUBRANGE, source.getMode()); + assertEquals(FileBasedSource.Mode.SINGLE_FILE_OR_SUBRANGE, sourceWithSchema.getMode()); + assertEquals( + FileBasedSource.Mode.SINGLE_FILE_OR_SUBRANGE, sourceWithSchemaWithMinBundleSize.getMode()); + } + + /** + * Class that will encode to a fixed size: 16 bytes. + * + *

Each object has a 15-byte array. Avro encodes an object of this type as a byte array, so + * each encoded object will consist of 1 byte that encodes the length of the array, followed by 15 + * bytes. + */ + @DefaultCoder(AvroCoder.class) + public static class FixedRecord { + private byte[] value = new byte[15]; + + public FixedRecord() { + this(0); + } + + public FixedRecord(int i) { + value[0] = (byte) i; + value[1] = (byte) (i >> 8); + value[2] = (byte) (i >> 16); + value[3] = (byte) (i >> 24); + } + + public int asInt() { + return value[0] | (value[1] << 8) | (value[2] << 16) | (value[3] << 24); + } + + @Override + public boolean equals(@Nullable Object o) { + if (o instanceof FixedRecord) { + FixedRecord other = (FixedRecord) o; + return this.asInt() == other.asInt(); + } + return false; + } + + @Override + public int hashCode() { + return toString().hashCode(); + } + + @Override + public String toString() { + return Integer.toString(this.asInt()); + } + } + + /** Create a list of count 16-byte records. */ + private static List createFixedRecords(int count) { + List records = new ArrayList<>(); + for (int i = 0; i < count; i++) { + records.add(new FixedRecord(i)); + } + return records; + } + + /** Class used as the record type in tests. */ + @DefaultCoder(AvroCoder.class) + static class Bird { + long number; + String species; + String quality; + long quantity; + + public Bird() {} + + public Bird(long number, String species, String quality, long quantity) { + this.number = number; + this.species = species; + this.quality = quality; + this.quantity = quantity; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(Bird.class) + .addValue(number) + .addValue(species) + .addValue(quantity) + .addValue(quality) + .toString(); + } + + @Override + public boolean equals(@Nullable Object obj) { + if (obj instanceof Bird) { + Bird other = (Bird) obj; + return Objects.equals(species, other.species) + && Objects.equals(quality, other.quality) + && quantity == other.quantity + && number == other.number; + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(number, species, quality, quantity); + } + } + + /** + * Class used as the record type in tests. + * + *

Contains nullable fields and fields with default values. Can be read using a file written + * with the Bird schema. + */ + @DefaultCoder(AvroCoder.class) + public static class FancyBird { + long number; + String species; + String quality; + long quantity; + + @org.apache.avro.reflect.Nullable String habitat; + + @AvroDefault("\"MAXIMUM OVERDRIVE\"") + String fancinessLevel; + + public FancyBird() {} + + public FancyBird( + long number, + String species, + String quality, + long quantity, + String habitat, + String fancinessLevel) { + this.number = number; + this.species = species; + this.quality = quality; + this.quantity = quantity; + this.habitat = habitat; + this.fancinessLevel = fancinessLevel; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(FancyBird.class) + .addValue(number) + .addValue(species) + .addValue(quality) + .addValue(quantity) + .addValue(habitat) + .addValue(fancinessLevel) + .toString(); + } + + @Override + public boolean equals(@Nullable Object obj) { + if (obj instanceof FancyBird) { + FancyBird other = (FancyBird) obj; + return Objects.equals(species, other.species) + && Objects.equals(quality, other.quality) + && quantity == other.quantity + && number == other.number + && Objects.equals(fancinessLevel, other.fancinessLevel) + && Objects.equals(habitat, other.habitat); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(number, species, quality, quantity, habitat, fancinessLevel); + } + } + + /** Create a list of n random records. */ + private static List createRandomRecords(long n) { + String[] qualities = { + "miserable", "forelorn", "fidgity", "squirrelly", "fanciful", "chipper", "lazy" + }; + String[] species = {"pigeons", "owls", "gulls", "hawks", "robins", "jays"}; + Random random = new Random(0); + + List records = new ArrayList<>(); + for (long i = 0; i < n; i++) { + Bird bird = new Bird(); + bird.quality = qualities[random.nextInt(qualities.length)]; + bird.species = species[random.nextInt(species.length)]; + bird.number = i; + bird.quantity = random.nextLong(); + records.add(bird); + } + return records; + } +} diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/SerializableAvroCodecFactoryTest.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/SerializableAvroCodecFactoryTest.java new file mode 100644 index 000000000000..241ad11635a8 --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/io/SerializableAvroCodecFactoryTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.io; + +import static org.apache.avro.file.DataFileConstants.BZIP2_CODEC; +import static org.apache.avro.file.DataFileConstants.DEFLATE_CODEC; +import static org.apache.avro.file.DataFileConstants.NULL_CODEC; +import static org.apache.avro.file.DataFileConstants.SNAPPY_CODEC; +import static org.apache.avro.file.DataFileConstants.XZ_CODEC; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import java.util.List; +import org.apache.avro.file.CodecFactory; +import org.apache.beam.sdk.util.SerializableUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests of SerializableAvroCodecFactory. */ +@RunWith(JUnit4.class) +public class SerializableAvroCodecFactoryTest { + private final List avroCodecs = + Arrays.asList(NULL_CODEC, SNAPPY_CODEC, DEFLATE_CODEC, XZ_CODEC, BZIP2_CODEC); + + @Test + public void testDefaultCodecsIn() throws Exception { + for (String codec : avroCodecs) { + SerializableAvroCodecFactory codecFactory = + new SerializableAvroCodecFactory(CodecFactory.fromString(codec)); + + assertEquals(CodecFactory.fromString(codec).toString(), codecFactory.getCodec().toString()); + } + } + + @Test + public void testDefaultCodecsSerDe() throws Exception { + for (String codec : avroCodecs) { + SerializableAvroCodecFactory codecFactory = + new SerializableAvroCodecFactory(CodecFactory.fromString(codec)); + + SerializableAvroCodecFactory serdeC = SerializableUtils.clone(codecFactory); + + assertEquals(CodecFactory.fromString(codec).toString(), serdeC.getCodec().toString()); + } + } + + @Test + public void testDeflateCodecSerDeWithLevels() throws Exception { + for (int i = 0; i < 10; ++i) { + SerializableAvroCodecFactory codecFactory = + new SerializableAvroCodecFactory(CodecFactory.deflateCodec(i)); + + SerializableAvroCodecFactory serdeC = SerializableUtils.clone(codecFactory); + + assertEquals(CodecFactory.deflateCodec(i).toString(), serdeC.getCodec().toString()); + } + } + + @Test + public void testXZCodecSerDeWithLevels() throws Exception { + for (int i = 0; i < 10; ++i) { + SerializableAvroCodecFactory codecFactory = + new SerializableAvroCodecFactory(CodecFactory.xzCodec(i)); + + SerializableAvroCodecFactory serdeC = SerializableUtils.clone(codecFactory); + + assertEquals(CodecFactory.xzCodec(i).toString(), serdeC.getCodec().toString()); + } + } + + @Test(expected = NullPointerException.class) + public void testNullCodecToString() throws Exception { + // use default CTR (available cause Serializable) + SerializableAvroCodecFactory codec = new SerializableAvroCodecFactory(); + assertEquals("null", codec.toString()); + } +} diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/AvroSchemaTest.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/AvroSchemaTest.java new file mode 100644 index 000000000000..3c235ae4d510 --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/AvroSchemaTest.java @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.schemas; + +import static org.junit.Assert.assertEquals; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.reflect.AvroIgnore; +import org.apache.avro.reflect.AvroName; +import org.apache.avro.reflect.AvroSchema; +import org.apache.avro.util.Utf8; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes; +import org.apache.beam.sdk.schemas.transforms.Group; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.ValidatesRunner; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.joda.time.Days; +import org.joda.time.LocalDate; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +/** Tests for AVRO schema classes. */ +public class AvroSchemaTest { + /** A test POJO that corresponds to our AVRO schema. */ + public static class AvroSubPojo { + @AvroName("BOOL_NON_NULLABLE") + public boolean boolNonNullable; + + @AvroName("int") + @org.apache.avro.reflect.Nullable + public Integer anInt; + + public AvroSubPojo(boolean boolNonNullable, Integer anInt) { + this.boolNonNullable = boolNonNullable; + this.anInt = anInt; + } + + public AvroSubPojo() {} + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof AvroSubPojo)) { + return false; + } + AvroSubPojo that = (AvroSubPojo) o; + return boolNonNullable == that.boolNonNullable && Objects.equals(anInt, that.anInt); + } + + @Override + public int hashCode() { + return Objects.hash(boolNonNullable, anInt); + } + + @Override + public String toString() { + return "AvroSubPojo{" + "boolNonNullable=" + boolNonNullable + ", anInt=" + anInt + '}'; + } + } + + /** A test POJO that corresponds to our AVRO schema. */ + public static class AvroPojo { + public @AvroName("bool_non_nullable") boolean boolNonNullable; + + @org.apache.avro.reflect.Nullable + public @AvroName("int") Integer anInt; + + @org.apache.avro.reflect.Nullable + public @AvroName("long") Long aLong; + + @AvroName("float") + @org.apache.avro.reflect.Nullable + public Float aFloat; + + @AvroName("double") + @org.apache.avro.reflect.Nullable + public Double aDouble; + + @org.apache.avro.reflect.Nullable public String string; + @org.apache.avro.reflect.Nullable public ByteBuffer bytes; + + @AvroSchema("{\"type\": \"fixed\", \"size\": 4, \"name\": \"fixed4\"}") + public byte[] fixed; + + @AvroSchema("{\"type\": \"int\", \"logicalType\": \"date\"}") + public LocalDate date; + + @AvroSchema("{\"type\": \"long\", \"logicalType\": \"timestamp-millis\"}") + public DateTime timestampMillis; + + @AvroSchema("{\"name\": \"TestEnum\", \"type\": \"enum\", \"symbols\": [\"abc\",\"cde\"] }") + public TestEnum testEnum; + + @org.apache.avro.reflect.Nullable public AvroSubPojo row; + @org.apache.avro.reflect.Nullable public List array; + @org.apache.avro.reflect.Nullable public Map map; + @AvroIgnore String extraField; + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof AvroPojo)) { + return false; + } + AvroPojo avroPojo = (AvroPojo) o; + return boolNonNullable == avroPojo.boolNonNullable + && Objects.equals(anInt, avroPojo.anInt) + && Objects.equals(aLong, avroPojo.aLong) + && Objects.equals(aFloat, avroPojo.aFloat) + && Objects.equals(aDouble, avroPojo.aDouble) + && Objects.equals(string, avroPojo.string) + && Objects.equals(bytes, avroPojo.bytes) + && Arrays.equals(fixed, avroPojo.fixed) + && Objects.equals(date, avroPojo.date) + && Objects.equals(timestampMillis, avroPojo.timestampMillis) + && Objects.equals(testEnum, avroPojo.testEnum) + && Objects.equals(row, avroPojo.row) + && Objects.equals(array, avroPojo.array) + && Objects.equals(map, avroPojo.map); + } + + @Override + public int hashCode() { + return Objects.hash( + boolNonNullable, + anInt, + aLong, + aFloat, + aDouble, + string, + bytes, + Arrays.hashCode(fixed), + date, + timestampMillis, + testEnum, + row, + array, + map); + } + + public AvroPojo( + boolean boolNonNullable, + int anInt, + long aLong, + float aFloat, + double aDouble, + String string, + ByteBuffer bytes, + byte[] fixed, + LocalDate date, + DateTime timestampMillis, + TestEnum testEnum, + AvroSubPojo row, + List array, + Map map) { + this.boolNonNullable = boolNonNullable; + this.anInt = anInt; + this.aLong = aLong; + this.aFloat = aFloat; + this.aDouble = aDouble; + this.string = string; + this.bytes = bytes; + this.fixed = fixed; + this.date = date; + this.timestampMillis = timestampMillis; + this.testEnum = testEnum; + this.row = row; + this.array = array; + this.map = map; + this.extraField = ""; + } + + public AvroPojo() {} + + @Override + public String toString() { + return "AvroPojo{" + + "boolNonNullable=" + + boolNonNullable + + ", anInt=" + + anInt + + ", aLong=" + + aLong + + ", aFloat=" + + aFloat + + ", aDouble=" + + aDouble + + ", string='" + + string + + '\'' + + ", bytes=" + + bytes + + ", fixed=" + + Arrays.toString(fixed) + + ", date=" + + date + + ", timestampMillis=" + + timestampMillis + + ", testEnum=" + + testEnum + + ", row=" + + row + + ", array=" + + array + + ", map=" + + map + + ", extraField='" + + extraField + + '\'' + + '}'; + } + } + + private static final Schema SUBSCHEMA = + Schema.builder() + .addField("BOOL_NON_NULLABLE", FieldType.BOOLEAN) + .addNullableField("int", FieldType.INT32) + .build(); + private static final FieldType SUB_TYPE = FieldType.row(SUBSCHEMA).withNullable(true); + + private static final EnumerationType TEST_ENUM_TYPE = EnumerationType.create("abc", "cde"); + + private static final Schema SCHEMA = + Schema.builder() + .addField("bool_non_nullable", FieldType.BOOLEAN) + .addNullableField("int", FieldType.INT32) + .addNullableField("long", FieldType.INT64) + .addNullableField("float", FieldType.FLOAT) + .addNullableField("double", FieldType.DOUBLE) + .addNullableField("string", FieldType.STRING) + .addNullableField("bytes", FieldType.BYTES) + .addField("fixed", FieldType.logicalType(FixedBytes.of(4))) + .addField("date", FieldType.DATETIME) + .addField("timestampMillis", FieldType.DATETIME) + .addField("TestEnum", FieldType.logicalType(TEST_ENUM_TYPE)) + .addNullableField("row", SUB_TYPE) + .addNullableField("array", FieldType.array(SUB_TYPE)) + .addNullableField("map", FieldType.map(FieldType.STRING, SUB_TYPE)) + .build(); + + private static final Schema POJO_SCHEMA = + Schema.builder() + .addField("bool_non_nullable", FieldType.BOOLEAN) + .addNullableField("int", FieldType.INT32) + .addNullableField("long", FieldType.INT64) + .addNullableField("float", FieldType.FLOAT) + .addNullableField("double", FieldType.DOUBLE) + .addNullableField("string", FieldType.STRING) + .addNullableField("bytes", FieldType.BYTES) + .addField("fixed", FieldType.logicalType(FixedBytes.of(4))) + .addField("date", FieldType.DATETIME) + .addField("timestampMillis", FieldType.DATETIME) + .addField("testEnum", FieldType.logicalType(TEST_ENUM_TYPE)) + .addNullableField("row", SUB_TYPE) + .addNullableField("array", FieldType.array(SUB_TYPE.withNullable(false))) + .addNullableField("map", FieldType.map(FieldType.STRING, SUB_TYPE.withNullable(false))) + .build(); + + private static final byte[] BYTE_ARRAY = new byte[] {1, 2, 3, 4}; + private static final DateTime DATE_TIME = + new DateTime().withDate(1979, 3, 14).withTime(1, 2, 3, 4); + private static final LocalDate DATE = new LocalDate(1979, 3, 14); + private static final TestAvroNested AVRO_NESTED_SPECIFIC_RECORD = new TestAvroNested(true, 42); + private static final TestAvro AVRO_SPECIFIC_RECORD = + new TestAvro( + true, + 43, + 44L, + (float) 44.1, + (double) 44.2, + "mystring", + ByteBuffer.wrap(BYTE_ARRAY), + new fixed4(BYTE_ARRAY), + DATE, + DATE_TIME, + TestEnum.abc, + AVRO_NESTED_SPECIFIC_RECORD, + ImmutableList.of(AVRO_NESTED_SPECIFIC_RECORD, AVRO_NESTED_SPECIFIC_RECORD), + ImmutableMap.of("k1", AVRO_NESTED_SPECIFIC_RECORD, "k2", AVRO_NESTED_SPECIFIC_RECORD)); + private static final GenericRecord AVRO_NESTED_GENERIC_RECORD = + new GenericRecordBuilder(TestAvroNested.SCHEMA$) + .set("BOOL_NON_NULLABLE", true) + .set("int", 42) + .build(); + private static final GenericRecord AVRO_GENERIC_RECORD = + new GenericRecordBuilder(TestAvro.SCHEMA$) + .set("bool_non_nullable", true) + .set("int", 43) + .set("long", 44L) + .set("float", (float) 44.1) + .set("double", (double) 44.2) + .set("string", new Utf8("mystring")) + .set("bytes", ByteBuffer.wrap(BYTE_ARRAY)) + .set( + "fixed", + GenericData.get() + .createFixed( + null, BYTE_ARRAY, org.apache.avro.Schema.createFixed("fixed4", "", "", 4))) + .set("date", (int) Days.daysBetween(new LocalDate(1970, 1, 1), DATE).getDays()) + .set("timestampMillis", DATE_TIME.getMillis()) + .set("TestEnum", TestEnum.abc) + .set("row", AVRO_NESTED_GENERIC_RECORD) + .set("array", ImmutableList.of(AVRO_NESTED_GENERIC_RECORD, AVRO_NESTED_GENERIC_RECORD)) + .set( + "map", + ImmutableMap.of( + new Utf8("k1"), AVRO_NESTED_GENERIC_RECORD, + new Utf8("k2"), AVRO_NESTED_GENERIC_RECORD)) + .build(); + + private static final Row NESTED_ROW = Row.withSchema(SUBSCHEMA).addValues(true, 42).build(); + private static final Row ROW = + Row.withSchema(SCHEMA) + .addValues( + true, + 43, + 44L, + (float) 44.1, + (double) 44.2, + "mystring", + ByteBuffer.wrap(BYTE_ARRAY), + BYTE_ARRAY, + DATE.toDateTimeAtStartOfDay(DateTimeZone.UTC), + DATE_TIME, + TEST_ENUM_TYPE.valueOf("abc"), + NESTED_ROW, + ImmutableList.of(NESTED_ROW, NESTED_ROW), + ImmutableMap.of("k1", NESTED_ROW, "k2", NESTED_ROW)) + .build(); + + @Test + public void testSpecificRecordSchema() { + assertEquals( + SCHEMA, + new org.apache.beam.sdk.schemas.AvroRecordSchema() + .schemaFor(TypeDescriptor.of(TestAvro.class))); + } + + @Test + public void testPojoSchema() { + assertEquals( + POJO_SCHEMA, + new org.apache.beam.sdk.schemas.AvroRecordSchema() + .schemaFor(TypeDescriptor.of(AvroPojo.class))); + } + + @Test + public void testSpecificRecordToRow() { + SerializableFunction toRow = + new org.apache.beam.sdk.schemas.AvroRecordSchema() + .toRowFunction(TypeDescriptor.of(TestAvro.class)); + assertEquals(ROW, toRow.apply(AVRO_SPECIFIC_RECORD)); + } + + @Test + public void testRowToSpecificRecord() { + SerializableFunction fromRow = + new org.apache.beam.sdk.schemas.AvroRecordSchema() + .fromRowFunction(TypeDescriptor.of(TestAvro.class)); + assertEquals(AVRO_SPECIFIC_RECORD, fromRow.apply(ROW)); + } + + @Test + public void testGenericRecordToRow() { + SerializableFunction toRow = + AvroUtils.getGenericRecordToRowFunction(SCHEMA); + assertEquals(ROW, toRow.apply(AVRO_GENERIC_RECORD)); + } + + @Test + public void testRowToGenericRecord() { + SerializableFunction fromRow = + AvroUtils.getRowToGenericRecordFunction(TestAvro.SCHEMA$); + assertEquals(AVRO_GENERIC_RECORD, fromRow.apply(ROW)); + } + + private static final AvroSubPojo SUB_POJO = new AvroSubPojo(true, 42); + private static final AvroPojo AVRO_POJO = + new AvroPojo( + true, + 43, + 44L, + (float) 44.1, + (double) 44.2, + "mystring", + ByteBuffer.wrap(BYTE_ARRAY), + BYTE_ARRAY, + DATE, + DATE_TIME, + TestEnum.abc, + SUB_POJO, + ImmutableList.of(SUB_POJO, SUB_POJO), + ImmutableMap.of("k1", SUB_POJO, "k2", SUB_POJO)); + + private static final Row ROW_FOR_POJO = + Row.withSchema(POJO_SCHEMA) + .addValues( + true, + 43, + 44L, + (float) 44.1, + (double) 44.2, + "mystring", + ByteBuffer.wrap(BYTE_ARRAY), + BYTE_ARRAY, + DATE.toDateTimeAtStartOfDay(DateTimeZone.UTC), + DATE_TIME, + TEST_ENUM_TYPE.valueOf("abc"), + NESTED_ROW, + ImmutableList.of(NESTED_ROW, NESTED_ROW), + ImmutableMap.of("k1", NESTED_ROW, "k2", NESTED_ROW)) + .build(); + + @Test + public void testPojoRecordToRow() { + SerializableFunction toRow = + new org.apache.beam.sdk.schemas.AvroRecordSchema() + .toRowFunction(TypeDescriptor.of(AvroPojo.class)); + assertEquals(ROW_FOR_POJO, toRow.apply(AVRO_POJO)); + } + + @Test + public void testRowToPojo() { + SerializableFunction fromRow = + new org.apache.beam.sdk.schemas.AvroRecordSchema() + .fromRowFunction(TypeDescriptor.of(AvroPojo.class)); + assertEquals(AVRO_POJO, fromRow.apply(ROW_FOR_POJO)); + } + + @Test + public void testPojoRecordToRowSerializable() { + SerializableUtils.ensureSerializableRoundTrip( + new org.apache.beam.sdk.schemas.AvroRecordSchema() + .toRowFunction(TypeDescriptor.of(AvroPojo.class))); + } + + @Test + public void testPojoRecordFromRowSerializable() { + SerializableUtils.ensureSerializableRoundTrip( + new AvroRecordSchema().fromRowFunction(TypeDescriptor.of(AvroPojo.class))); + } + + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + + @Test + @Category(ValidatesRunner.class) + public void testAvroPipelineGroupBy() { + PCollection input = pipeline.apply(Create.of(ROW_FOR_POJO).withRowSchema(POJO_SCHEMA)); + + PCollection output = input.apply(Group.byFieldNames("string")); + Schema keySchema = Schema.builder().addStringField("string").build(); + Schema outputSchema = + Schema.builder() + .addRowField("key", keySchema) + .addIterableField("value", FieldType.row(POJO_SCHEMA)) + .build(); + PAssert.that(output) + .containsInAnyOrder( + Row.withSchema(outputSchema) + .addValue(Row.withSchema(keySchema).addValue("mystring").build()) + .addIterable(ImmutableList.of(ROW_FOR_POJO)) + .build()); + + pipeline.run(); + } +} diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/io/AvroPayloadSerializerProviderTest.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/io/AvroPayloadSerializerProviderTest.java new file mode 100644 index 000000000000..9c56ffcdc084 --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/io/AvroPayloadSerializerProviderTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.schemas.io; + +import static org.junit.Assert.assertEquals; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.extensions.avro.schemas.io.payloads.AvroPayloadSerializerProvider; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class AvroPayloadSerializerProviderTest { + private static final Schema SCHEMA = + Schema.builder().addInt64Field("abc").addStringField("xyz").build(); + private static final org.apache.avro.Schema AVRO_SCHEMA = AvroUtils.toAvroSchema(SCHEMA); + private static final AvroCoder AVRO_CODER = AvroCoder.of(AVRO_SCHEMA); + private static final Row DESERIALIZED = + Row.withSchema(SCHEMA).withFieldValue("abc", 3L).withFieldValue("xyz", "qqq").build(); + private static final GenericRecord SERIALIZED = + new GenericRecordBuilder(AVRO_SCHEMA).set("abc", 3L).set("xyz", "qqq").build(); + + private final AvroPayloadSerializerProvider provider = new AvroPayloadSerializerProvider(); + + @Test + public void serialize() throws Exception { + byte[] bytes = provider.getSerializer(SCHEMA, ImmutableMap.of()).serialize(DESERIALIZED); + GenericRecord record = AVRO_CODER.decode(new ByteArrayInputStream(bytes)); + assertEquals(3L, record.get("abc")); + assertEquals("qqq", record.get("xyz").toString()); + } + + @Test + public void deserialize() throws Exception { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + AVRO_CODER.encode(SERIALIZED, os); + Row row = provider.getSerializer(SCHEMA, ImmutableMap.of()).deserialize(os.toByteArray()); + assertEquals(DESERIALIZED, row); + } +} diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroGenerators.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroGenerators.java new file mode 100644 index 000000000000..fa7d7cceecce --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroGenerators.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.schemas.utils; + +import com.pholser.junit.quickcheck.generator.GenerationStatus; +import com.pholser.junit.quickcheck.generator.Generator; +import com.pholser.junit.quickcheck.random.SourceOfRandomness; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.avro.Schema; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ObjectArrays; + +/** QuickCheck generators for AVRO. */ +class AvroGenerators { + + /** Generates arbitrary AVRO schemas. */ + public static class SchemaGenerator extends BaseSchemaGenerator { + + public static final SchemaGenerator INSTANCE = new SchemaGenerator(); + + private static final ImmutableList PRIMITIVE_TYPES = + ImmutableList.of( + Schema.Type.STRING, + Schema.Type.BYTES, + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.BOOLEAN); + + private static final ImmutableList ALL_TYPES = + ImmutableList.builder() + .addAll(PRIMITIVE_TYPES) + .add(Schema.Type.FIXED) + .add(Schema.Type.ENUM) + .add(Schema.Type.RECORD) + .add(Schema.Type.ARRAY) + .add(Schema.Type.MAP) + .add(Schema.Type.UNION) + .add(Schema.Type.ARRAY) + .build(); + + private static final int MAX_NESTING = 10; + + @Override + public Schema generate(SourceOfRandomness random, GenerationStatus status) { + Schema.Type type; + + if (nesting(status) >= MAX_NESTING) { + type = random.choose(PRIMITIVE_TYPES); + } else { + type = random.choose(ALL_TYPES); + } + + if (PRIMITIVE_TYPES.contains(type)) { + return Schema.create(type); + } else { + nestingInc(status); + + if (type == Schema.Type.FIXED) { + int size = random.choose(Arrays.asList(1, 5, 12)); + return Schema.createFixed("fixed_" + branch(status), "", "", size); + } else if (type == Schema.Type.UNION) { + // only nullable fields, everything else isn't supported in row conversion code + return UnionSchemaGenerator.INSTANCE.generate(random, status); + } else if (type == Schema.Type.ENUM) { + return EnumSchemaGenerator.INSTANCE.generate(random, status); + } else if (type == Schema.Type.RECORD) { + return RecordSchemaGenerator.INSTANCE.generate(random, status); + } else if (type == Schema.Type.MAP) { + return Schema.createMap(generate(random, status)); + } else if (type == Schema.Type.ARRAY) { + return Schema.createArray(generate(random, status)); + } else { + throw new AssertionError("Unexpected AVRO type: " + type); + } + } + } + } + + public static class RecordSchemaGenerator extends BaseSchemaGenerator { + + public static final RecordSchemaGenerator INSTANCE = new RecordSchemaGenerator(); + + @Override + public Schema generate(SourceOfRandomness random, GenerationStatus status) { + List fields = + IntStream.range(0, random.nextInt(0, status.size()) + 1) + .mapToObj( + i -> { + // deterministically avoid collisions in record names + branchPush(status, String.valueOf(i)); + Schema.Field field = + createField(i, SchemaGenerator.INSTANCE.generate(random, status)); + branchPop(status); + return field; + }) + .collect(Collectors.toList()); + + return Schema.createRecord("record_" + branch(status), "", "example", false, fields); + } + + private Schema.Field createField(int i, Schema schema) { + return new Schema.Field("field_" + i, schema, null, (Object) null); + } + } + + static class UnionSchemaGenerator extends BaseSchemaGenerator { + + public static final UnionSchemaGenerator INSTANCE = new UnionSchemaGenerator(); + + @Override + public Schema generate(SourceOfRandomness random, GenerationStatus status) { + Map schemaMap = + IntStream.range(0, random.nextInt(0, status.size()) + 1) + .mapToObj( + i -> { + // deterministically avoid collisions in record names + branchPush(status, String.valueOf(i)); + Schema schema = + SchemaGenerator.INSTANCE + // nested unions aren't supported in AVRO + .filter(x -> x.getType() != Schema.Type.UNION) + .generate(random, status); + branchPop(status); + return schema; + }) + // AVRO requires uniqueness by full name + .collect(Collectors.toMap(Schema::getFullName, Function.identity(), (x, y) -> x)); + + List schemas = new ArrayList<>(schemaMap.values()); + + if (random.nextBoolean()) { + Schema nullSchema = Schema.create(Schema.Type.NULL); + schemas.add(nullSchema); + Collections.shuffle(schemas, random.toJDKRandom()); + } + + return Schema.createUnion(schemas); + } + } + + static class EnumSchemaGenerator extends BaseSchemaGenerator { + + public static final EnumSchemaGenerator INSTANCE = new EnumSchemaGenerator(); + + private static final Schema FRUITS = + Schema.createEnum("Fruit", "", "example", Arrays.asList("banana", "apple", "pear")); + + private static final Schema STATUS = + Schema.createEnum("Status", "", "example", Arrays.asList("OK", "ERROR", "WARNING")); + + @Override + public Schema generate(final SourceOfRandomness random, final GenerationStatus status) { + return random.choose(Arrays.asList(FRUITS, STATUS)); + } + } + + abstract static class BaseSchemaGenerator extends Generator { + + private static final GenerationStatus.Key NESTING_KEY = + new GenerationStatus.Key<>("nesting", Integer.class); + + private static final GenerationStatus.Key BRANCH_KEY = + new GenerationStatus.Key<>("branch", String[].class); + + BaseSchemaGenerator() { + super(Schema.class); + } + + void branchPush(GenerationStatus status, String value) { + String[] current = status.valueOf(BRANCH_KEY).orElse(new String[0]); + String[] next = ObjectArrays.concat(current, value); + + status.setValue(BRANCH_KEY, next); + } + + void branchPop(GenerationStatus status) { + String[] current = status.valueOf(BRANCH_KEY).orElse(new String[0]); + String[] next = Arrays.copyOf(current, current.length - 1); + + status.setValue(BRANCH_KEY, next); + } + + String branch(GenerationStatus status) { + return Joiner.on("_").join(status.valueOf(BRANCH_KEY).orElse(new String[0])); + } + + int nesting(GenerationStatus status) { + return status.valueOf(NESTING_KEY).orElse(0); + } + + void nestingInc(GenerationStatus status) { + status.setValue(NESTING_KEY, nesting(status) + 1); + } + } +} diff --git a/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtilsTest.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtilsTest.java new file mode 100644 index 000000000000..2c33f6439665 --- /dev/null +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtilsTest.java @@ -0,0 +1,895 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.avro.schemas.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.pholser.junit.quickcheck.From; +import com.pholser.junit.quickcheck.Property; +import com.pholser.junit.quickcheck.runner.JUnitQuickcheck; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.JDBCType; +import java.util.List; +import java.util.Map; +import org.apache.avro.Conversions; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.RandomData; +import org.apache.avro.Schema.Type; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.util.Utf8; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.extensions.avro.io.AvroGeneratedUser; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; +import org.apache.beam.sdk.testing.CoderProperties; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.joda.time.Days; +import org.joda.time.Instant; +import org.joda.time.LocalTime; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Tests for conversion between AVRO records and Beam rows. */ +@RunWith(JUnitQuickcheck.class) +@SuppressWarnings({ + "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) +}) +public class AvroUtilsTest { + + private static final org.apache.avro.Schema NULL_SCHEMA = + org.apache.avro.Schema.create(Type.NULL); + + @Property(trials = 1000) + @SuppressWarnings("unchecked") + public void supportsAnyAvroSchema( + @From(AvroGenerators.RecordSchemaGenerator.class) org.apache.avro.Schema avroSchema) { + + Schema schema = AvroUtils.toBeamSchema(avroSchema); + Iterable iterable = new RandomData(avroSchema, 10); + List records = Lists.newArrayList((Iterable) iterable); + + for (GenericRecord record : records) { + AvroUtils.toBeamRowStrict(record, schema); + } + } + + @Property(trials = 1000) + @SuppressWarnings("unchecked") + public void avroToBeamRoundTrip( + @From(AvroGenerators.RecordSchemaGenerator.class) org.apache.avro.Schema avroSchema) { + + Schema schema = AvroUtils.toBeamSchema(avroSchema); + Iterable iterable = new RandomData(avroSchema, 10); + List records = Lists.newArrayList((Iterable) iterable); + + for (GenericRecord record : records) { + Row row = AvroUtils.toBeamRowStrict(record, schema); + GenericRecord out = AvroUtils.toGenericRecord(row, avroSchema); + assertEquals(record, out); + } + } + + @Test + public void testUnwrapNullableSchema() { + org.apache.avro.Schema avroSchema = + org.apache.avro.Schema.createUnion( + org.apache.avro.Schema.create(Type.NULL), org.apache.avro.Schema.create(Type.STRING)); + + AvroUtils.TypeWithNullability typeWithNullability = + new AvroUtils.TypeWithNullability(avroSchema); + assertTrue(typeWithNullability.nullable); + assertEquals(org.apache.avro.Schema.create(Type.STRING), typeWithNullability.type); + } + + @Test + public void testUnwrapNullableSchemaReordered() { + org.apache.avro.Schema avroSchema = + org.apache.avro.Schema.createUnion( + org.apache.avro.Schema.create(Type.STRING), org.apache.avro.Schema.create(Type.NULL)); + + AvroUtils.TypeWithNullability typeWithNullability = + new AvroUtils.TypeWithNullability(avroSchema); + assertTrue(typeWithNullability.nullable); + assertEquals(org.apache.avro.Schema.create(Type.STRING), typeWithNullability.type); + } + + @Test + public void testUnwrapNullableSchemaToUnion() { + org.apache.avro.Schema avroSchema = + org.apache.avro.Schema.createUnion( + org.apache.avro.Schema.create(Type.STRING), + org.apache.avro.Schema.create(Type.LONG), + org.apache.avro.Schema.create(Type.NULL)); + + AvroUtils.TypeWithNullability typeWithNullability = + new AvroUtils.TypeWithNullability(avroSchema); + assertTrue(typeWithNullability.nullable); + assertEquals( + org.apache.avro.Schema.createUnion( + org.apache.avro.Schema.create(Type.STRING), org.apache.avro.Schema.create(Type.LONG)), + typeWithNullability.type); + } + + @Test + public void testNullableArrayFieldToBeamArrayField() { + org.apache.avro.Schema.Field avroField = + new org.apache.avro.Schema.Field( + "arrayField", + ReflectData.makeNullable( + org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(Type.INT))), + "", + null); + + Field expectedBeamField = Field.nullable("arrayField", FieldType.array(FieldType.INT32)); + + Field beamField = AvroUtils.toBeamField(avroField); + assertEquals(expectedBeamField, beamField); + } + + @Test + public void testNullableBeamArrayFieldToAvroField() { + Field beamField = Field.nullable("arrayField", FieldType.array(FieldType.INT32)); + + org.apache.avro.Schema.Field expectedAvroField = + new org.apache.avro.Schema.Field( + "arrayField", + ReflectData.makeNullable( + org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(Type.INT))), + "", + null); + + org.apache.avro.Schema.Field avroField = AvroUtils.toAvroField(beamField, "ignored"); + assertEquals(expectedAvroField, avroField); + } + + private static List getAvroSubSchemaFields() { + List fields = Lists.newArrayList(); + fields.add( + new org.apache.avro.Schema.Field( + "bool", org.apache.avro.Schema.create(Type.BOOLEAN), "", null)); + fields.add( + new org.apache.avro.Schema.Field("int", org.apache.avro.Schema.create(Type.INT), "", null)); + return fields; + } + + private static org.apache.avro.Schema getAvroSubSchema(String name) { + return org.apache.avro.Schema.createRecord( + name, null, "topLevelRecord", false, getAvroSubSchemaFields()); + } + + private static org.apache.avro.Schema getAvroSchema() { + List fields = Lists.newArrayList(); + fields.add( + new org.apache.avro.Schema.Field( + "bool", org.apache.avro.Schema.create(Type.BOOLEAN), "", (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "int", org.apache.avro.Schema.create(Type.INT), "", (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "long", org.apache.avro.Schema.create(Type.LONG), "", (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "float", org.apache.avro.Schema.create(Type.FLOAT), "", (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "double", org.apache.avro.Schema.create(Type.DOUBLE), "", (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "string", org.apache.avro.Schema.create(Type.STRING), "", (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "bytes", org.apache.avro.Schema.create(Type.BYTES), "", (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "decimal", + LogicalTypes.decimal(Integer.MAX_VALUE) + .addToSchema(org.apache.avro.Schema.create(Type.BYTES)), + "", + (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "timestampMillis", + LogicalTypes.timestampMillis().addToSchema(org.apache.avro.Schema.create(Type.LONG)), + "", + (Object) null)); + fields.add(new org.apache.avro.Schema.Field("row", getAvroSubSchema("row"), "", (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "array", + org.apache.avro.Schema.createArray(getAvroSubSchema("array")), + "", + (Object) null)); + fields.add( + new org.apache.avro.Schema.Field( + "map", org.apache.avro.Schema.createMap(getAvroSubSchema("map")), "", (Object) null)); + return org.apache.avro.Schema.createRecord("topLevelRecord", null, null, false, fields); + } + + private static Schema getBeamSubSchema() { + return new Schema.Builder() + .addField(Field.of("bool", FieldType.BOOLEAN)) + .addField(Field.of("int", FieldType.INT32)) + .build(); + } + + private Schema getBeamSchema() { + Schema subSchema = getBeamSubSchema(); + return new Schema.Builder() + .addField(Field.of("bool", FieldType.BOOLEAN)) + .addField(Field.of("int", FieldType.INT32)) + .addField(Field.of("long", FieldType.INT64)) + .addField(Field.of("float", FieldType.FLOAT)) + .addField(Field.of("double", FieldType.DOUBLE)) + .addField(Field.of("string", FieldType.STRING)) + .addField(Field.of("bytes", FieldType.BYTES)) + .addField(Field.of("decimal", FieldType.DECIMAL)) + .addField(Field.of("timestampMillis", FieldType.DATETIME)) + .addField(Field.of("row", FieldType.row(subSchema))) + .addField(Field.of("array", FieldType.array(FieldType.row(subSchema)))) + .addField(Field.of("map", FieldType.map(FieldType.STRING, FieldType.row(subSchema)))) + .build(); + } + + private static final byte[] BYTE_ARRAY = new byte[] {1, 2, 3, 4}; + private static final DateTime DATE_TIME = + new DateTime().withDate(1979, 3, 14).withTime(1, 2, 3, 4).withZone(DateTimeZone.UTC); + private static final BigDecimal BIG_DECIMAL = new BigDecimal(3600); + + private Row getBeamRow() { + Row subRow = Row.withSchema(getBeamSubSchema()).addValues(true, 42).build(); + return Row.withSchema(getBeamSchema()) + .addValue(true) + .addValue(43) + .addValue(44L) + .addValue((float) 44.1) + .addValue((double) 44.2) + .addValue("string") + .addValue(BYTE_ARRAY) + .addValue(BIG_DECIMAL) + .addValue(DATE_TIME) + .addValue(subRow) + .addValue(ImmutableList.of(subRow, subRow)) + .addValue(ImmutableMap.of("k1", subRow, "k2", subRow)) + .build(); + } + + private static GenericRecord getSubGenericRecord(String name) { + return new GenericRecordBuilder(getAvroSubSchema(name)) + .set("bool", true) + .set("int", 42) + .build(); + } + + private static GenericRecord getGenericRecord() { + + LogicalType decimalType = + LogicalTypes.decimal(Integer.MAX_VALUE) + .addToSchema(org.apache.avro.Schema.create(Type.BYTES)) + .getLogicalType(); + ByteBuffer encodedDecimal = + new Conversions.DecimalConversion().toBytes(BIG_DECIMAL, null, decimalType); + + return new GenericRecordBuilder(getAvroSchema()) + .set("bool", true) + .set("int", 43) + .set("long", 44L) + .set("float", (float) 44.1) + .set("double", (double) 44.2) + .set("string", new Utf8("string")) + .set("bytes", ByteBuffer.wrap(BYTE_ARRAY)) + .set("decimal", encodedDecimal) + .set("timestampMillis", DATE_TIME.getMillis()) + .set("row", getSubGenericRecord("row")) + .set("array", ImmutableList.of(getSubGenericRecord("array"), getSubGenericRecord("array"))) + .set( + "map", + ImmutableMap.of( + new Utf8("k1"), + getSubGenericRecord("map"), + new Utf8("k2"), + getSubGenericRecord("map"))) + .build(); + } + + @Test + public void testFromAvroSchema() { + assertEquals(getBeamSchema(), AvroUtils.toBeamSchema(getAvroSchema())); + } + + @Test + public void testFromBeamSchema() { + Schema beamSchema = getBeamSchema(); + org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(beamSchema); + assertEquals(getAvroSchema(), avroSchema); + } + + @Test + public void testAvroSchemaFromBeamSchemaCanBeParsed() { + org.apache.avro.Schema convertedSchema = AvroUtils.toAvroSchema(getBeamSchema()); + org.apache.avro.Schema validatedSchema = + new org.apache.avro.Schema.Parser().parse(convertedSchema.toString()); + assertEquals(convertedSchema, validatedSchema); + } + + @Test + public void testAvroSchemaFromBeamSchemaWithFieldCollisionCanBeParsed() { + + // Two similar schemas, the only difference is the "street" field type in the nested record. + Schema contact = + new Schema.Builder() + .addField(Field.of("name", FieldType.STRING)) + .addField( + Field.of( + "address", + FieldType.row( + new Schema.Builder() + .addField(Field.of("street", FieldType.STRING)) + .addField(Field.of("city", FieldType.STRING)) + .build()))) + .build(); + + Schema contactMultiline = + new Schema.Builder() + .addField(Field.of("name", FieldType.STRING)) + .addField( + Field.of( + "address", + FieldType.row( + new Schema.Builder() + .addField(Field.of("street", FieldType.array(FieldType.STRING))) + .addField(Field.of("city", FieldType.STRING)) + .build()))) + .build(); + + // Ensure that no collisions happen between two sibling fields with same-named child fields + // (with different schemas, between a parent field and a sub-record field with the same name, + // and artificially with the generated field name. + Schema beamSchema = + new Schema.Builder() + .addField(Field.of("home", FieldType.row(contact))) + .addField(Field.of("work", FieldType.row(contactMultiline))) + .addField(Field.of("address", FieldType.row(contact))) + .addField(Field.of("topLevelRecord", FieldType.row(contactMultiline))) + .build(); + + org.apache.avro.Schema convertedSchema = AvroUtils.toAvroSchema(beamSchema); + org.apache.avro.Schema validatedSchema = + new org.apache.avro.Schema.Parser().parse(convertedSchema.toString()); + assertEquals(convertedSchema, validatedSchema); + } + + @Test + public void testNullableFieldInAvroSchema() { + List fields = Lists.newArrayList(); + fields.add( + new org.apache.avro.Schema.Field( + "int", ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT)), "", null)); + fields.add( + new org.apache.avro.Schema.Field( + "array", + org.apache.avro.Schema.createArray( + ReflectData.makeNullable(org.apache.avro.Schema.create(Type.BYTES))), + "", + null)); + fields.add( + new org.apache.avro.Schema.Field( + "map", + org.apache.avro.Schema.createMap( + ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))), + "", + null)); + fields.add( + new org.apache.avro.Schema.Field( + "enum", + ReflectData.makeNullable( + org.apache.avro.Schema.createEnum( + "fruit", "", "", ImmutableList.of("banana", "apple", "pear"))), + "", + null)); + + org.apache.avro.Schema avroSchema = + org.apache.avro.Schema.createRecord("topLevelRecord", null, null, false, fields); + + Schema expectedSchema = + Schema.builder() + .addNullableField("int", FieldType.INT32) + .addArrayField("array", FieldType.BYTES.withNullable(true)) + .addMapField("map", FieldType.STRING, FieldType.INT32.withNullable(true)) + .addField( + "enum", + FieldType.logicalType(EnumerationType.create("banana", "apple", "pear")) + .withNullable(true)) + .build(); + assertEquals(expectedSchema, AvroUtils.toBeamSchema(avroSchema)); + + Map nullMap = Maps.newHashMap(); + nullMap.put("k1", null); + GenericRecord genericRecord = + new GenericRecordBuilder(avroSchema) + .set("int", null) + .set("array", Lists.newArrayList((Object) null)) + .set("map", nullMap) + .set("enum", null) + .build(); + Row expectedRow = + Row.withSchema(expectedSchema) + .addValue(null) + .addValue(Lists.newArrayList((Object) null)) + .addValue(nullMap) + .addValue(null) + .build(); + assertEquals(expectedRow, AvroUtils.toBeamRowStrict(genericRecord, expectedSchema)); + } + + @Test + public void testNullableFieldsInBeamSchema() { + Schema beamSchema = + Schema.builder() + .addNullableField("int", FieldType.INT32) + .addArrayField("array", FieldType.INT32.withNullable(true)) + .addMapField("map", FieldType.STRING, FieldType.INT32.withNullable(true)) + .build(); + + List fields = Lists.newArrayList(); + fields.add( + new org.apache.avro.Schema.Field( + "int", ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT)), "", null)); + fields.add( + new org.apache.avro.Schema.Field( + "array", + org.apache.avro.Schema.createArray( + ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))), + "", + null)); + fields.add( + new org.apache.avro.Schema.Field( + "map", + org.apache.avro.Schema.createMap( + ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))), + "", + null)); + org.apache.avro.Schema avroSchema = + org.apache.avro.Schema.createRecord("topLevelRecord", null, null, false, fields); + assertEquals(avroSchema, AvroUtils.toAvroSchema(beamSchema)); + + Map nullMapUtf8 = Maps.newHashMap(); + nullMapUtf8.put(new Utf8("k1"), null); + Map nullMapString = Maps.newHashMap(); + nullMapString.put("k1", null); + + GenericRecord expectedGenericRecord = + new GenericRecordBuilder(avroSchema) + .set("int", null) + .set("array", Lists.newArrayList((Object) null)) + .set("map", nullMapUtf8) + .build(); + Row row = + Row.withSchema(beamSchema) + .addValue(null) + .addValue(Lists.newArrayList((Object) null)) + .addValue(nullMapString) + .build(); + assertEquals(expectedGenericRecord, AvroUtils.toGenericRecord(row, avroSchema)); + } + + @Test + public void testUnionFieldInAvroSchema() { + + List fields = Lists.newArrayList(); + List unionFields = Lists.newArrayList(); + + unionFields.add(org.apache.avro.Schema.create(Type.INT)); + unionFields.add(org.apache.avro.Schema.create(Type.STRING)); + + fields.add( + new org.apache.avro.Schema.Field( + "union", org.apache.avro.Schema.createUnion(unionFields), "", null)); + org.apache.avro.Schema avroSchema = + org.apache.avro.Schema.createRecord("topLevelRecord", null, null, false, fields); + OneOfType oneOfType = + OneOfType.create(Field.of("int", FieldType.INT32), Field.of("string", FieldType.STRING)); + + Schema expectedSchema = Schema.builder().addLogicalTypeField("union", oneOfType).build(); + assertEquals(expectedSchema, AvroUtils.toBeamSchema(avroSchema)); + GenericRecord genericRecord = new GenericRecordBuilder(avroSchema).set("union", 23423).build(); + Row expectedRow = + Row.withSchema(expectedSchema).addValue(oneOfType.createValue(0, 23423)).build(); + assertEquals(expectedRow, AvroUtils.toBeamRowStrict(genericRecord, expectedSchema)); + } + + @Test + public void testUnionFieldInBeamSchema() { + OneOfType oneOfType = + OneOfType.create(Field.of("int", FieldType.INT32), Field.of("string", FieldType.STRING)); + + Schema beamSchema = Schema.builder().addLogicalTypeField("union", oneOfType).build(); + List fields = Lists.newArrayList(); + List unionFields = Lists.newArrayList(); + + unionFields.add(org.apache.avro.Schema.create(Type.INT)); + unionFields.add(org.apache.avro.Schema.create(Type.STRING)); + fields.add( + new org.apache.avro.Schema.Field( + "union", org.apache.avro.Schema.createUnion(unionFields), "", null)); + org.apache.avro.Schema avroSchema = + org.apache.avro.Schema.createRecord("topLevelRecord", null, null, false, fields); + GenericRecord expectedGenericRecord = + new GenericRecordBuilder(avroSchema).set("union", 23423).build(); + Row row = Row.withSchema(beamSchema).addValue(oneOfType.createValue(0, 23423)).build(); + assertEquals(expectedGenericRecord, AvroUtils.toGenericRecord(row, avroSchema)); + } + + @Test + public void testJdbcLogicalVarCharRowDataToAvroSchema() { + String expectedAvroSchemaJson = + "{ " + + " \"name\": \"topLevelRecord\", " + + " \"type\": \"record\", " + + " \"fields\": [{ " + + " \"name\": \"my_varchar_field\", " + + " \"type\": {\"type\": \"string\", \"logicalType\": \"varchar\", \"maxLength\": 10}" + + " }, " + + " { " + + " \"name\": \"my_longvarchar_field\", " + + " \"type\": {\"type\": \"string\", \"logicalType\": \"varchar\", \"maxLength\": 50}" + + " }, " + + " { " + + " \"name\": \"my_nvarchar_field\", " + + " \"type\": {\"type\": \"string\", \"logicalType\": \"varchar\", \"maxLength\": 10}" + + " }, " + + " { " + + " \"name\": \"my_longnvarchar_field\", " + + " \"type\": {\"type\": \"string\", \"logicalType\": \"varchar\", \"maxLength\": 50}" + + " }, " + + " { " + + " \"name\": \"fixed_length_char_field\", " + + " \"type\": {\"type\": \"string\", \"logicalType\": \"char\", \"maxLength\": 25}" + + " } " + + " ] " + + "}"; + + Schema beamSchema = + Schema.builder() + .addField( + Field.of( + "my_varchar_field", FieldType.logicalType(JdbcType.StringType.varchar(10)))) + .addField( + Field.of( + "my_longvarchar_field", + FieldType.logicalType(JdbcType.StringType.longvarchar(50)))) + .addField( + Field.of( + "my_nvarchar_field", FieldType.logicalType(JdbcType.StringType.nvarchar(10)))) + .addField( + Field.of( + "my_longnvarchar_field", + FieldType.logicalType(JdbcType.StringType.longnvarchar(50)))) + .addField( + Field.of( + "fixed_length_char_field", + FieldType.logicalType(JdbcType.StringType.fixedLengthChar(25)))) + .build(); + + assertEquals( + new org.apache.avro.Schema.Parser().parse(expectedAvroSchemaJson), + AvroUtils.toAvroSchema(beamSchema)); + } + + @Test + public void testJdbcLogicalVarCharRowDataToGenericRecord() { + Schema beamSchema = + Schema.builder() + .addField( + Field.of( + "my_varchar_field", FieldType.logicalType(JdbcType.StringType.varchar(10)))) + .addField( + Field.of( + "my_longvarchar_field", + FieldType.logicalType(JdbcType.StringType.longvarchar(50)))) + .addField( + Field.of( + "my_nvarchar_field", FieldType.logicalType(JdbcType.StringType.nvarchar(10)))) + .addField( + Field.of( + "my_longnvarchar_field", + FieldType.logicalType(JdbcType.StringType.longnvarchar(50)))) + .build(); + + Row rowData = + Row.withSchema(beamSchema) + .addValue("varchar_value") + .addValue("longvarchar_value") + .addValue("nvarchar_value") + .addValue("longnvarchar_value") + .build(); + + org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(beamSchema); + GenericRecord expectedRecord = + new GenericRecordBuilder(avroSchema) + .set("my_varchar_field", "varchar_value") + .set("my_longvarchar_field", "longvarchar_value") + .set("my_nvarchar_field", "nvarchar_value") + .set("my_longnvarchar_field", "longnvarchar_value") + .build(); + + assertEquals(expectedRecord, AvroUtils.toGenericRecord(rowData, avroSchema)); + } + + @Test + public void testJdbcLogicalDateAndTimeRowDataToAvroSchema() { + String expectedAvroSchemaJson = + "{ " + + " \"name\": \"topLevelRecord\", " + + " \"type\": \"record\", " + + " \"fields\": [{ " + + " \"name\": \"my_date_field\", " + + " \"type\": { \"type\": \"int\", \"logicalType\": \"date\" }" + + " }, " + + " { " + + " \"name\": \"my_time_field\", " + + " \"type\": { \"type\": \"int\", \"logicalType\": \"time-millis\" }" + + " }" + + " ] " + + "}"; + + Schema beamSchema = + Schema.builder() + .addField(Field.of("my_date_field", FieldType.logicalType(JdbcType.DATE))) + .addField(Field.of("my_time_field", FieldType.logicalType(JdbcType.TIME))) + .build(); + + assertEquals( + new org.apache.avro.Schema.Parser().parse(expectedAvroSchemaJson), + AvroUtils.toAvroSchema(beamSchema)); + } + + @Test + public void testJdbcLogicalDateAndTimeRowDataToGenericRecord() { + // Test Fixed clock at + DateTime testDateTime = DateTime.parse("2021-05-29T11:15:16.234Z"); + + Schema beamSchema = + Schema.builder() + .addField(Field.of("my_date_field", FieldType.logicalType(JdbcType.DATE))) + .addField(Field.of("my_time_field", FieldType.logicalType(JdbcType.TIME))) + .build(); + + Row rowData = + Row.withSchema(beamSchema) + .addValue(testDateTime.toLocalDate().toDateTime(LocalTime.MIDNIGHT).toInstant()) + .addValue(Instant.ofEpochMilli(testDateTime.toLocalTime().millisOfDay().get())) + .build(); + + int daysFromEpoch = + Days.daysBetween( + Instant.EPOCH, + testDateTime.toLocalDate().toDateTime(LocalTime.MIDNIGHT).toInstant()) + .getDays(); + int timeSinceMidNight = testDateTime.toLocalTime().getMillisOfDay(); + + org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(beamSchema); + GenericRecord expectedRecord = + new GenericRecordBuilder(avroSchema) + .set("my_date_field", daysFromEpoch) + .set("my_time_field", timeSinceMidNight) + .build(); + + assertEquals(expectedRecord, AvroUtils.toGenericRecord(rowData, avroSchema)); + } + + @Test + public void testBeamRowToGenericRecord() { + GenericRecord genericRecord = AvroUtils.toGenericRecord(getBeamRow(), null); + assertEquals(getAvroSchema(), genericRecord.getSchema()); + assertEquals(getGenericRecord(), genericRecord); + } + + @Test + public void testBeamRowToGenericRecordInferSchema() { + GenericRecord genericRecord = AvroUtils.toGenericRecord(getBeamRow()); + assertEquals(getAvroSchema(), genericRecord.getSchema()); + assertEquals(getGenericRecord(), genericRecord); + } + + @Test + public void testRowToGenericRecordFunction() { + SerializableUtils.ensureSerializable(AvroUtils.getRowToGenericRecordFunction(NULL_SCHEMA)); + SerializableUtils.ensureSerializable(AvroUtils.getRowToGenericRecordFunction(null)); + } + + @Test + public void testGenericRecordToBeamRow() { + GenericRecord genericRecord = getGenericRecord(); + Row row = AvroUtils.toBeamRowStrict(getGenericRecord(), null); + assertEquals(getBeamRow(), row); + + // Alternatively, a timestamp-millis logical type can have a joda datum. + genericRecord.put("timestampMillis", new DateTime(genericRecord.get("timestampMillis"))); + row = AvroUtils.toBeamRowStrict(getGenericRecord(), null); + assertEquals(getBeamRow(), row); + } + + @Test + public void testGenericRecordToRowFunction() { + SerializableUtils.ensureSerializable(AvroUtils.getGenericRecordToRowFunction(Schema.of())); + SerializableUtils.ensureSerializable(AvroUtils.getGenericRecordToRowFunction(null)); + } + + @Test + public void testAvroSchemaCoders() { + Pipeline pipeline = Pipeline.create(); + org.apache.avro.Schema schema = + org.apache.avro.Schema.createRecord( + "TestSubRecord", + "TestSubRecord doc", + "org.apache.beam.sdk.schemas.utils", + false, + getAvroSubSchemaFields()); + GenericRecord record = + new GenericRecordBuilder(getAvroSubSchema("simple")) + .set("bool", true) + .set("int", 42) + .build(); + + PCollection records = + pipeline.apply(Create.of(record).withCoder(AvroCoder.of(schema))); + assertFalse(records.hasSchema()); + records.setCoder(AvroUtils.schemaCoder(schema)); + assertTrue(records.hasSchema()); + CoderProperties.coderSerializable(records.getCoder()); + + AvroGeneratedUser user = new AvroGeneratedUser("foo", 42, "green"); + PCollection users = + pipeline.apply(Create.of(user).withCoder(AvroCoder.of(AvroGeneratedUser.class))); + assertFalse(users.hasSchema()); + users.setCoder(AvroUtils.schemaCoder((AvroCoder) users.getCoder())); + assertTrue(users.hasSchema()); + CoderProperties.coderSerializable(users.getCoder()); + } + + @Test + public void testAvroBytesToRowAndRowToAvroBytesFunctions() { + Schema schema = + Schema.builder() + .addInt32Field("f_int") + .addInt64Field("f_long") + .addDoubleField("f_double") + .addStringField("f_string") + .build(); + + SimpleFunction toBytesFn = AvroUtils.getRowToAvroBytesFunction(schema); + SimpleFunction toRowFn = AvroUtils.getAvroBytesToRowFunction(schema); + + Row row = Row.withSchema(schema).attachValues(1, 1L, 1d, "string"); + + byte[] serializedRow = toBytesFn.apply(row); + Row deserializedRow = toRowFn.apply(serializedRow); + + assertEquals(row, deserializedRow); + } + + @Test + public void testNullSchemas() { + assertEquals( + AvroUtils.getFromRowFunction(GenericRecord.class), + AvroUtils.getFromRowFunction(GenericRecord.class)); + } + + /** Helper class that simulate JDBC Logical types. */ + private static class JdbcType implements Schema.LogicalType { + + private static final JdbcType DATE = + new JdbcType<>(JDBCType.DATE, FieldType.STRING, FieldType.DATETIME, ""); + private static final JdbcType TIME = + new JdbcType<>(JDBCType.TIME, FieldType.STRING, FieldType.DATETIME, ""); + + private final String identifier; + private final FieldType argumentType; + private final FieldType baseType; + private final Object argument; + + private static class StringType extends JdbcType { + + private static StringType fixedLengthChar(int size) { + return new StringType(JDBCType.CHAR, size); + } + + private static StringType varchar(int size) { + return new StringType(JDBCType.VARCHAR, size); + } + + private static StringType longvarchar(int size) { + return new StringType(JDBCType.LONGVARCHAR, size); + } + + private static StringType nvarchar(int size) { + return new StringType(JDBCType.NVARCHAR, size); + } + + private static StringType longnvarchar(int size) { + return new StringType(JDBCType.LONGNVARCHAR, size); + } + + private StringType(JDBCType type, int size) { + super(type, FieldType.INT32, FieldType.STRING, size); + } + } + + private JdbcType( + JDBCType jdbcType, FieldType argumentType, FieldType baseType, Object argument) { + this.identifier = jdbcType.getName(); + this.argumentType = argumentType; + this.baseType = baseType; + this.argument = argument; + } + + @Override + public String getIdentifier() { + return identifier; + } + + @Override + public @Nullable FieldType getArgumentType() { + return argumentType; + } + + @Override + public FieldType getBaseType() { + return baseType; + } + + @Override + @SuppressWarnings("TypeParameterUnusedInFormals") + public @Nullable T1 getArgument() { + return (T1) argument; + } + + @Override + public @NonNull T toBaseType(@NonNull T input) { + return input; + } + + @Override + public @NonNull T toInputType(@NonNull T base) { + return base; + } + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index a9f4de327ca3..8f578795bc93 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -129,6 +129,7 @@ include(":sdks:java:core:jmh") include(":sdks:java:expansion-service") include(":sdks:java:expansion-service:app") include(":sdks:java:extensions:arrow") +include(":sdks:java:extensions:avro") include(":sdks:java:extensions:euphoria") include(":sdks:java:extensions:kryo") include(":sdks:java:extensions:google-cloud-platform-core")