diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java index abd9bc46bd46..bdd30ba47699 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java @@ -46,13 +46,13 @@ public static class AbstractGetterTypeSupplier implements FieldValueTypeSupplier public static final AbstractGetterTypeSupplier INSTANCE = new AbstractGetterTypeSupplier(); @Override - public List get(Class clazz) { + public List get(TypeDescriptor typeDescriptor) { // If the generated class is passed in, we want to look at the base class to find the getters. - Class targetClass = AutoValueUtils.getBaseAutoValueClass(clazz); + TypeDescriptor targetTypeDescriptor = AutoValueUtils.getBaseAutoValueClass(typeDescriptor); List methods = - ReflectUtils.getMethods(targetClass).stream() + ReflectUtils.getMethods(targetTypeDescriptor.getRawType()).stream() .filter(ReflectUtils::isGetter) // All AutoValue getters are marked abstract. .filter(m -> Modifier.isAbstract(m.getModifiers())) @@ -62,7 +62,7 @@ public List get(Class clazz) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -89,9 +89,10 @@ private static void validateFieldNumbers(List types) } @Override - public List fieldValueGetters(Class targetClass, Schema schema) { + public List fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( - targetClass, + targetTypeDescriptor, schema, AbstractGetterTypeSupplier.INSTANCE, new DefaultTypeConversionsFactory()); @@ -99,17 +100,19 @@ public List fieldValueGetters(Class targetClass, Schema sch @Override public List fieldValueTypeInformations( - Class targetClass, Schema schema) { - return JavaBeanUtils.getFieldTypes(targetClass, schema, AbstractGetterTypeSupplier.INSTANCE); + TypeDescriptor targetTypeDescriptor, Schema schema) { + return JavaBeanUtils.getFieldTypes( + targetTypeDescriptor, schema, AbstractGetterTypeSupplier.INSTANCE); } @Override - public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema schema) { + public SchemaUserTypeCreator schemaTypeCreator( + TypeDescriptor targetTypeDescriptor, Schema schema) { // If a static method is marked with @SchemaCreate, use that. - Method annotated = ReflectUtils.getAnnotatedCreateMethod(targetClass); + Method annotated = ReflectUtils.getAnnotatedCreateMethod(targetTypeDescriptor.getRawType()); if (annotated != null) { return JavaBeanUtils.getStaticCreator( - targetClass, + targetTypeDescriptor, annotated, schema, AbstractGetterTypeSupplier.INSTANCE, @@ -119,7 +122,8 @@ public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema sche // Try to find a generated builder class. If one exists, use that to generate a // SchemaTypeCreator for creating AutoValue objects. SchemaUserTypeCreator creatorFactory = - AutoValueUtils.getBuilderCreator(targetClass, schema, AbstractGetterTypeSupplier.INSTANCE); + AutoValueUtils.getBuilderCreator( + targetTypeDescriptor, schema, AbstractGetterTypeSupplier.INSTANCE); if (creatorFactory != null) { return creatorFactory; } @@ -128,9 +132,10 @@ public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema sche // class. Use that for creating AutoValue objects. creatorFactory = AutoValueUtils.getConstructorCreator( - targetClass, schema, AbstractGetterTypeSupplier.INSTANCE); + targetTypeDescriptor, schema, AbstractGetterTypeSupplier.INSTANCE); if (creatorFactory == null) { - throw new RuntimeException("Could not find a way to create AutoValue class " + targetClass); + throw new RuntimeException( + "Could not find a way to create AutoValue class " + targetTypeDescriptor); } return creatorFactory; @@ -139,6 +144,6 @@ public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema sche @Override public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { return JavaBeanUtils.schemaFromJavaBeanClass( - typeDescriptor.getRawType(), AbstractGetterTypeSupplier.INSTANCE); + typeDescriptor, AbstractGetterTypeSupplier.INSTANCE); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java index 2c140bd1dfef..8725833bc1da 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java @@ -19,6 +19,7 @@ import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.sdk.values.TypeDescriptor; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -36,7 +37,7 @@ "rawtypes" }) public class CachingFactory implements Factory { - private transient @Nullable ConcurrentHashMap cache = null; + private transient @Nullable ConcurrentHashMap, CreatedT> cache = null; private final Factory innerFactory; @@ -45,16 +46,16 @@ public CachingFactory(Factory innerFactory) { } @Override - public CreatedT create(Class clazz, Schema schema) { + public CreatedT create(TypeDescriptor typeDescriptor, Schema schema) { if (cache == null) { cache = new ConcurrentHashMap<>(); } - CreatedT cached = cache.get(clazz); + CreatedT cached = cache.get(typeDescriptor); if (cached != null) { return cached; } - cached = innerFactory.create(clazz, schema); - cache.put(clazz, cached); + cached = innerFactory.create(typeDescriptor, schema); + cache.put(typeDescriptor, cached); return cached; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Factory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Factory.java index f9da36b97c77..f302f20cfb64 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Factory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Factory.java @@ -19,9 +19,10 @@ import java.io.Serializable; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.values.TypeDescriptor; /** A Factory interface for schema-related objects for a specific Java type. */ @Internal public interface Factory extends Serializable { - T create(Class clazz, Schema schema); + T create(TypeDescriptor typeDescriptor, Schema schema); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index 750709192c08..3a0ebf13aff7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -28,6 +28,7 @@ import java.util.Collections; import java.util.Map; import java.util.stream.Stream; +import org.apache.beam.sdk.schemas.AutoValue_FieldValueTypeInformation.Builder; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -78,7 +79,7 @@ public abstract class FieldValueTypeInformation implements Serializable { /** If the field has a description, returns the description for the field. */ public abstract @Nullable String getDescription(); - abstract Builder toBuilder(); + public abstract Builder toBuilder(); @AutoValue.Builder public abstract static class Builder { @@ -106,7 +107,7 @@ public abstract static class Builder { public abstract Builder setDescription(@Nullable String fieldDescription); - abstract FieldValueTypeInformation build(); + public abstract FieldValueTypeInformation build(); } public static FieldValueTypeInformation forOneOf( @@ -125,8 +126,9 @@ public static FieldValueTypeInformation forOneOf( .build(); } - public static FieldValueTypeInformation forField(Field field, int index) { - TypeDescriptor type = TypeDescriptor.of(field.getGenericType()); + public static FieldValueTypeInformation forField( + TypeDescriptor typeDescriptor, Field field, int index) { + TypeDescriptor type = typeDescriptor.resolveType(field.getGenericType()); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(field.getName(), field)) .setNumber(getNumberOverride(index, field)) @@ -134,9 +136,9 @@ public static FieldValueTypeInformation forField(Field field, int index) { .setType(type) .setRawType(type.getRawType()) .setField(field) - .setElementType(getIterableComponentType(field)) - .setMapKeyType(getMapKeyType(field)) - .setMapValueType(getMapValueType(field)) + .setElementType(getIterableComponentType(type)) + .setMapKeyType(getMapKeyType(type)) + .setMapValueType(getMapValueType(type)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(field)) .build(); @@ -184,7 +186,8 @@ public static String getNameOverride( return fieldDescription.value(); } - public static FieldValueTypeInformation forGetter(Method method, int index) { + public static FieldValueTypeInformation forGetter( + TypeDescriptor typeDescriptor, Method method, int index) { String name; if (method.getName().startsWith("get")) { name = ReflectUtils.stripPrefix(method.getName(), "get"); @@ -193,8 +196,7 @@ public static FieldValueTypeInformation forGetter(Method method, int index) { } else { throw new RuntimeException("Getter has wrong prefix " + method.getName()); } - - TypeDescriptor type = TypeDescriptor.of(method.getGenericReturnType()); + TypeDescriptor type = typeDescriptor.resolveType(method.getGenericReturnType()); boolean nullable = hasNullableReturnType(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(name, method)) @@ -252,11 +254,12 @@ private static boolean isNullableAnnotation(Annotation annotation) { return annotation.annotationType().getSimpleName().equals("Nullable"); } - public static FieldValueTypeInformation forSetter(Method method) { - return forSetter(method, "set"); + public static FieldValueTypeInformation forSetter(TypeDescriptor typeDescriptor, Method method) { + return forSetter(typeDescriptor, method, "set"); } - public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) { + public static FieldValueTypeInformation forSetter( + TypeDescriptor typeDescriptor, Method method, String setterPrefix) { String name; if (method.getName().startsWith(setterPrefix)) { name = ReflectUtils.stripPrefix(method.getName(), setterPrefix); @@ -264,7 +267,7 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr throw new RuntimeException("Setter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericParameterTypes()[0]); + TypeDescriptor type = typeDescriptor.resolveType(method.getGenericParameterTypes()[0]); boolean nullable = hasSingleNullableParameter(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(name) @@ -279,15 +282,27 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr .build(); } + public static FieldValueTypeInformation.Builder builder() { + return new AutoValue_FieldValueTypeInformation.Builder(); + } + public FieldValueTypeInformation withName(String name) { return toBuilder().setName(name).build(); } - private static FieldValueTypeInformation getIterableComponentType(Field field) { - return getIterableComponentType(TypeDescriptor.of(field.getGenericType())); + public FieldValueTypeInformation withTypesFrom(FieldValueTypeInformation other) { + return toBuilder() + .setType(other.getType()) + .setRawType(other.getRawType()) + .setElementType(other.getElementType()) + .setMapKeyType(other.getMapKeyType()) + .setMapValueType(other.getMapValueType()) + .setOneOfTypes(other.getOneOfTypes()) + .build(); } - static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { + public static @Nullable FieldValueTypeInformation getIterableComponentType( + TypeDescriptor valueType) { // TODO: Figure out nullable elements. TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); if (componentType == null) { @@ -308,22 +323,14 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { // If the Field is a map type, returns the key type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) { - return getMapKeyType(TypeDescriptor.of(field.getGenericType())); - } - - private static @Nullable FieldValueTypeInformation getMapKeyType( + public static @Nullable FieldValueTypeInformation getMapKeyType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 0); } // If the Field is a map type, returns the value type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapValueType(Field field) { - return getMapType(TypeDescriptor.of(field.getGenericType()), 1); - } - - private static @Nullable FieldValueTypeInformation getMapValueType( + public static @Nullable FieldValueTypeInformation getMapValueType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 1); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java index 53c098599c36..c0be0910a264 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.RowWithGetters; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Function; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; @@ -47,23 +48,28 @@ "rawtypes" }) class FromRowUsingCreator implements SerializableFunction, Function { - private final Class clazz; + private final TypeDescriptor typeDescriptor; private final GetterBasedSchemaProvider schemaProvider; private final Factory schemaTypeCreatorFactory; @SuppressFBWarnings("SE_TRANSIENT_FIELD_NOT_RESTORED") private transient @MonotonicNonNull Function[] fieldConverters; - public FromRowUsingCreator(Class clazz, GetterBasedSchemaProvider schemaProvider) { - this(clazz, schemaProvider, new CachingFactory<>(schemaProvider::schemaTypeCreator), null); + public FromRowUsingCreator( + TypeDescriptor typeDescriptor, GetterBasedSchemaProvider schemaProvider) { + this( + typeDescriptor, + schemaProvider, + new CachingFactory<>(schemaProvider::schemaTypeCreator), + null); } private FromRowUsingCreator( - Class clazz, + TypeDescriptor typeDescriptor, GetterBasedSchemaProvider schemaProvider, Factory schemaTypeCreatorFactory, @Nullable Function[] fieldConverters) { - this.clazz = clazz; + this.typeDescriptor = typeDescriptor; this.schemaProvider = schemaProvider; this.schemaTypeCreatorFactory = schemaTypeCreatorFactory; this.fieldConverters = fieldConverters; @@ -76,10 +82,10 @@ public T apply(Row row) { return null; } if (row instanceof RowWithGetters) { - Object target = ((RowWithGetters) row).getGetterTarget(); - if (target.getClass().equals(clazz)) { + RowWithGetters rowWithGetters = (RowWithGetters) row; + if (rowWithGetters.getGetterTargetType().equals(typeDescriptor)) { // Efficient path: simply extract the underlying object instead of creating a new one. - return (T) target; + return (T) rowWithGetters.getGetterTarget(); } } if (fieldConverters == null) { @@ -91,7 +97,8 @@ public T apply(Row row) { for (int i = 0; i < row.getFieldCount(); ++i) { params[i] = fieldConverters[i].apply(row.getValue(i)); } - SchemaUserTypeCreator creator = schemaTypeCreatorFactory.create(clazz, row.getSchema()); + SchemaUserTypeCreator creator = + schemaTypeCreatorFactory.create(typeDescriptor, row.getSchema()); return (T) creator.create(params); } @@ -99,13 +106,15 @@ private synchronized void initFieldConverters(Schema schema) { if (fieldConverters == null) { CachingFactory> typeFactory = new CachingFactory<>(schemaProvider::fieldValueTypeInformations); - fieldConverters = fieldConverters(clazz, schema, typeFactory); + fieldConverters = fieldConverters(typeDescriptor, schema, typeFactory); } } private Function[] fieldConverters( - Class clazz, Schema schema, Factory> typeFactory) { - List typeInfos = typeFactory.create(clazz, schema); + TypeDescriptor typeDescriptor, + Schema schema, + Factory> typeFactory) { + List typeInfos = typeFactory.create(typeDescriptor, schema); checkState( typeInfos.size() == schema.getFieldCount(), "Did not have a matching number of type informations and fields."); @@ -133,10 +142,9 @@ private Function fieldConverter( if (!needsConversion(type)) { return FieldConverter.IDENTITY; } else if (TypeName.ROW.equals(type.getTypeName())) { - Function[] converters = - fieldConverters(typeInfo.getRawType(), type.getRowSchema(), typeFactory); + Function[] converters = fieldConverters(typeInfo.getType(), type.getRowSchema(), typeFactory); return new FromRowUsingCreator( - typeInfo.getRawType(), schemaProvider, schemaTypeCreatorFactory, converters); + typeInfo.getType(), schemaProvider, schemaTypeCreatorFactory, converters); } else if (TypeName.ARRAY.equals(type.getTypeName())) { return new ConvertCollection( fieldConverter(type.getCollectionElementType(), typeInfo.getElementType(), typeFactory)); @@ -271,11 +279,11 @@ public boolean equals(@Nullable Object o) { return false; } FromRowUsingCreator that = (FromRowUsingCreator) o; - return clazz.equals(that.clazz) && schemaProvider.equals(that.schemaProvider); + return typeDescriptor.equals(that.typeDescriptor) && schemaProvider.equals(that.schemaProvider); } @Override public int hashCode() { - return Objects.hash(clazz, schemaProvider); + return Objects.hash(typeDescriptor, schemaProvider); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java index 2b697bebd815..db97029a5b00 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java @@ -20,15 +20,20 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; +import org.apache.beam.sdk.schemas.utils.ReflectUtils; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; @@ -48,31 +53,37 @@ }) public abstract class GetterBasedSchemaProvider implements SchemaProvider { /** Implementing class should override to return FieldValueGetters. */ - public abstract List fieldValueGetters(Class targetClass, Schema schema); + public abstract List fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema); /** Implementing class should override to return a list of type-informations. */ public abstract List fieldValueTypeInformations( - Class targetClass, Schema schema); + TypeDescriptor targetTypeDescriptor, Schema schema); /** Implementing class should override to return a constructor. */ - public abstract SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema schema); + public abstract SchemaUserTypeCreator schemaTypeCreator( + TypeDescriptor targetTypeDescriptor, Schema schema); private class ToRowWithValueGetters implements SerializableFunction { private final Schema schema; + private final TypeDescriptor getterTargetType; private final Factory> getterFactory; - public ToRowWithValueGetters(Schema schema) { + public ToRowWithValueGetters(Schema schema, TypeDescriptor getterTargetType) { this.schema = schema; + this.getterTargetType = getterTargetType; // Since we know that this factory is always called from inside the lambda with the same // schema, return a caching factory that caches the first value seen for each class. This // prevents having to lookup the getter list each time createGetters is called. this.getterFactory = - RowValueGettersFactory.of(GetterBasedSchemaProvider.this::fieldValueGetters); + RowValueGettersFactory.of( + GetterBasedSchemaProvider.this::fieldValueGetters, + GetterBasedSchemaProvider.this::fieldValueTypeInformations); } @Override public Row apply(T input) { - return Row.withSchema(schema).withFieldValueGetters(getterFactory, input); + return Row.withSchema(schema).withFieldValueGetters(getterTargetType, getterFactory, input); } private GetterBasedSchemaProvider getOuter() { @@ -107,14 +118,13 @@ public SerializableFunction toRowFunction(TypeDescriptor typeDesc // workers would see different versions of the schema. Schema schema = schemaFor(typeDescriptor); - return new ToRowWithValueGetters<>(schema); + return new ToRowWithValueGetters<>(schema, typeDescriptor); } @Override @SuppressWarnings("unchecked") public SerializableFunction fromRowFunction(TypeDescriptor typeDescriptor) { - Class clazz = (Class) typeDescriptor.getType(); - return new FromRowUsingCreator<>(clazz, this); + return new FromRowUsingCreator<>(typeDescriptor, this); } @Override @@ -130,22 +140,35 @@ public boolean equals(@Nullable Object obj) { private static class RowValueGettersFactory implements Factory> { private final Factory> gettersFactory; private final Factory> cachingGettersFactory; + private final Factory> typeInfoFactory; - static Factory> of(Factory> gettersFactory) { - return new RowValueGettersFactory(gettersFactory).cachingGettersFactory; + static Factory> of( + Factory> gettersFactory, + Factory> typeInfoFactory) { + return new RowValueGettersFactory(gettersFactory, typeInfoFactory).cachingGettersFactory; } - RowValueGettersFactory(Factory> gettersFactory) { + RowValueGettersFactory( + Factory> gettersFactory, + Factory> typeInfoFactory) { this.gettersFactory = gettersFactory; this.cachingGettersFactory = new CachingFactory<>(this); + this.typeInfoFactory = typeInfoFactory; } @Override - public List create(Class clazz, Schema schema) { - List getters = gettersFactory.create(clazz, schema); + public List create(TypeDescriptor typeDescriptor, Schema schema) { + List getters = gettersFactory.create(typeDescriptor, schema); + Map typeInfoByName = + typeInfoFactory.create(typeDescriptor, schema).stream() + .collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity())); List rowGetters = new ArrayList<>(getters.size()); for (int i = 0; i < getters.size(); i++) { - rowGetters.add(rowValueGetter(getters.get(i), schema.getField(i).getType())); + rowGetters.add( + rowValueGetter( + getters.get(i), + schema.getField(i).getType(), + typeInfoByName.get(getters.get(i).name()).getType())); } return rowGetters; } @@ -161,22 +184,48 @@ && needsConversion(type.getCollectionElementType())) || needsConversion(type.getMapValueType()))); } - FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { + FieldValueGetter rowValueGetter( + FieldValueGetter base, FieldType type, @Nullable TypeDescriptor getterReturnType) { TypeName typeName = type.getTypeName(); + if (!needsConversion(type)) { return base; } if (typeName.equals(TypeName.ROW)) { - return new GetRow(base, type.getRowSchema(), cachingGettersFactory); + return new GetRow(base, getterReturnType, type.getRowSchema(), cachingGettersFactory); } else if (typeName.equals(TypeName.ARRAY)) { FieldType elementType = type.getCollectionElementType(); + TypeDescriptor elementTypeDescriptor = + Optional.ofNullable(getterReturnType) + .map(getterType -> ReflectUtils.getIterableComponentType(getterType)) + .orElse(null); return elementType.getTypeName().equals(TypeName.ROW) - ? new GetEagerCollection(base, converter(elementType)) - : new GetCollection(base, converter(elementType)); + ? new GetEagerCollection(base, converter(elementType, elementTypeDescriptor)) + : new GetCollection(base, converter(elementType, elementTypeDescriptor)); } else if (typeName.equals(TypeName.ITERABLE)) { - return new GetIterable(base, converter(type.getCollectionElementType())); + TypeDescriptor elementTypeDescriptor = + Optional.ofNullable(getterReturnType) + .map(getterType -> ReflectUtils.getIterableComponentType(getterType)) + .orElse(null); + return new GetIterable( + base, converter(type.getCollectionElementType(), elementTypeDescriptor)); } else if (typeName.equals(TypeName.MAP)) { - return new GetMap(base, converter(type.getMapKeyType()), converter(type.getMapValueType())); + TypeDescriptor[] resolvedKeyValueTypes = + Optional.ofNullable(getterReturnType) + .map( + getterType -> + Arrays.stream(Map.class.getTypeParameters()) + .map( + typeVar -> { + TypeDescriptor resolved = getterType.resolveType(typeVar); + return resolved.hasUnresolvedParameters() ? null : resolved; + }) + .toArray(TypeDescriptor[]::new)) + .orElse(new TypeDescriptor[] {null, null}); + return new GetMap( + base, + converter(type.getMapKeyType(), resolvedKeyValueTypes[0]), + converter(type.getMapValueType(), resolvedKeyValueTypes[1])); } else if (type.isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = type.getLogicalType(OneOfType.class); Schema oneOfSchema = oneOfType.getOneOfSchema(); @@ -185,7 +234,7 @@ FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { Map converters = Maps.newHashMapWithExpectedSize(values.size()); for (Map.Entry kv : values.entrySet()) { FieldType fieldType = oneOfSchema.getField(kv.getKey()).getType(); - FieldValueGetter converter = converter(fieldType); + FieldValueGetter converter = converter(fieldType, null); converters.put(kv.getValue(), converter); } @@ -196,23 +245,33 @@ FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { return base; } - FieldValueGetter converter(FieldType type) { - return rowValueGetter(IDENTITY, type); + FieldValueGetter converter(FieldType type, @Nullable TypeDescriptor getterType) { + return rowValueGetter(IDENTITY, type, getterType); } static class GetRow extends Converter { final Schema schema; final Factory> factory; + @Nullable final TypeDescriptor valueType; - GetRow(FieldValueGetter getter, Schema schema, Factory> factory) { + GetRow( + FieldValueGetter getter, + @Nullable TypeDescriptor valueType, + Schema schema, + Factory> factory) { super(getter); this.schema = schema; this.factory = factory; + this.valueType = valueType; } @Override Object convert(Object value) { - return Row.withSchema(schema).withFieldValueGetters(factory, value); + return Row.withSchema(schema) + .withFieldValueGetters( + Optional.ofNullable(valueType).orElse(TypeDescriptor.of(value.getClass())), + factory, + value); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java index 7024e8be86cf..acc588b201a8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java @@ -60,15 +60,15 @@ public static class GetterTypeSupplier implements FieldValueTypeSupplier { public static final GetterTypeSupplier INSTANCE = new GetterTypeSupplier(); @Override - public List get(Class clazz) { + public List get(TypeDescriptor typeDescriptor) { List methods = - ReflectUtils.getMethods(clazz).stream() + ReflectUtils.getMethods(typeDescriptor.getRawType()).stream() .filter(ReflectUtils::isGetter) .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -110,29 +110,32 @@ public static class SetterTypeSupplier implements FieldValueTypeSupplier { private static final SetterTypeSupplier INSTANCE = new SetterTypeSupplier(); @Override - public List get(Class clazz) { - return ReflectUtils.getMethods(clazz).stream() + public List get(TypeDescriptor typeDescriptor) { + return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream() .filter(ReflectUtils::isSetter) .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) - .map(FieldValueTypeInformation::forSetter) + .map(m -> FieldValueTypeInformation.forSetter(typeDescriptor, m)) .map( t -> { if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) { throw new RuntimeException( String.format( - "@SchemaFieldNumber can only be used on getters in Java Beans. Found on setter '%s'", + "@SchemaFieldNumber can only be used on getters in Java Beans. Found on" + + " setter '%s'", t.getMethod().getName())); } if (t.getMethod().getAnnotation(SchemaFieldName.class) != null) { throw new RuntimeException( String.format( - "@SchemaFieldName can only be used on getters in Java Beans. Found on setter '%s'", + "@SchemaFieldName can only be used on getters in Java Beans. Found on" + + " setter '%s'", t.getMethod().getName())); } if (t.getMethod().getAnnotation(SchemaCaseFormat.class) != null) { throw new RuntimeException( String.format( - "@SchemaCaseFormat can only be used on getters in Java Beans. Found on setter '%s'", + "@SchemaCaseFormat can only be used on getters in Java Beans. Found on" + + " setter '%s'", t.getMethod().getName())); } return t; @@ -154,40 +157,44 @@ public boolean equals(@Nullable Object obj) { @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { Schema schema = - JavaBeanUtils.schemaFromJavaBeanClass( - typeDescriptor.getRawType(), GetterTypeSupplier.INSTANCE); + JavaBeanUtils.schemaFromJavaBeanClass(typeDescriptor, GetterTypeSupplier.INSTANCE); // If there are no creator methods, then validate that we have setters for every field. // Otherwise, we will have no way of creating instances of the class. if (ReflectUtils.getAnnotatedCreateMethod(typeDescriptor.getRawType()) == null && ReflectUtils.getAnnotatedConstructor(typeDescriptor.getRawType()) == null) { JavaBeanUtils.validateJavaBean( - GetterTypeSupplier.INSTANCE.get(typeDescriptor.getRawType(), schema), - SetterTypeSupplier.INSTANCE.get(typeDescriptor.getRawType(), schema), + GetterTypeSupplier.INSTANCE.get(typeDescriptor, schema), + SetterTypeSupplier.INSTANCE.get(typeDescriptor, schema), schema); } return schema; } @Override - public List fieldValueGetters(Class targetClass, Schema schema) { + public List fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( - targetClass, schema, GetterTypeSupplier.INSTANCE, new DefaultTypeConversionsFactory()); + targetTypeDescriptor, + schema, + GetterTypeSupplier.INSTANCE, + new DefaultTypeConversionsFactory()); } @Override public List fieldValueTypeInformations( - Class targetClass, Schema schema) { - return JavaBeanUtils.getFieldTypes(targetClass, schema, GetterTypeSupplier.INSTANCE); + TypeDescriptor targetTypeDescriptor, Schema schema) { + return JavaBeanUtils.getFieldTypes(targetTypeDescriptor, schema, GetterTypeSupplier.INSTANCE); } @Override - public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema schema) { + public SchemaUserTypeCreator schemaTypeCreator( + TypeDescriptor targetTypeDescriptor, Schema schema) { // If a static method is marked with @SchemaCreate, use that. - Method annotated = ReflectUtils.getAnnotatedCreateMethod(targetClass); + Method annotated = ReflectUtils.getAnnotatedCreateMethod(targetTypeDescriptor.getRawType()); if (annotated != null) { return JavaBeanUtils.getStaticCreator( - targetClass, + targetTypeDescriptor, annotated, schema, GetterTypeSupplier.INSTANCE, @@ -195,10 +202,11 @@ public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema sche } // If a Constructor was tagged with @SchemaCreate, invoke that constructor. - Constructor constructor = ReflectUtils.getAnnotatedConstructor(targetClass); + Constructor constructor = + ReflectUtils.getAnnotatedConstructor(targetTypeDescriptor.getRawType()); if (constructor != null) { return JavaBeanUtils.getConstructorCreator( - targetClass, + targetTypeDescriptor, constructor, schema, GetterTypeSupplier.INSTANCE, @@ -208,15 +216,15 @@ public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema sche // Else try to make a setter-based creator Factory setterBasedFactory = new SetterBasedCreatorFactory(new JavaBeanSetterFactory()); - return setterBasedFactory.create(targetClass, schema); + return setterBasedFactory.create(targetTypeDescriptor, schema); } /** A factory for creating {@link FieldValueSetter} objects for a JavaBean object. */ private static class JavaBeanSetterFactory implements Factory> { @Override - public List create(Class targetClass, Schema schema) { + public List create(TypeDescriptor typeDescriptor, Schema schema) { return JavaBeanUtils.getSetters( - targetClass, schema, SetterTypeSupplier.INSTANCE, new DefaultTypeConversionsFactory()); + typeDescriptor, schema, SetterTypeSupplier.INSTANCE, new DefaultTypeConversionsFactory()); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java index 16b96f1c7ae1..f6a198821c5e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java @@ -57,22 +57,22 @@ public static class JavaFieldTypeSupplier implements FieldValueTypeSupplier { public static final JavaFieldTypeSupplier INSTANCE = new JavaFieldTypeSupplier(); @Override - public List get(Class clazz) { + public List get(TypeDescriptor typeDescriptor) { List fields = - ReflectUtils.getFields(clazz).stream() + ReflectUtils.getFields(typeDescriptor.getRawType()).stream() .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(fields.size()); for (int i = 0; i < fields.size(); ++i) { - types.add(FieldValueTypeInformation.forField(fields.get(i), i)); + types.add(FieldValueTypeInformation.forField(typeDescriptor, fields.get(i), i)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); // If there are no creators registered, then make sure none of the schema fields are final, // as we (currently) have no way of creating classes in this case. - if (ReflectUtils.getAnnotatedCreateMethod(clazz) == null - && ReflectUtils.getAnnotatedConstructor(clazz) == null) { + if (ReflectUtils.getAnnotatedCreateMethod(typeDescriptor.getRawType()) == null + && ReflectUtils.getAnnotatedConstructor(typeDescriptor.getRawType()) == null) { Optional finalField = types.stream() .map(FieldValueTypeInformation::getField) @@ -81,7 +81,7 @@ public List get(Class clazz) { if (finalField.isPresent()) { throw new IllegalArgumentException( "Class " - + clazz + + typeDescriptor + " has final fields and no " + "registered creator. Cannot use as schema, as we don't know how to create this " + "object automatically"); @@ -111,29 +111,33 @@ private static void validateFieldNumbers(List types) @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { - return POJOUtils.schemaFromPojoClass( - typeDescriptor.getRawType(), JavaFieldTypeSupplier.INSTANCE); + return POJOUtils.schemaFromPojoClass(typeDescriptor, JavaFieldTypeSupplier.INSTANCE); } @Override - public List fieldValueGetters(Class targetClass, Schema schema) { + public List fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return POJOUtils.getGetters( - targetClass, schema, JavaFieldTypeSupplier.INSTANCE, new DefaultTypeConversionsFactory()); + targetTypeDescriptor, + schema, + JavaFieldTypeSupplier.INSTANCE, + new DefaultTypeConversionsFactory()); } @Override public List fieldValueTypeInformations( - Class targetClass, Schema schema) { - return POJOUtils.getFieldTypes(targetClass, schema, JavaFieldTypeSupplier.INSTANCE); + TypeDescriptor targetTypeDescriptor, Schema schema) { + return POJOUtils.getFieldTypes(targetTypeDescriptor, schema, JavaFieldTypeSupplier.INSTANCE); } @Override - public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema schema) { + public SchemaUserTypeCreator schemaTypeCreator( + TypeDescriptor targetTypeDescriptor, Schema schema) { // If a static method is marked with @SchemaCreate, use that. - Method annotated = ReflectUtils.getAnnotatedCreateMethod(targetClass); + Method annotated = ReflectUtils.getAnnotatedCreateMethod(targetTypeDescriptor.getRawType()); if (annotated != null) { return POJOUtils.getStaticCreator( - targetClass, + targetTypeDescriptor, annotated, schema, JavaFieldTypeSupplier.INSTANCE, @@ -141,10 +145,11 @@ public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema sche } // If a Constructor was tagged with @SchemaCreate, invoke that constructor. - Constructor constructor = ReflectUtils.getAnnotatedConstructor(targetClass); + Constructor constructor = + ReflectUtils.getAnnotatedConstructor(targetTypeDescriptor.getRawType()); if (constructor != null) { return POJOUtils.getConstructorCreator( - targetClass, + targetTypeDescriptor, constructor, schema, JavaFieldTypeSupplier.INSTANCE, @@ -152,6 +157,9 @@ public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema sche } return POJOUtils.getSetFieldCreator( - targetClass, schema, JavaFieldTypeSupplier.INSTANCE, new DefaultTypeConversionsFactory()); + targetTypeDescriptor, + schema, + JavaFieldTypeSupplier.INSTANCE, + new DefaultTypeConversionsFactory()); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SetterBasedCreatorFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SetterBasedCreatorFactory.java index 7663651ae7c9..e7ded3c52af5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SetterBasedCreatorFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SetterBasedCreatorFactory.java @@ -19,6 +19,7 @@ import java.lang.reflect.InvocationTargetException; import java.util.List; +import org.apache.beam.sdk.values.TypeDescriptor; /** * A {@link Factory} that uses a default constructor and a list of setters to construct a {@link @@ -35,14 +36,14 @@ public SetterBasedCreatorFactory(Factory> setterFactory) } @Override - public SchemaUserTypeCreator create(Class clazz, Schema schema) { - List setters = setterFactory.create(clazz, schema); + public SchemaUserTypeCreator create(TypeDescriptor typeDescriptor, Schema schema) { + List setters = setterFactory.create(typeDescriptor, schema); return new SchemaUserTypeCreator() { @Override public Object create(Object... params) { Object object; try { - object = clazz.getDeclaredConstructor().newInstance(); + object = typeDescriptor.getRawType().getDeclaredConstructor().newInstance(); } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java index ddebbeb2bffe..6f3e598f5314 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java @@ -101,7 +101,7 @@ public ProviderAndDescriptor( try { return new ProviderAndDescriptor( providerClass.getDeclaredConstructor().newInstance(), - TypeDescriptor.of(clazz)); + typeDescriptor.getSupertype((Class) clazz)); } catch (NoSuchMethodException | InstantiationException | IllegalAccessException diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java index 2ec0a9a60cd6..54e2a595fa71 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java @@ -299,7 +299,7 @@ private static FunctionAndType createFunctionFromName(String name, String path) private static class EmptyFieldValueTypeSupplier implements org.apache.beam.sdk.schemas.utils.FieldValueTypeSupplier { @Override - public List get(Class clazz) { + public List get(TypeDescriptor typeDescriptor) { return Collections.emptyList(); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java index dcbbf70888d3..bbf831ff0cc5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java @@ -71,18 +71,19 @@ "rawtypes" }) public class AutoValueUtils { - public static Class getBaseAutoValueClass(Class clazz) { + public static TypeDescriptor getBaseAutoValueClass(TypeDescriptor typeDescriptor) { // AutoValue extensions may be nested - while (clazz != null && clazz.getName().contains("AutoValue_")) { - clazz = clazz.getSuperclass(); + while (typeDescriptor != null && typeDescriptor.getRawType().getName().contains("AutoValue_")) { + typeDescriptor = typeDescriptor.getSupertype(typeDescriptor.getRawType().getSuperclass()); } - return clazz; + return typeDescriptor; } - private static Class getAutoValueGenerated(Class clazz) { - String generatedClassName = getAutoValueGeneratedName(clazz.getName()); + private static TypeDescriptor getAutoValueGenerated( + TypeDescriptor typeDescriptor) { + String generatedClassName = getAutoValueGeneratedName(typeDescriptor.getRawType().getName()); try { - return Class.forName(generatedClassName); + return typeDescriptor.getSubtype((Class) Class.forName(generatedClassName)); } catch (ClassNotFoundException e) { throw new IllegalStateException("AutoValue generated class not found: " + generatedClassName); } @@ -121,19 +122,22 @@ private static String getAutoValueGeneratedName(String baseClass) { * Try to find an accessible constructor for creating an AutoValue class. Otherwise return null. */ public static @Nullable SchemaUserTypeCreator getConstructorCreator( - Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { - Class generatedClass = getAutoValueGenerated(clazz); - List schemaTypes = fieldValueTypeSupplier.get(clazz, schema); + TypeDescriptor typeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier) { + TypeDescriptor generatedTypeDescriptor = getAutoValueGenerated(typeDescriptor); + List schemaTypes = + fieldValueTypeSupplier.get(typeDescriptor, schema); Optional> constructor = - Arrays.stream(generatedClass.getDeclaredConstructors()) + Arrays.stream(generatedTypeDescriptor.getRawType().getDeclaredConstructors()) .filter(c -> !Modifier.isPrivate(c.getModifiers())) - .filter(c -> matchConstructor(c, schemaTypes)) + .filter(c -> matchConstructor(generatedTypeDescriptor, c, schemaTypes)) .findAny(); return constructor .map( c -> JavaBeanUtils.getConstructorCreator( - generatedClass, + generatedTypeDescriptor, c, schema, fieldValueTypeSupplier, @@ -142,7 +146,9 @@ private static String getAutoValueGeneratedName(String baseClass) { } private static boolean matchConstructor( - Constructor constructor, List getterTypes) { + TypeDescriptor typeDescriptor, + Constructor constructor, + List getterTypes) { if (constructor.getParameters().length != getterTypes.size()) { return false; } @@ -158,7 +164,8 @@ private static boolean matchConstructor( // Verify that constructor parameters match (name and type) the inferred schema. for (Parameter parameter : constructor.getParameters()) { FieldValueTypeInformation type = typeMap.get(parameter.getName()); - if (type == null || type.getRawType() != parameter.getType()) { + if (type == null + || !type.getType().equals(typeDescriptor.resolveType(parameter.getParameterizedType()))) { valid = false; break; } @@ -187,8 +194,10 @@ private static boolean matchConstructor( * Try to find an accessible builder class for creating an AutoValue class. Otherwise return null. */ public static @Nullable SchemaUserTypeCreator getBuilderCreator( - Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { - Class builderClass = getAutoValueGeneratedBuilder(clazz); + TypeDescriptor typeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier) { + Class builderClass = getAutoValueGeneratedBuilder(typeDescriptor.getRawType()); if (builderClass == null) { return null; } @@ -196,12 +205,13 @@ private static boolean matchConstructor( Map setterTypes = ReflectUtils.getMethods(builderClass).stream() .filter(ReflectUtils::isSetter) - .map(FieldValueTypeInformation::forSetter) + .map(m -> FieldValueTypeInformation.forSetter(typeDescriptor, m)) .collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity())); List setterMethods = Lists.newArrayList(); // The builder methods to call in order. - List schemaTypes = fieldValueTypeSupplier.get(clazz, schema); + List schemaTypes = + fieldValueTypeSupplier.get(typeDescriptor, schema); for (FieldValueTypeInformation type : schemaTypes) { String autoValueFieldName = ReflectUtils.stripGetterPrefix(type.getMethod().getName()); @@ -214,6 +224,10 @@ private static boolean matchConstructor( + "a setter for " + autoValueFieldName); } + if (setterType.getType().hasUnresolvedParameters()) { + // copy the types from the getter in the builder's target class + setterType = setterType.withTypesFrom(type); + } setterMethods.add(setterType); } @@ -298,11 +312,10 @@ public ByteCodeAppender appender(final Target implementationTarget) { TypeConversion convertType = typeConversionsFactory.createTypeConversion(true); for (int i = 0; i < setters.size(); ++i) { - Method setterMethod = checkNotNull(setters.get(i).getMethod()); - Parameter parameter = setterMethod.getParameters()[0]; + FieldValueTypeInformation setterType = setters.get(i); + Method setterMethod = checkNotNull(setterType.getMethod()); ForLoadedType convertedType = - new ForLoadedType( - (Class) convertType.convert(TypeDescriptor.of(parameter.getParameterizedType()))); + new ForLoadedType((Class) convertType.convert(setterType.getType())); StackManipulation readParameter = new StackManipulation.Compound( @@ -317,7 +330,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Duplication.SINGLE, typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(parameter.getType())), + .convert(setterType.getType()), MethodInvocation.invoke(new ForLoadedMethod(setterMethod)), Removal.SINGLE); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index c2b33c2d2315..81293d9afb79 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -153,7 +153,7 @@ protected String name(TypeDescription superClass) { private static boolean overridePackage(@Nullable String targetPackage) { return targetPackage != null && !targetPackage.startsWith("java."); } - }; + } static class IfNullElse implements StackManipulation { private final StackManipulation readValue; @@ -361,8 +361,8 @@ protected Type convertIterable(TypeDescriptor type) { } @Override - protected Type convertMap(TypeDescriptor type) { - return Map.class; + protected Type convertMap(TypeDescriptor type) { + return returnRawTypes ? Map.class : type.getSupertype(Map.class).getType(); } @Override @@ -395,11 +395,19 @@ protected Type convertDefault(TypeDescriptor type) { return returnRawTypes ? type.getRawType() : type.getType(); } + public static TypeDescriptor primitiveToWrapper(TypeDescriptor typeDescriptor) { + Class cls = typeDescriptor.getRawType(); + if (cls.isPrimitive()) { + return TypeDescriptor.of(ClassUtils.primitiveToWrapper(cls)); + } else { + return typeDescriptor; + } + } + @SuppressWarnings("unchecked") private TypeDescriptor> createCollectionType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = primitiveToWrapper(componentType); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -407,8 +415,7 @@ private TypeDescriptor> createCollectionType( @SuppressWarnings("unchecked") private TypeDescriptor> createIterableType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = primitiveToWrapper(componentType); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -1472,10 +1479,9 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Push all creator parameters on the stack. TypeConversion convertType = typeConversionsFactory.createTypeConversion(true); for (int i = 0; i < parameters.size(); i++) { - Parameter parameter = parameters.get(i); + FieldValueTypeInformation fieldType = fields.get(fieldMapping.get(i)); ForLoadedType convertedType = - new ForLoadedType( - (Class) convertType.convert(TypeDescriptor.of(parameter.getType()))); + new ForLoadedType((Class) convertType.convert(fieldType.getType())); // The instruction to read the parameter. Use the fieldMapping to reorder parameters as // necessary. @@ -1490,7 +1496,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { stackManipulation, typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(parameter.getParameterizedType()))); + .convert(fieldType.getType())); } stackManipulation = new StackManipulation.Compound( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/FieldValueTypeSupplier.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/FieldValueTypeSupplier.java index d93456b21949..693997f64aa0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/FieldValueTypeSupplier.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/FieldValueTypeSupplier.java @@ -21,6 +21,7 @@ import java.util.List; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.values.TypeDescriptor; /** * A naming policy for schema fields. This maps a name from the class (field name or getter name) to @@ -28,7 +29,7 @@ */ public interface FieldValueTypeSupplier extends Serializable { /** Return all the FieldValueTypeInformations. */ - List get(Class clazz); + List get(TypeDescriptor typeDescriptor); /** * Return all the FieldValueTypeInformations. @@ -36,7 +37,7 @@ public interface FieldValueTypeSupplier extends Serializable { *

If the schema parameter is not null, then the returned list must be in the same order as * fields in the schema. */ - default List get(Class clazz, Schema schema) { - return StaticSchemaInference.sortBySchema(get(clazz), schema); + default List get(TypeDescriptor typeDescriptor, Schema schema) { + return StaticSchemaInference.sortBySchema(get(typeDescriptor), schema); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java index 6573f25c66e2..3f9c765872a7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java @@ -29,6 +29,7 @@ import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; import net.bytebuddy.description.method.MethodDescription.ForLoadedMethod; +import net.bytebuddy.description.type.TypeDescription; import net.bytebuddy.dynamic.DynamicType; import net.bytebuddy.dynamic.scaffold.InstrumentedType; import net.bytebuddy.implementation.FixedValue; @@ -37,6 +38,7 @@ import net.bytebuddy.implementation.bytecode.ByteCodeAppender.Size; import net.bytebuddy.implementation.bytecode.Removal; import net.bytebuddy.implementation.bytecode.StackManipulation; +import net.bytebuddy.implementation.bytecode.assign.TypeCasting; import net.bytebuddy.implementation.bytecode.member.MethodInvocation; import net.bytebuddy.implementation.bytecode.member.MethodReturn; import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; @@ -51,8 +53,9 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.InjectPackageStrategy; import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.StaticFactoryMethodInstruction; import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversionsFactory; -import org.apache.beam.sdk.schemas.utils.ReflectUtils.ClassWithSchema; +import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; /** A set of utilities to generate getter and setter classes for JavaBean objects. */ @@ -62,13 +65,20 @@ }) public class JavaBeanUtils { /** Create a {@link Schema} for a Java Bean class. */ + public static Schema schemaFromJavaBeanClass( + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { + return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); + } + public static Schema schemaFromJavaBeanClass( Class clazz, FieldValueTypeSupplier fieldValueTypeSupplier) { - return StaticSchemaInference.schemaFromClass(clazz, fieldValueTypeSupplier); + return schemaFromJavaBeanClass(TypeDescriptor.of(clazz), fieldValueTypeSupplier); } private static final String CONSTRUCTOR_HELP_STRING = - "In order to infer a Schema from a Java Bean, it must have a constructor annotated with @SchemaCreate, or it must have a compatible setter for every getter used as a Schema field."; + "In order to infer a Schema from a Java Bean, it must have a constructor annotated with" + + " @SchemaCreate, or it must have a compatible setter for every getter used as a Schema" + + " field."; // Make sure that there are matching setters and getters. public static void validateJavaBean( @@ -111,19 +121,31 @@ public static void validateJavaBean( // Static ByteBuddy instance used by all helpers. private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); - private static final Map> CACHED_FIELD_TYPES = - Maps.newConcurrentMap(); + private static final Map> + CACHED_FIELD_TYPES = Maps.newConcurrentMap(); public static List getFieldTypes( - Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { + TypeDescriptor typeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier) { return CACHED_FIELD_TYPES.computeIfAbsent( - ClassWithSchema.create(clazz, schema), c -> fieldValueTypeSupplier.get(clazz, schema)); + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> fieldValueTypeSupplier.get(typeDescriptor, schema)); } // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = - Maps.newConcurrentMap(); + private static final Map> + CACHED_GETTERS = Maps.newConcurrentMap(); + + public static List getGetters( + Class clazz, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier, + TypeConversionsFactory typeConversionsFactory) { + return getGetters( + TypeDescriptor.of(clazz), schema, fieldValueTypeSupplier, typeConversionsFactory); + } /** * Return the list of {@link FieldValueGetter}s for a Java Bean class @@ -131,21 +153,22 @@ public static List getFieldTypes( *

The returned list is ordered by the order of fields in the schema. */ public static List getGetters( - Class clazz, + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_GETTERS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + TypeDescriptorWithSchema.create(typeDescriptor, schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); return types.stream() .map(t -> createGetter(t, typeConversionsFactory)) .collect(Collectors.toList()); }); } - public static FieldValueGetter createGetter( + public static FieldValueGetter createGetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { DynamicType.Builder builder = ByteBuddyUtils.subclassGetterInterface( @@ -186,8 +209,17 @@ private static DynamicType.Builder implementGetterMethods( // The list of setters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_SETTERS = - Maps.newConcurrentMap(); + private static final Map> + CACHED_SETTERS = Maps.newConcurrentMap(); + + public static List getSetters( + Class clazz, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier, + TypeConversionsFactory typeConversionsFactory) { + return getSetters( + TypeDescriptor.of(clazz), schema, fieldValueTypeSupplier, typeConversionsFactory); + } /** * Return the list of {@link FieldValueSetter}s for a Java Bean class @@ -195,14 +227,15 @@ private static DynamicType.Builder implementGetterMethods( *

The returned list is ordered by the order of fields in the schema. */ public static List getSetters( - Class clazz, + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_SETTERS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + TypeDescriptorWithSchema.create(typeDescriptor, schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); return types.stream() .map(t -> createSetter(t, typeConversionsFactory)) .collect(Collectors.toList()); @@ -250,21 +283,22 @@ private static DynamicType.Builder implementSetterMethods( // The list of constructors for a class is cached, so we only create the classes the first time // getConstructor is called. - public static final Map CACHED_CREATORS = - Maps.newConcurrentMap(); + public static final Map + CACHED_CREATORS = Maps.newConcurrentMap(); public static SchemaUserTypeCreator getConstructorCreator( - Class clazz, + TypeDescriptor typeDescriptor, Constructor constructor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + ReflectUtils.TypeDescriptorWithSchema.create(typeDescriptor, schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); return createConstructorCreator( - clazz, constructor, schema, types, typeConversionsFactory); + typeDescriptor.getRawType(), constructor, schema, types, typeConversionsFactory); }); } @@ -302,16 +336,18 @@ public static SchemaUserTypeCreator createConstructorCreator( } public static SchemaUserTypeCreator getStaticCreator( - Class clazz, + TypeDescriptor typeDescriptor, Method creator, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + ReflectUtils.TypeDescriptorWithSchema.create(typeDescriptor, schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); - return createStaticCreator(clazz, creator, schema, types, typeConversionsFactory); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return createStaticCreator( + typeDescriptor.getRawType(), creator, schema, types, typeConversionsFactory); }); } @@ -371,13 +407,21 @@ public ByteCodeAppender appender(final Target implementationTarget) { // this + method parameters. int numLocals = 1 + instrumentedMethod.getParameters().size(); + StackManipulation cast = + typeInformation + .getRawType() + .isAssignableFrom(typeInformation.getMethod().getReturnType()) + ? StackManipulation.Trivial.INSTANCE + : TypeCasting.to(TypeDescription.ForLoadedType.of(typeInformation.getRawType())); + // StackManipulation that will read the value from the class field. StackManipulation readValue = new StackManipulation.Compound( // Method param is offset 1 (offset 0 is the this parameter). MethodVariableAccess.REFERENCE.loadFrom(1), // Invoke the getter - MethodInvocation.invoke(new ForLoadedMethod(typeInformation.getMethod()))); + MethodInvocation.invoke(new ForLoadedMethod(typeInformation.getMethod())), + cast); StackManipulation stackManipulation = new StackManipulation.Compound( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index 93875a20707f..2d49a68d82a8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -30,6 +30,7 @@ import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; import net.bytebuddy.description.field.FieldDescription.ForLoadedField; +import net.bytebuddy.description.type.TypeDescription; import net.bytebuddy.description.type.TypeDescription.ForLoadedType; import net.bytebuddy.dynamic.DynamicType; import net.bytebuddy.dynamic.scaffold.InstrumentedType; @@ -59,7 +60,7 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.StaticFactoryMethodInstruction; import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversion; import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversionsFactory; -import org.apache.beam.sdk.schemas.utils.ReflectUtils.ClassWithSchema; +import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; @@ -71,26 +72,34 @@ "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) }) public class POJOUtils { + public static Schema schemaFromPojoClass( + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { + return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); + } + public static Schema schemaFromPojoClass( Class clazz, FieldValueTypeSupplier fieldValueTypeSupplier) { - return StaticSchemaInference.schemaFromClass(clazz, fieldValueTypeSupplier); + return schemaFromPojoClass(TypeDescriptor.of(clazz), fieldValueTypeSupplier); } // Static ByteBuddy instance used by all helpers. private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); - private static final Map> CACHED_FIELD_TYPES = - Maps.newConcurrentMap(); + private static final Map> + CACHED_FIELD_TYPES = Maps.newConcurrentMap(); public static List getFieldTypes( - Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { + TypeDescriptor typeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier) { return CACHED_FIELD_TYPES.computeIfAbsent( - ClassWithSchema.create(clazz, schema), c -> fieldValueTypeSupplier.get(clazz, schema)); + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> fieldValueTypeSupplier.get(typeDescriptor, schema)); } // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = + private static final Map> CACHED_GETTERS = Maps.newConcurrentMap(); public static List getGetters( @@ -98,18 +107,31 @@ public static List getGetters( Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { + return getGetters( + TypeDescriptor.of(clazz), schema, fieldValueTypeSupplier, typeConversionsFactory); + } + + public static List getGetters( + TypeDescriptor typeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier, + TypeConversionsFactory typeConversionsFactory) { // Return the getters ordered by their position in the schema. return CACHED_GETTERS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + TypeDescriptorWithSchema.create(typeDescriptor, schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); List getters = types.stream() .map(t -> createGetter(t, typeConversionsFactory)) .collect(Collectors.toList()); if (getters.size() != schema.getFieldCount()) { throw new RuntimeException( - "Was not able to generate getters for schema: " + schema + " class: " + clazz); + "Was not able to generate getters for schema: " + + schema + + " class: " + + typeDescriptor); } return getters; }); @@ -117,24 +139,25 @@ public static List getGetters( // The list of constructors for a class is cached, so we only create the classes the first time // getConstructor is called. - public static final Map CACHED_CREATORS = + public static final Map CACHED_CREATORS = Maps.newConcurrentMap(); public static SchemaUserTypeCreator getSetFieldCreator( - Class clazz, + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + TypeDescriptorWithSchema.create(typeDescriptor, schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); - return createSetFieldCreator(clazz, schema, types, typeConversionsFactory); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return createSetFieldCreator(typeDescriptor, schema, types, typeConversionsFactory); }); } private static SchemaUserTypeCreator createSetFieldCreator( - Class clazz, + TypeDescriptor typeDescriptor, Schema schema, List types, TypeConversionsFactory typeConversionsFactory) { @@ -144,17 +167,19 @@ private static SchemaUserTypeCreator createSetFieldCreator( try { DynamicType.Builder builder = BYTE_BUDDY - .with(new InjectPackageStrategy(clazz)) + .with(new InjectPackageStrategy(typeDescriptor.getRawType())) .subclass(SchemaUserTypeCreator.class) .method(ElementMatchers.named("create")) - .intercept(new SetFieldCreateInstruction(fields, clazz, typeConversionsFactory)); + .intercept( + new SetFieldCreateInstruction( + fields, typeDescriptor.getRawType(), typeConversionsFactory)); return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader(clazz.getClassLoader()), - getClassLoadingStrategy(clazz)) + ReflectHelpers.findClassLoader(typeDescriptor.getRawType().getClassLoader()), + getClassLoadingStrategy(typeDescriptor.getRawType())) .getLoaded() .getDeclaredConstructor() .newInstance(); @@ -165,23 +190,26 @@ private static SchemaUserTypeCreator createSetFieldCreator( | InvocationTargetException e) { throw new RuntimeException( String.format( - "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must have a zero-argument constructor, or a constructor annotated with @SchemaCreate.", - clazz, schema)); + "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must" + + " have a zero-argument constructor, or a constructor annotated with" + + " @SchemaCreate.", + typeDescriptor, schema)); } } public static SchemaUserTypeCreator getConstructorCreator( - Class clazz, + TypeDescriptor typeDescriptor, Constructor constructor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + TypeDescriptorWithSchema.create(typeDescriptor, schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); return createConstructorCreator( - clazz, constructor, schema, types, typeConversionsFactory); + typeDescriptor.getRawType(), constructor, schema, types, typeConversionsFactory); }); } @@ -220,16 +248,18 @@ public static SchemaUserTypeCreator createConstructorCreator( } public static SchemaUserTypeCreator getStaticCreator( - Class clazz, + TypeDescriptor typeDescriptor, Method creator, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + TypeDescriptorWithSchema.create(typeDescriptor, schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); - return createStaticCreator(clazz, creator, schema, types, typeConversionsFactory); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return createStaticCreator( + typeDescriptor.getRawType(), creator, schema, types, typeConversionsFactory); }); } @@ -286,11 +316,8 @@ public static SchemaUserTypeCreator createStaticCreator( ByteBuddyUtils.subclassGetterInterface( BYTE_BUDDY, field.getDeclaringClass(), - typeConversionsFactory - .createTypeConversion(false) - .convert(TypeDescriptor.of(field.getType()))); - builder = - implementGetterMethods(builder, field, typeInformation.getName(), typeConversionsFactory); + typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); + builder = implementGetterMethods(builder, typeInformation, typeConversionsFactory); try { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) @@ -311,20 +338,19 @@ public static SchemaUserTypeCreator createStaticCreator( private static DynamicType.Builder implementGetterMethods( DynamicType.Builder builder, - Field field, - String name, + FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .method(ElementMatchers.named("name")) - .intercept(FixedValue.reference(name)) + .intercept(FixedValue.reference(typeInformation.getName())) .method(ElementMatchers.named("get")) - .intercept(new ReadFieldInstruction(field, typeConversionsFactory)); + .intercept(new ReadFieldInstruction(typeInformation, typeConversionsFactory)); } // The list of setters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_SETTERS = + private static final Map> CACHED_SETTERS = Maps.newConcurrentMap(); public static List getSetters( @@ -332,11 +358,21 @@ public static List getSetters( Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { + return getSetters( + TypeDescriptor.of(clazz), schema, fieldValueTypeSupplier, typeConversionsFactory); + } + + public static List getSetters( + TypeDescriptor typeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier, + TypeConversionsFactory typeConversionsFactory) { // Return the setters, ordered by their position in the schema. return CACHED_SETTERS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + TypeDescriptorWithSchema.create(typeDescriptor, schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); return types.stream() .map(t -> createSetter(t, typeConversionsFactory)) .collect(Collectors.toList()); @@ -404,11 +440,12 @@ private static DynamicType.Builder implementSetterMethods( // Implements a method to read a public field out of an object. static class ReadFieldInstruction implements Implementation { // Field that will be read. - private final Field field; + private final FieldValueTypeInformation typeInformation; private final TypeConversionsFactory typeConversionsFactory; - ReadFieldInstruction(Field field, TypeConversionsFactory typeConversionsFactory) { - this.field = field; + ReadFieldInstruction( + FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { + this.typeInformation = typeInformation; this.typeConversionsFactory = typeConversionsFactory; } @@ -423,19 +460,25 @@ public ByteCodeAppender appender(final Target implementationTarget) { // this + method parameters. int numLocals = 1 + instrumentedMethod.getParameters().size(); + StackManipulation cast = + typeInformation.getRawType().isAssignableFrom(typeInformation.getField().getType()) + ? StackManipulation.Trivial.INSTANCE + : TypeCasting.to(TypeDescription.ForLoadedType.of(typeInformation.getRawType())); + // StackManipulation that will read the value from the class field. StackManipulation readValue = new StackManipulation.Compound( // Method param is offset 1 (offset 0 is the this parameter). MethodVariableAccess.REFERENCE.loadFrom(1), // Read the field from the object. - FieldAccess.forField(new ForLoadedField(field)).read()); + FieldAccess.forField(new ForLoadedField(typeInformation.getField())).read(), + cast); StackManipulation stackManipulation = new StackManipulation.Compound( typeConversionsFactory .createGetterConversions(readValue) - .convert(TypeDescriptor.of(field.getGenericType())), + .convert(TypeDescriptor.of(typeInformation.getField().getGenericType())), MethodReturn.REFERENCE); StackManipulation.Size size = stackManipulation.apply(methodVisitor, implementationContext); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index f3888a5ed443..da875de8b610 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -63,6 +63,18 @@ public static ClassWithSchema create(Class clazz, Schema schema) { } } + /** Represents a type descriptor and a schema. */ + @AutoValue + public abstract static class TypeDescriptorWithSchema { + public abstract TypeDescriptor getTypeDescriptor(); + + public abstract Schema getSchema(); + + public static TypeDescriptorWithSchema create(TypeDescriptor typeDescriptor, Schema schema) { + return new AutoValue_ReflectUtils_TypeDescriptorWithSchema(typeDescriptor, schema); + } + } + private static final Map, List> DECLARED_METHODS = Maps.newConcurrentMap(); private static final Map, Method> ANNOTATED_CONSTRUCTORS = Maps.newConcurrentMap(); private static final Map, List> DECLARED_FIELDS = Maps.newConcurrentMap(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java index 72d79adb8288..5f8beacde581 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java @@ -85,25 +85,25 @@ enum MethodType { * public getter methods, or special annotations on the class. */ public static Schema schemaFromClass( - Class clazz, FieldValueTypeSupplier fieldValueTypeSupplier) { - return schemaFromClass(clazz, fieldValueTypeSupplier, new HashMap()); + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { + return schemaFromClass(typeDescriptor, fieldValueTypeSupplier, new HashMap<>()); } private static Schema schemaFromClass( - Class clazz, + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier, - Map alreadyVisitedSchemas) { - if (alreadyVisitedSchemas.containsKey(clazz)) { - Schema existingSchema = alreadyVisitedSchemas.get(clazz); + Map, Schema> alreadyVisitedSchemas) { + if (alreadyVisitedSchemas.containsKey(typeDescriptor)) { + Schema existingSchema = alreadyVisitedSchemas.get(typeDescriptor); if (existingSchema == null) { throw new IllegalArgumentException( - "Cannot infer schema with a circular reference. Class: " + clazz.getTypeName()); + "Cannot infer schema with a circular reference. Type: " + typeDescriptor); } return existingSchema; } - alreadyVisitedSchemas.put(clazz, null); + alreadyVisitedSchemas.put(typeDescriptor, null); Schema.Builder builder = Schema.builder(); - for (FieldValueTypeInformation type : fieldValueTypeSupplier.get(clazz)) { + for (FieldValueTypeInformation type : fieldValueTypeSupplier.get(typeDescriptor)) { Schema.FieldType fieldType = fieldFromType(type.getType(), fieldValueTypeSupplier, alreadyVisitedSchemas); Schema.Field f = @@ -116,21 +116,21 @@ private static Schema schemaFromClass( builder.addFields(f); } Schema generatedSchema = builder.build(); - alreadyVisitedSchemas.replace(clazz, generatedSchema); + alreadyVisitedSchemas.replace(typeDescriptor, generatedSchema); return generatedSchema; } /** Map a Java field type to a Beam Schema FieldType. */ public static Schema.FieldType fieldFromType( TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier) { - return fieldFromType(type, fieldValueTypeSupplier, new HashMap()); + return fieldFromType(type, fieldValueTypeSupplier, new HashMap<>()); } // TODO(https://github.com/apache/beam/issues/21567): support type inference for logical types private static Schema.FieldType fieldFromType( - TypeDescriptor type, + TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier, - Map alreadyVisitedSchemas) { + Map, Schema> alreadyVisitedSchemas) { FieldType primitiveType = PRIMITIVE_TYPES.get(type.getRawType()); if (primitiveType != null) { return primitiveType; @@ -154,7 +154,7 @@ private static Schema.FieldType fieldFromType( fieldFromType(component, fieldValueTypeSupplier, alreadyVisitedSchemas)); } } else if (type.isSubtypeOf(TypeDescriptor.of(Map.class))) { - TypeDescriptor> map = type.getSupertype(Map.class); + TypeDescriptor> map = ((TypeDescriptor) type).getSupertype(Map.class); if (map.getType() instanceof ParameterizedType) { ParameterizedType ptype = (ParameterizedType) map.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); @@ -179,7 +179,7 @@ private static Schema.FieldType fieldFromType( } else if (type.isSubtypeOf(TypeDescriptor.of(ByteBuffer.class))) { return FieldType.BYTES; } else if (type.isSubtypeOf(TypeDescriptor.of(Iterable.class))) { - TypeDescriptor> iterable = type.getSupertype(Iterable.class); + TypeDescriptor> iterable = ((TypeDescriptor) type).getSupertype(Iterable.class); if (iterable.getType() instanceof ParameterizedType) { ParameterizedType ptype = (ParameterizedType) iterable.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); @@ -198,8 +198,7 @@ private static Schema.FieldType fieldFromType( throw new RuntimeException("Cannot infer schema from unparameterized collection."); } } else { - return FieldType.row( - schemaFromClass(type.getRawType(), fieldValueTypeSupplier, alreadyVisitedSchemas)); + return FieldType.row(schemaFromClass(type, fieldValueTypeSupplier, alreadyVisitedSchemas)); } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index ee3852d70bbe..640983997175 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -837,9 +837,11 @@ public int nextFieldId() { @Internal public Row withFieldValueGetters( - Factory> fieldValueGetterFactory, Object getterTarget) { + TypeDescriptor getterTargetType, + Factory> fieldValueGetterFactory, + Object getterTarget) { checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + return new RowWithGetters(getterTargetType, schema, fieldValueGetterFactory, getterTarget); } public Row build() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index cb4d83550577..2e30b355a93e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -45,13 +45,22 @@ public class RowWithGetters extends Row { private final Object getterTarget; private final List getters; + private final TypeDescriptor getterTargetType; private @Nullable Map cache = null; RowWithGetters( - Schema schema, Factory> getterFactory, Object getterTarget) { + TypeDescriptor getterTargetType, + Schema schema, + Factory> getterFactory, + Object getterTarget) { super(schema); this.getterTarget = getterTarget; - this.getters = getterFactory.create(getterTarget.getClass(), schema); + this.getterTargetType = getterTargetType; + this.getters = getterFactory.create(getterTargetType, schema); + } + + public TypeDescriptor getGetterTargetType() { + return getterTargetType; } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java index 045662d1680c..5bc8fce562e2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TypeDescriptor.java @@ -190,6 +190,10 @@ public final TypeDescriptor getSupertype(Class superclass) return new SimpleTypeDescriptor<>(token.getSupertype(superclass)); } + public final TypeDescriptor getSubtype(Class subclass) { + return new SimpleTypeDescriptor<>(token.getSubtype(subclass)); + } + /** Returns true if this type is known to be an array type. */ public final boolean isArray() { return token.isArray(); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java index d0ee623dea7c..66157721054e 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.assertSchemaEquivalent; import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.equivalentTo; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertArrayEquals; @@ -28,6 +29,8 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; @@ -37,9 +40,13 @@ import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; import org.apache.beam.sdk.schemas.annotations.SchemaFieldNumber; import org.apache.beam.sdk.schemas.utils.SchemaTestUtils; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.DateTime; import org.joda.time.Instant; import org.junit.Test; @@ -70,7 +77,7 @@ public class AutoValueSchemaTest { .build(); static final Schema OUTER_SCHEMA = Schema.builder().addRowField("inner", SIMPLE_SCHEMA).build(); - private Row createSimpleRow(String name) { + private static Row createSimpleRow(String name) { return Row.withSchema(SIMPLE_SCHEMA) .addValues( name, @@ -348,6 +355,48 @@ abstract static class Builder { } } + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class GenericAutoValue { + public abstract T getT(); + + public static GenericAutoValue create(T t) { + return new AutoValue_AutoValueSchemaTest_GenericAutoValue<>(t); + } + } + + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class GenericAutoValueWithBuilder { + public abstract T getT(); + + GenericAutoValueWithBuilder() {} + + public static Builder builder() { + return new AutoValue_AutoValueSchemaTest_GenericAutoValueWithBuilder.Builder<>(); + } + + @AutoValue.Builder + abstract static class Builder { + public abstract Builder setT(T t); + + public abstract GenericAutoValueWithBuilder build(); + } + } + + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class GenericAutoValueWithCreator { + public abstract T getT(); + + GenericAutoValueWithCreator() {} + + @SchemaCreate + public static GenericAutoValueWithCreator create(T t) { + return new AutoValue_AutoValueSchemaTest_GenericAutoValueWithCreator<>(t); + } + } + private void verifyRow(Row row) { assertEquals("string", row.getString("str")); assertEquals((byte) 1, (Object) row.getByte("aByte")); @@ -385,6 +434,361 @@ public void testSchema() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(SIMPLE_SCHEMA, schema); } + @Test + public void testGenericAutoValueSchema() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema actual = registry.getSchema(new TypeDescriptor>() {}); + Schema expected = Schema.builder().addRowField("t", SIMPLE_SCHEMA).build(); + assertSchemaEquivalent(expected, actual); + } + + @Test + public void testGenericAutoValueToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction, Row> toRow = + registry.getToRowFunction(new TypeDescriptor>() {}); + Row row = + toRow.apply( + GenericAutoValue.create( + new AutoValue_AutoValueSchemaTest_SimpleAutoValue( + "string", + (byte) 1, + (short) 2, + 3, + 4L, + true, + DATE, + BYTE_ARRAY, + ByteBuffer.wrap(BYTE_ARRAY), + DATE.toInstant(), + BigDecimal.ONE, + STRING_BUILDER))); + + verifyRow(row.getRow("t")); + } + + @Test + public void testGenericAutoValueFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction> fromRow = + registry.getFromRowFunction(new TypeDescriptor>() {}); + + Row row = + Row.withSchema(Schema.builder().addRowField("t", SIMPLE_SCHEMA).build()) + .withFieldValue("t", createSimpleRow("string")) + .build(); + GenericAutoValue actual = fromRow.apply(row); + verifyAutoValue(actual.getT()); + } + + @Test + public void testGenericAutoValueWithCreatorFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction> fromRow = + registry.getFromRowFunction( + new TypeDescriptor>() {}); + + Row row = + Row.withSchema(Schema.builder().addRowField("t", SIMPLE_SCHEMA).build()) + .withFieldValue("t", createSimpleRow("string")) + .build(); + GenericAutoValueWithCreator actual = fromRow.apply(row); + verifyAutoValue(actual.getT()); + } + + @Test + public void testGenericAutoValueWithBuilderFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction> fromRow = + registry.getFromRowFunction( + new TypeDescriptor>() {}); + + Row row = + Row.withSchema(Schema.builder().addRowField("t", SIMPLE_SCHEMA).build()) + .withFieldValue("t", createSimpleRow("string")) + .build(); + GenericAutoValueWithBuilder actual = fromRow.apply(row); + verifyAutoValue(actual.getT()); + } + + @Test + public void testGenericAutoValueBuilderOfMapOfCreatorsFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction< + Row, GenericAutoValueWithBuilder>>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithBuilder< + Map>>>() {}); + + Schema mapValueSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder() + .addMapField("t", FieldType.STRING, FieldType.row(mapValueSchema)) + .build()) + .withFieldValue( + "t", + ImmutableMap.builder() + .put("k1", Row.withSchema(mapValueSchema).withFieldValue("t", "v1").build()) + .put("k2", Row.withSchema(mapValueSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithBuilder>> actual = + fromRow.apply(row); + GenericAutoValueWithCreator genericAutoValue1 = + GenericAutoValueWithCreator.create("v1"); + GenericAutoValueWithCreator genericAutoValue2 = + GenericAutoValueWithCreator.create("v2"); + + assertEquals(genericAutoValue1, actual.getT().get("k1")); + assertEquals(genericAutoValue2, actual.getT().get("k2")); + } + + @Test + public void testGenericAutoValueCreatorOfMapOfBuildersFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction< + Row, GenericAutoValueWithCreator>>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithCreator< + Map>>>() {}); + + Schema mapValueSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder() + .addMapField("t", FieldType.STRING, FieldType.row(mapValueSchema)) + .build()) + .withFieldValue( + "t", + ImmutableMap.builder() + .put("k1", Row.withSchema(mapValueSchema).withFieldValue("t", "v1").build()) + .put("k2", Row.withSchema(mapValueSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithCreator>> actual = + fromRow.apply(row); + GenericAutoValueWithBuilder genericAutoValue1 = + GenericAutoValueWithBuilder.builder().setT("v1").build(); + GenericAutoValueWithBuilder genericAutoValue2 = + GenericAutoValueWithBuilder.builder().setT("v2").build(); + + assertEquals(genericAutoValue1, actual.getT().get("k1")); + assertEquals(genericAutoValue2, actual.getT().get("k2")); + } + + @Test + public void testGenericAutoValueBuilderOfListOfCreatorsFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction< + Row, GenericAutoValueWithBuilder>>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithBuilder>>>() {}); + + Schema listElementSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder().addArrayField("t", FieldType.row(listElementSchema)).build()) + .withFieldValue( + "t", + ImmutableList.builder() + .add(Row.withSchema(listElementSchema).withFieldValue("t", "v1").build()) + .add(Row.withSchema(listElementSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithBuilder>> actual = + fromRow.apply(row); + GenericAutoValueWithCreator genericAutoValue1 = + GenericAutoValueWithCreator.create("v1"); + GenericAutoValueWithCreator genericAutoValue2 = + GenericAutoValueWithCreator.create("v2"); + + assertEquals(genericAutoValue1, actual.getT().get(0)); + assertEquals(genericAutoValue2, actual.getT().get(1)); + } + + @Test + public void testGenericAutoValueCreatorOfListOfBuildersFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction< + Row, GenericAutoValueWithCreator>>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithCreator>>>() {}); + + Schema listElementSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder().addArrayField("t", FieldType.row(listElementSchema)).build()) + .withFieldValue( + "t", + ImmutableList.builder() + .add(Row.withSchema(listElementSchema).withFieldValue("t", "v1").build()) + .add(Row.withSchema(listElementSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithCreator>> actual = + fromRow.apply(row); + GenericAutoValueWithBuilder genericAutoValue1 = + GenericAutoValueWithBuilder.builder().setT("v1").build(); + GenericAutoValueWithBuilder genericAutoValue2 = + GenericAutoValueWithBuilder.builder().setT("v2").build(); + + assertEquals(genericAutoValue1, actual.getT().get(0)); + assertEquals(genericAutoValue2, actual.getT().get(1)); + } + + @Test + public void testGenericAutoValueBuilderOfArrayOfCreatorsFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction[]>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithBuilder[]>>() {}); + + Schema arrayElementSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder().addArrayField("t", FieldType.row(arrayElementSchema)).build()) + .withFieldValue( + "t", + ImmutableList.builder() + .add(Row.withSchema(arrayElementSchema).withFieldValue("t", "v1").build()) + .add(Row.withSchema(arrayElementSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithBuilder[]> actual = fromRow.apply(row); + GenericAutoValueWithCreator genericAutoValue1 = + GenericAutoValueWithCreator.create("v1"); + GenericAutoValueWithCreator genericAutoValue2 = + GenericAutoValueWithCreator.create("v2"); + + assertEquals(genericAutoValue1, actual.getT()[0]); + assertEquals(genericAutoValue2, actual.getT()[1]); + } + + @Test + public void testGenericAutoValueCreatorOfArrayOfBuildersFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction[]>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithCreator[]>>() {}); + + Schema arrayElementSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder().addArrayField("t", FieldType.row(arrayElementSchema)).build()) + .withFieldValue( + "t", + ImmutableList.builder() + .add(Row.withSchema(arrayElementSchema).withFieldValue("t", "v1").build()) + .add(Row.withSchema(arrayElementSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithCreator[]> actual = fromRow.apply(row); + GenericAutoValueWithBuilder genericAutoValue1 = + GenericAutoValueWithBuilder.builder().setT("v1").build(); + GenericAutoValueWithBuilder genericAutoValue2 = + GenericAutoValueWithBuilder.builder().setT("v2").build(); + + assertEquals(genericAutoValue1, actual.getT()[0]); + assertEquals(genericAutoValue2, actual.getT()[1]); + } + + @Test + public void testGenericAutoValueWithMapToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction>>, Row> toRow = + registry.getToRowFunction( + new TypeDescriptor>>>() {}); + + GenericAutoValue genericAutoValue1 = GenericAutoValue.create("v1"); + GenericAutoValue genericAutoValue2 = GenericAutoValue.create("v2"); + + Row row = + toRow.apply( + GenericAutoValue.create( + ImmutableMap.of("k1", genericAutoValue1, "k2", genericAutoValue2))); + + assertEquals("v1", row.getMap("t").get("k1").getString("t")); + assertEquals("v2", row.getMap("t").get("k2").getString("t")); + } + + @Test + public void testGenericAutoValueWithListToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction>>, Row> toRow = + registry.getToRowFunction( + new TypeDescriptor>>>() {}); + + GenericAutoValue genericAutoValue1 = GenericAutoValue.create("v1"); + GenericAutoValue genericAutoValue2 = GenericAutoValue.create("v2"); + + Row row = + toRow.apply( + GenericAutoValue.create(ImmutableList.of(genericAutoValue1, genericAutoValue2))); + Row[] genericAutoValueRows = row.getArray("t").toArray(new Row[0]); + + assertEquals("v1", genericAutoValueRows[0].getString("t")); + assertEquals("v2", genericAutoValueRows[1].getString("t")); + } + + @Test + public void testGenericAutoValueWithArrayToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction[]>, Row> toRow = + registry.getToRowFunction( + new TypeDescriptor[]>>() {}); + + GenericAutoValue genericAutoValue1 = GenericAutoValue.create("v1"); + GenericAutoValue genericAutoValue2 = GenericAutoValue.create("v2"); + + @SuppressWarnings("unchecked") + Row row = + toRow.apply( + GenericAutoValue.create(new GenericAutoValue[] {genericAutoValue1, genericAutoValue2})); + Row[] genericAutoValueRows = row.getArray("t").toArray(new Row[0]); + + assertEquals("v1", genericAutoValueRows[0].getString("t")); + assertEquals("v2", genericAutoValueRows[1].getString("t")); + } + + @Test + public void testNestedGenericAutoValueToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction>>, Row> toRow = + registry.getToRowFunction( + new TypeDescriptor>>>() {}); + + Row row = + toRow.apply( + GenericAutoValue.create(GenericAutoValue.create(GenericAutoValue.create("v1")))); + + assertEquals("v1", row.getRow("t").getRow("t").getString("t")); + } + @Test public void testToRowConstructor() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -402,6 +806,7 @@ public void testToRowConstructor() throws NoSuchSchemaException { DATE.toInstant(), BigDecimal.ONE, STRING_BUILDER); + Row row = registry.getToRowFunction(SimpleAutoValue.class).apply(value); verifyRow(row); } @@ -444,6 +849,7 @@ public void testToRowConstructorMemoized() throws NoSuchSchemaException { DATE.toInstant(), BigDecimal.ONE, STRING_BUILDER); + Row row = registry.getToRowFunction(MemoizedAutoValue.class).apply(value); verifyRow(row); } @@ -571,6 +977,7 @@ public void testToRowNestedConstructor() throws NoSuchSchemaException { DATE.toInstant(), BigDecimal.ONE, STRING_BUILDER); + AutoValueOuter outer = new AutoValue_AutoValueSchemaTest_AutoValueOuter(inner); Row row = registry.getToRowFunction(AutoValueOuter.class).apply(outer); verifyRow(row.getRow("inner")); @@ -675,6 +1082,7 @@ static SimpleAutoValueWithStaticFactory create( Instant instant, BigDecimal bigDecimal, StringBuilder stringBuilder) { + return new AutoValue_AutoValueSchemaTest_SimpleAutoValueWithStaticFactory( str, aByte, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java index 5313feb5c6c0..1d888302e9bd 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.assertSchemaEquivalent; import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.equivalentTo; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.ALL_NULLABLE_BEAN_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.ANNOTATED_SIMPLE_BEAN_SCHEMA; @@ -32,6 +33,7 @@ import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.PRIMITIVE_ARRAY_BEAN_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.RENAMED_FIELDS_AND_SETTERS_BEAM_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.SIMPLE_BEAN_SCHEMA; +import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.genericBeanSchema; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -56,6 +58,7 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.BeanWithCaseFormat; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.BeanWithNoCreateOption; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.BeanWithRenamedFieldsAndSetters; +import org.apache.beam.sdk.schemas.utils.TestJavaBeans.GenericBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.IterableBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.MismatchingNullableBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.NestedArrayBean; @@ -68,9 +71,11 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBeanWithAnnotations; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Ints; import org.joda.time.DateTime; import org.junit.Ignore; @@ -131,6 +136,16 @@ private Row createSimpleRow(String name) { .build(); } + private GenericBean createGeneric(T t) { + GenericBean genericBean = new GenericBean<>(); + genericBean.setT(t); + return genericBean; + } + + private Row createGenericRow(Schema.FieldType tFieldType, Object tFieldValue) { + return Row.withSchema(genericBeanSchema(tFieldType)).withFieldValue("t", tFieldValue).build(); + } + @Test public void testSchema() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -138,14 +153,9 @@ public void testSchema() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(SIMPLE_BEAN_SCHEMA, schema); } - @Test - public void testToRow() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - SimpleBean bean = createSimple("string"); - Row row = registry.getToRowFunction(SimpleBean.class).apply(bean); - + private static void verifyRow(String expectedStrField, Row row) { assertEquals(12, row.getFieldCount()); - assertEquals("string", row.getString("str")); + assertEquals(expectedStrField, row.getString("str")); assertEquals((byte) 1, (Object) row.getByte("aByte")); assertEquals((short) 2, (Object) row.getInt16("aShort")); assertEquals((int) 3, (Object) row.getInt32("anInt")); @@ -159,13 +169,8 @@ public void testToRow() throws NoSuchSchemaException { assertEquals("stringbuilder", row.getString("stringBuilder")); } - @Test - public void testFromRow() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - Row row = createSimpleRow("string"); - - SimpleBean bean = registry.getFromRowFunction(SimpleBean.class).apply(row); - assertEquals("string", bean.getStr()); + private static void verifySimpleBean(String expectedStrField, SimpleBean bean) { + assertEquals(expectedStrField, bean.getStr()); assertEquals((byte) 1, bean.getaByte()); assertEquals((short) 2, bean.getaShort()); assertEquals((int) 3, bean.getAnInt()); @@ -179,6 +184,23 @@ public void testFromRow() throws NoSuchSchemaException { assertEquals("stringbuilder", bean.getStringBuilder().toString()); } + @Test + public void testToRow() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SimpleBean bean = createSimple("string"); + Row row = registry.getToRowFunction(SimpleBean.class).apply(bean); + verifyRow("string", row); + } + + @Test + public void testFromRow() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = createSimpleRow("string"); + + SimpleBean bean = registry.getFromRowFunction(SimpleBean.class).apply(row); + verifySimpleBean("string", bean); + } + @Test public void testNullableToRow() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -625,4 +647,121 @@ public void testSetterConstructionWithRenamedFields() throws NoSuchSchemaExcepti assertEquals( registry.getFromRowFunction(BeanWithCaseFormat.class).apply(row), beanWithCaseFormat); } + + @Test + public void testGenericBeamSchema() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema actual = registry.getSchema(new TypeDescriptor>() {}); + Schema expected = genericBeanSchema(Schema.FieldType.row(SIMPLE_BEAN_SCHEMA)); + + assertSchemaEquivalent(expected, actual); + } + + @Test + public void testGenericBeamSchemaToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + GenericBean> genericBean = + createGeneric(createGeneric(createSimple("string"))); + + Row row = + registry + .getToRowFunction(new TypeDescriptor>>() {}) + .apply(genericBean); + + verifyRow("string", row.getRow("t").getRow("t")); + } + + @Test + public void testGenericBeamSchemaFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema nestedSchema = genericBeanSchema(Schema.FieldType.row(SIMPLE_BEAN_SCHEMA)); + Row row = + createGenericRow( + Schema.FieldType.row(nestedSchema), + createGenericRow(Schema.FieldType.row(SIMPLE_BEAN_SCHEMA), createSimpleRow("string"))); + GenericBean> actual = + registry + .getFromRowFunction(new TypeDescriptor>>() {}) + .apply(row); + + verifySimpleBean("string", actual.getT().getT()); + } + + @Test + public void testGenericBeamSchemaMapToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = + registry + .getToRowFunction( + new TypeDescriptor>>>() {}) + .apply( + createGeneric( + ImmutableMap.>builder() + .put("k1", createGeneric("v1")) + .put("k2", createGeneric("v2")) + .build())); + + assertEquals("v1", row.getMap("t").get("k1").getString("t")); + assertEquals("v2", row.getMap("t").get("k2").getString("t")); + } + + @Test + public void testGenericBeamSchemaMapFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema.FieldType mapValueFieldType = + Schema.FieldType.row(genericBeanSchema(Schema.FieldType.STRING)); + GenericBean>> actual = + registry + .getFromRowFunction( + new TypeDescriptor>>>() {}) + .apply( + createGenericRow( + Schema.FieldType.map(Schema.FieldType.STRING, mapValueFieldType), + ImmutableMap.builder() + .put("k1", createGenericRow(Schema.FieldType.STRING, "v1")) + .put("k2", createGenericRow(Schema.FieldType.STRING, "v2")) + .build())); + + assertEquals("v1", actual.getT().get("k1").getT()); + assertEquals("v2", actual.getT().get("k2").getT()); + } + + @Test + public void testGenericBeamSchemaIterableToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = + registry + .getToRowFunction(new TypeDescriptor>>>() {}) + .apply( + createGeneric( + ImmutableList.>builder() + .add(createGeneric("v1")) + .add(createGeneric("v2")) + .build())); + + Row[] rows = Streams.stream(row.getIterable("t")).toArray(Row[]::new); + + assertEquals("v1", rows[0].getString("t")); + assertEquals("v2", rows[1].getString("t")); + } + + @Test + public void testGenericBeamSchemaIterableFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema.FieldType elementFieldType = + Schema.FieldType.row(genericBeanSchema(Schema.FieldType.STRING)); + GenericBean>> actual = + registry + .getFromRowFunction(new TypeDescriptor>>>() {}) + .apply( + createGenericRow( + Schema.FieldType.array(elementFieldType), + ImmutableList.builder() + .add(createGenericRow(Schema.FieldType.STRING, "v1")) + .add(createGenericRow(Schema.FieldType.STRING, "v2")) + .build())); + GenericBean[] beans = Streams.stream(actual.getT()).toArray(GenericBean[]::new); + assertEquals("v1", beans[0].getT()); + assertEquals("v2", beans[1].getT()); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index 11bef79b26f7..d4ab2e1ac4f5 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.assertSchemaEquivalent; import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.equivalentTo; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.ANNOTATED_SIMPLE_POJO_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.CASE_FORMAT_POJO_SCHEMA; @@ -34,6 +35,7 @@ import static org.apache.beam.sdk.schemas.utils.TestPOJOs.PRIMITIVE_ARRAY_POJO_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.SIMPLE_POJO_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.SIMPLE_POJO_WITH_DESCRIPTION_SCHEMA; +import static org.apache.beam.sdk.schemas.utils.TestPOJOs.genericPOJOSchema; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.junit.Assert.assertArrayEquals; @@ -56,6 +58,7 @@ import org.apache.beam.sdk.schemas.utils.TestPOJOs; import org.apache.beam.sdk.schemas.utils.TestPOJOs.AnnotatedSimplePojo; import org.apache.beam.sdk.schemas.utils.TestPOJOs.FirstCircularNestedPOJO; +import org.apache.beam.sdk.schemas.utils.TestPOJOs.GenericPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.NestedArrayPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.NestedArraysPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.NestedMapPOJO; @@ -76,9 +79,11 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Ints; import org.joda.time.DateTime; import org.joda.time.Instant; @@ -180,6 +185,10 @@ private Row createAnnotatedRow(String name) { .build(); } + private static Row createGenericRow(FieldType tFieldType, Object tFieldValue) { + return Row.withSchema(genericPOJOSchema(tFieldType)).withFieldValue("t", tFieldValue).build(); + } + @Test public void testSchema() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -187,14 +196,9 @@ public void testSchema() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(SIMPLE_POJO_SCHEMA, schema); } - @Test - public void testToRow() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - SimplePOJO pojo = createSimple("string"); - Row row = registry.getToRowFunction(SimplePOJO.class).apply(pojo); - + private static void verifySimpleRow(String expectedStrField, Row row) { assertEquals(12, row.getFieldCount()); - assertEquals("string", row.getString("str")); + assertEquals(expectedStrField, row.getString("str")); assertEquals((byte) 1, (Object) row.getByte("aByte")); assertEquals((short) 2, (Object) row.getInt16("aShort")); assertEquals((int) 3, (Object) row.getInt32("anInt")); @@ -208,13 +212,8 @@ public void testToRow() throws NoSuchSchemaException { assertEquals("stringbuilder", row.getString("stringBuilder")); } - @Test - public void testFromRow() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - Row row = createSimpleRow("string"); - - SimplePOJO pojo = registry.getFromRowFunction(SimplePOJO.class).apply(row); - assertEquals("string", pojo.str); + private static void verifySimplePOJO(String expectedStrField, SimplePOJO pojo) { + assertEquals(expectedStrField, pojo.str); assertEquals((byte) 1, pojo.aByte); assertEquals((short) 2, pojo.aShort); assertEquals((int) 3, pojo.anInt); @@ -228,6 +227,23 @@ public void testFromRow() throws NoSuchSchemaException { assertEquals("stringbuilder", pojo.stringBuilder.toString()); } + @Test + public void testToRow() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SimplePOJO pojo = createSimple("string"); + Row row = registry.getToRowFunction(SimplePOJO.class).apply(pojo); + verifySimpleRow("string", row); + } + + @Test + public void testFromRow() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = createSimpleRow("string"); + + SimplePOJO pojo = registry.getFromRowFunction(SimplePOJO.class).apply(row); + verifySimplePOJO("string", pojo); + } + @Test public void testNullableSchema() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -781,4 +797,118 @@ public void testCircularNestedPOJOThrows() throws NoSuchSchemaException { thrown.getMessage(), containsString("TestPOJOs$FirstCircularNestedPOJO")); } + + @Test + public void testGenericPOJOSchema() throws Exception { + Schema actual = + SchemaRegistry.createDefault() + .getSchema(new TypeDescriptor>>() {}); + Schema expected = + genericPOJOSchema(FieldType.row(genericPOJOSchema(FieldType.row(SIMPLE_POJO_SCHEMA)))); + assertSchemaEquivalent(expected, actual); + } + + @Test + public void testGenericPOJOToRow() throws Exception { + Row row = + SchemaRegistry.createDefault() + .getToRowFunction(new TypeDescriptor>>() {}) + .apply(new GenericPOJO<>(new GenericPOJO<>(createSimple("string")))); + + verifySimpleRow("string", row.getRow("t").getRow("t")); + } + + @Test + public void testGenericPOJOFromRow() throws Exception { + FieldType innerGenericPOJOFieldType = + FieldType.row(genericPOJOSchema(FieldType.row(SIMPLE_POJO_SCHEMA))); + GenericPOJO> actualPOJO = + SchemaRegistry.createDefault() + .getFromRowFunction(new TypeDescriptor>>() {}) + .apply( + createGenericRow( + innerGenericPOJOFieldType, + createGenericRow( + FieldType.row(SIMPLE_POJO_SCHEMA), createSimpleRow("string")))); + + verifySimplePOJO("string", actualPOJO.t.t); + } + + @Test + public void testGenericPOJOSchemaMapToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = + registry + .getToRowFunction( + new TypeDescriptor>>>() {}) + .apply( + new GenericPOJO<>( + ImmutableMap.>builder() + .put("k1", new GenericPOJO<>("v1")) + .put("k2", new GenericPOJO<>("v2")) + .build())); + + assertEquals("v1", row.getMap("t").get("k1").getString("t")); + assertEquals("v2", row.getMap("t").get("k2").getString("t")); + } + + @Test + public void testGenericPOJOSchemaMapFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema.FieldType mapValueFieldType = + Schema.FieldType.row(genericPOJOSchema(Schema.FieldType.STRING)); + GenericPOJO>> actual = + registry + .getFromRowFunction( + new TypeDescriptor>>>() {}) + .apply( + createGenericRow( + Schema.FieldType.map(Schema.FieldType.STRING, mapValueFieldType), + ImmutableMap.builder() + .put("k1", createGenericRow(Schema.FieldType.STRING, "v1")) + .put("k2", createGenericRow(Schema.FieldType.STRING, "v2")) + .build())); + + assertEquals("v1", actual.t.get("k1").t); + assertEquals("v2", actual.t.get("k2").t); + } + + @Test + public void testGenericBeamSchemaIterableToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = + registry + .getToRowFunction(new TypeDescriptor>>>() {}) + .apply( + new GenericPOJO<>( + ImmutableList.>builder() + .add(new GenericPOJO<>("v1")) + .add(new GenericPOJO<>("v2")) + .build())); + + Row[] rows = Streams.stream(row.getIterable("t")).toArray(Row[]::new); + + assertEquals("v1", rows[0].getString("t")); + assertEquals("v2", rows[1].getString("t")); + } + + @Test + public void testGenericBeamSchemaIterableFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema.FieldType elementFieldType = + Schema.FieldType.row(genericPOJOSchema(Schema.FieldType.STRING)); + GenericPOJO>> actual = + registry + .getFromRowFunction(new TypeDescriptor>>>() {}) + .apply( + createGenericRow( + Schema.FieldType.array(elementFieldType), + ImmutableList.builder() + .add(createGenericRow(Schema.FieldType.STRING, "v1")) + .add(createGenericRow(Schema.FieldType.STRING, "v2")) + .build())); + GenericPOJO[] pojos = Streams.stream(actual.t).toArray(GenericPOJO[]::new); + assertEquals("v1", pojos[0].t); + assertEquals("v2", pojos[1].t); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java index b5ad6f989d9e..8d91be530f8c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java @@ -1397,4 +1397,39 @@ public void setValue(@Nullable Float value) { Schema.Field.nullable("value", FieldType.FLOAT) .withDescription("This value is the value stored in the object as a float.")) .build(); + + @DefaultSchema(JavaBeanSchema.class) + public static class GenericBean { + @Nullable private T t; + + @Nullable + public T getT() { + return t; + } + + public void setT(@Nullable T t) { + this.t = t; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GenericBean that = (GenericBean) o; + return Objects.equals(t, that.t); + } + + @Override + public int hashCode() { + return Objects.hashCode(t); + } + } + + public static Schema genericBeanSchema(FieldType genericFieldType) { + return Schema.builder().addNullableField("t", genericFieldType).build(); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java index 789de02adee8..5ab348b0183b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java @@ -1274,4 +1274,18 @@ public int hashCode() { Schema.Field.nullable("str", FieldType.STRING) .withDescription("a simple string that is part of this field")) .build(); + + @DefaultSchema(JavaFieldSchema.class) + public static class GenericPOJO { + public @Nullable T t; + + @SchemaCreate + public GenericPOJO(@Nullable T t) { + this.t = t; + } + } + + public static Schema genericPOJOSchema(FieldType tFieldType) { + return Schema.builder().addNullableField("t", tFieldType).build(); + } } diff --git a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java index 88e6fefcf9d3..12d13a1d2f40 100644 --- a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java +++ b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java @@ -46,6 +46,7 @@ import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; @@ -291,7 +292,7 @@ private FieldVectorListValueGetterFactory(List fieldVectors) { } @Override - public List create(Class clazz, Schema schema) { + public List create(TypeDescriptor typeDescriptor, Schema schema) { return this.fieldVectors.stream() .map( (fieldVector) -> { @@ -484,7 +485,9 @@ public Row next() { throw new IllegalStateException("There are no more Rows."); } Row result = - Row.withSchema(schema).withFieldValueGetters(this.fieldValueGetters, this.currRowIndex); + Row.withSchema(schema) + .withFieldValueGetters( + TypeDescriptor.of(Integer.class), this.fieldValueGetters, this.currRowIndex); this.currRowIndex += 1; return result; } diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java index e3bf24621cd5..b0146728fbfe 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java @@ -44,18 +44,20 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters(Class targetClass, Schema schema) { - return AvroUtils.getGetters(targetClass, schema); + public List fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { + return AvroUtils.getGetters(targetTypeDescriptor.getRawType(), schema); } @Override public List fieldValueTypeInformations( - Class targetClass, Schema schema) { - return AvroUtils.getFieldTypes(targetClass, schema); + TypeDescriptor targetTypeDescriptor, Schema schema) { + return AvroUtils.getFieldTypes(targetTypeDescriptor, schema); } @Override - public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema schema) { - return AvroUtils.getCreator(targetClass, schema); + public SchemaUserTypeCreator schemaTypeCreator( + TypeDescriptor targetTypeDescriptor, Schema schema) { + return AvroUtils.getCreator(targetTypeDescriptor, schema); } } diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java index 7622132c7e27..4df0c6808522 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -808,20 +808,20 @@ public static SchemaCoder schemaCoder(AvroCoder avroCoder) { private static final class AvroSpecificRecordFieldValueTypeSupplier implements FieldValueTypeSupplier { @Override - public List get(Class clazz) { + public List get(TypeDescriptor typeDescriptor) { throw new RuntimeException("Unexpected call."); } @Override - public List get(Class clazz, Schema schema) { + public List get(TypeDescriptor typeDescriptor, Schema schema) { Map mapping = getMapping(schema); - List methods = ReflectUtils.getMethods(clazz); + List methods = ReflectUtils.getMethods(typeDescriptor.getRawType()); List types = Lists.newArrayList(); for (int i = 0; i < methods.size(); ++i) { Method method = methods.get(i); if (ReflectUtils.isGetter(method)) { FieldValueTypeInformation fieldValueTypeInformation = - FieldValueTypeInformation.forGetter(method, i); + FieldValueTypeInformation.forGetter(typeDescriptor, method, i); String name = mapping.get(fieldValueTypeInformation.getName()); if (name != null) { types.add(fieldValueTypeInformation.withName(name)); @@ -864,13 +864,15 @@ private Map getMapping(Schema schema) { private static final class AvroPojoFieldValueTypeSupplier implements FieldValueTypeSupplier { @Override - public List get(Class clazz) { - List classFields = ReflectUtils.getFields(clazz); + public List get(TypeDescriptor typeDescriptor) { + List classFields = + ReflectUtils.getFields(typeDescriptor.getRawType()); Map types = Maps.newHashMap(); for (int i = 0; i < classFields.size(); ++i) { java.lang.reflect.Field f = classFields.get(i); if (!f.isAnnotationPresent(AvroIgnore.class)) { - FieldValueTypeInformation typeInformation = FieldValueTypeInformation.forField(f, i); + FieldValueTypeInformation typeInformation = + FieldValueTypeInformation.forField(typeDescriptor, f, i); AvroName avroname = f.getAnnotation(AvroName.class); if (avroname != null) { typeInformation = typeInformation.withName(avroname.value()); @@ -883,12 +885,13 @@ public List get(Class clazz) { } /** Get field types for an AVRO-generated SpecificRecord or a POJO. */ - public static List getFieldTypes(Class clazz, Schema schema) { - if (TypeDescriptor.of(clazz).isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { + public static List getFieldTypes( + TypeDescriptor typeDescriptor, Schema schema) { + if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { return JavaBeanUtils.getFieldTypes( - clazz, schema, new AvroSpecificRecordFieldValueTypeSupplier()); + typeDescriptor, schema, new AvroSpecificRecordFieldValueTypeSupplier()); } else { - return POJOUtils.getFieldTypes(clazz, schema, new AvroPojoFieldValueTypeSupplier()); + return POJOUtils.getFieldTypes(typeDescriptor, schema, new AvroPojoFieldValueTypeSupplier()); } } @@ -907,12 +910,17 @@ public static List getGetters(Class clazz, Schema schem } /** Get an object creator for an AVRO-generated SpecificRecord. */ - public static SchemaUserTypeCreator getCreator(Class clazz, Schema schema) { - if (TypeDescriptor.of(clazz).isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { - return AvroByteBuddyUtils.getCreator((Class) clazz, schema); + public static SchemaUserTypeCreator getCreator( + TypeDescriptor typeDescriptor, Schema schema) { + if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { + return AvroByteBuddyUtils.getCreator( + (Class) typeDescriptor.getRawType(), schema); } else { return POJOUtils.getSetFieldCreator( - clazz, schema, new AvroPojoFieldValueTypeSupplier(), new AvroTypeConversionFactory()); + typeDescriptor, + schema, + new AvroPojoFieldValueTypeSupplier(), + new AvroTypeConversionFactory()); } } diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java index bb2e267bae23..3e61ec318802 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -468,22 +468,23 @@ public TypeConversion createSetterConversions(StackManipulati *

The returned list is ordered by the order of fields in the schema. */ public static List getGetters( - Class clazz, + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { - Multimap methods = ReflectUtils.getMethodsMap(clazz); + Multimap methods = ReflectUtils.getMethodsMap(typeDescriptor.getRawType()); return CACHED_GETTERS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), + ClassWithSchema.create(typeDescriptor.getRawType(), schema), c -> { - List types = fieldValueTypeSupplier.get(clazz, schema); + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); return types.stream() .map( t -> createGetter( t, typeConversionsFactory, - clazz, + typeDescriptor, methods, schema.getField(t.getName()), fieldValueTypeSupplier)) @@ -949,7 +950,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { private static FieldValueGetter createGetter( FieldValueTypeInformation fieldValueTypeInformation, TypeConversionsFactory typeConversionsFactory, - Class clazz, + TypeDescriptor typeDescriptor, Multimap methods, Field field, FieldValueTypeSupplier fieldValueTypeSupplier) { @@ -965,7 +966,7 @@ private static FieldValueGetter createGetter( // Create a map of case enum value to getter. This must be sorted, so store in a TreeMap. TreeMap> oneOfGetters = Maps.newTreeMap(); Map oneOfFieldTypes = - fieldValueTypeSupplier.get(clazz, oneOfType.getOneOfSchema()).stream() + fieldValueTypeSupplier.get(typeDescriptor, oneOfType.getOneOfSchema()).stream() .collect(Collectors.toMap(FieldValueTypeInformation::getName, f -> f)); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { int protoFieldIndex = getFieldNumber(oneOfField); @@ -973,14 +974,18 @@ private static FieldValueGetter createGetter( createGetter( oneOfFieldTypes.get(oneOfField.getName()), typeConversionsFactory, - clazz, + typeDescriptor, methods, oneOfField, fieldValueTypeSupplier); oneOfGetters.put(protoFieldIndex, oneOfFieldGetter); } return createOneOfGetter( - fieldValueTypeInformation, oneOfGetters, clazz, oneOfType, caseMethod); + fieldValueTypeInformation, + oneOfGetters, + typeDescriptor.getRawType(), + oneOfType, + caseMethod); } else { return JavaBeanUtils.createGetter(fieldValueTypeInformation, typeConversionsFactory); } @@ -1017,34 +1022,41 @@ static Method getProtoGetter(Multimap methods, String name, Fiel public static @Nullable SchemaUserTypeCreator getBuilderCreator( - Class protoClass, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { - Class builderClass = getProtoGeneratedBuilder(protoClass); + TypeDescriptor protoTypeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier) { + Class builderClass = getProtoGeneratedBuilder(protoTypeDescriptor.getRawType()); if (builderClass == null) { return null; } Multimap methods = ReflectUtils.getMethodsMap(builderClass); List> setters = schema.getFields().stream() - .map(f -> getProtoFieldValueSetter(f, methods, builderClass)) + .map(f -> getProtoFieldValueSetter(protoTypeDescriptor, f, methods, builderClass)) .collect(Collectors.toList()); - return createBuilderCreator(protoClass, builderClass, setters, schema); + return createBuilderCreator(protoTypeDescriptor.getRawType(), builderClass, setters, schema); } private static FieldValueSetter getProtoFieldValueSetter( - Field field, Multimap methods, Class builderClass) { + TypeDescriptor typeDescriptor, + Field field, + Multimap methods, + Class builderClass) { if (field.getType().isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = field.getType().getLogicalType(OneOfType.class); TreeMap> oneOfSetters = Maps.newTreeMap(); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { - FieldValueSetter setter = getProtoFieldValueSetter(oneOfField, methods, builderClass); + FieldValueSetter setter = + getProtoFieldValueSetter(typeDescriptor, oneOfField, methods, builderClass); oneOfSetters.put(getFieldNumber(oneOfField), setter); } return createOneOfSetter(field.getName(), oneOfSetters, builderClass); } else { Method method = getProtoSetter(methods, field.getName(), field.getType()); return JavaBeanUtils.createSetter( - FieldValueTypeInformation.forSetter(method, protoSetterPrefix(field.getType())), + FieldValueTypeInformation.forSetter( + typeDescriptor, method, protoSetterPrefix(field.getType())), new ProtoTypeConversionsFactory()); } } diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java index 1b3d42e35536..885d0d6e52f3 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java @@ -53,13 +53,13 @@ public class ProtoMessageSchema extends GetterBasedSchemaProvider { private static final class ProtoClassFieldValueTypeSupplier implements FieldValueTypeSupplier { @Override - public List get(Class clazz) { + public List get(TypeDescriptor typeDescriptor) { throw new RuntimeException("Unexpected call."); } @Override - public List get(Class clazz, Schema schema) { - Multimap methods = ReflectUtils.getMethodsMap(clazz); + public List get(TypeDescriptor typeDescriptor, Schema schema) { + Multimap methods = ReflectUtils.getMethodsMap(typeDescriptor.getRawType()); List types = Lists.newArrayListWithCapacity(schema.getFieldCount()); for (int i = 0; i < schema.getFieldCount(); ++i) { @@ -72,7 +72,8 @@ public List get(Class clazz, Schema schema) { Method method = getProtoGetter(methods, oneOfField.getName(), oneOfField.getType()); oneOfTypes.put( oneOfField.getName(), - FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) + .withName(field.getName())); } // Add an entry that encapsulates information about all possible getters. types.add( @@ -82,7 +83,9 @@ public List get(Class clazz, Schema schema) { } else { // This is a simple field. Add the getter. Method method = getProtoGetter(methods, field.getName(), field.getType()); - types.add(FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + types.add( + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) + .withName(field.getName())); } } return types; @@ -96,9 +99,10 @@ public List get(Class clazz, Schema schema) { } @Override - public List fieldValueGetters(Class targetClass, Schema schema) { + public List fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return ProtoByteBuddyUtils.getGetters( - targetClass, + targetTypeDescriptor, schema, new ProtoClassFieldValueTypeSupplier(), new ProtoTypeConversionsFactory()); @@ -106,17 +110,19 @@ public List fieldValueGetters(Class targetClass, Schema sch @Override public List fieldValueTypeInformations( - Class targetClass, Schema schema) { - return JavaBeanUtils.getFieldTypes(targetClass, schema, new ProtoClassFieldValueTypeSupplier()); + TypeDescriptor targetTypeDescriptor, Schema schema) { + return JavaBeanUtils.getFieldTypes( + targetTypeDescriptor, schema, new ProtoClassFieldValueTypeSupplier()); } @Override - public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema schema) { + public SchemaUserTypeCreator schemaTypeCreator( + TypeDescriptor targetTypeDescriptor, Schema schema) { SchemaUserTypeCreator creator = ProtoByteBuddyUtils.getBuilderCreator( - targetClass, schema, new ProtoClassFieldValueTypeSupplier()); + targetTypeDescriptor, schema, new ProtoClassFieldValueTypeSupplier()); if (creator == null) { - throw new RuntimeException("Cannot create creator for " + targetClass); + throw new RuntimeException("Cannot create creator for " + targetTypeDescriptor); } return creator; } @@ -149,7 +155,8 @@ public static SimpleFunction getRowToProtoBytesFn(Class claz private void checkForDynamicType(TypeDescriptor typeDescriptor) { if (typeDescriptor.getRawType().equals(DynamicMessage.class)) { throw new RuntimeException( - "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use ProtoDynamicMessageSchema instead."); + "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use" + + " ProtoDynamicMessageSchema instead."); } } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java index 65812d72df1d..ffb750f9f8b1 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java @@ -32,6 +32,7 @@ import java.util.Objects; import java.util.Set; import java.util.function.BiConsumer; +import java.util.stream.Collectors; import org.apache.beam.sdk.io.aws2.schemas.AwsSchemaUtils.SdkBuilderSetter; import org.apache.beam.sdk.io.aws2.schemas.AwsTypes.ConverterFactory; import org.apache.beam.sdk.schemas.CachingFactory; @@ -75,9 +76,10 @@ public class AwsSchemaProvider extends GetterBasedSchemaProvider { @SuppressWarnings("rawtypes") @Override - public List fieldValueGetters(Class clazz, Schema schema) { + public List fieldValueGetters(TypeDescriptor typeDescriptor, Schema schema) { ConverterFactory fromAws = ConverterFactory.fromAws(); - Map> sdkFields = sdkFieldsByName((Class) clazz); + Map> sdkFields = + sdkFieldsByName((Class) typeDescriptor.getRawType()); List getters = new ArrayList<>(schema.getFieldCount()); for (String field : schema.getFieldNames()) { SdkField sdkField = checkStateNotNull(sdkFields.get(field), "Unknown field"); @@ -91,7 +93,7 @@ public List fieldValueGetters(Class clazz, Schema schema) { @Override public SerializableFunction fromRowFunction(TypeDescriptor type) { checkState(SdkPojo.class.isAssignableFrom(type.getRawType()), "Unsupported type %s", type); - return FromRowFactory.create(type.getRawType()); + return FromRowFactory.create(type); } private static class FromRowWithBuilder @@ -114,7 +116,7 @@ public T apply(Row row) { } } SdkBuilder builder = sdkBuilder(cls); - List setters = factory.create(cls, row.getSchema()); + List setters = factory.create(TypeDescriptor.of(cls), row.getSchema()); for (SdkBuilderSetter set : setters) { if (!row.getSchema().hasField(set.name())) { continue; @@ -150,14 +152,19 @@ private static class FromRowFactory implements Factory(new SettersFactory()); @SuppressWarnings("nullness") // schema nullable for this factory - static SerializableFunction create(Class clazz) { - checkState(SdkPojo.class.isAssignableFrom(clazz), "Unsupported clazz %s", clazz); - return (SerializableFunction) new FromRowFactory().cachingFactory.create(clazz, null); + static SerializableFunction create(TypeDescriptor typeDescriptor) { + checkState( + SdkPojo.class.isAssignableFrom(typeDescriptor.getRawType()), + "Unsupported type descriptor %s", + typeDescriptor); + return (SerializableFunction) + new FromRowFactory().cachingFactory.create(typeDescriptor, null); } @Override - public SerializableFunction create(Class clazz, Schema ignored) { - return new FromRowWithBuilder<>((Class) clazz, settersFactory); + public SerializableFunction create(TypeDescriptor typeDescriptor, Schema ignored) { + return new FromRowWithBuilder<>( + (Class) typeDescriptor.getRawType(), settersFactory); } private class SettersFactory implements Factory> { @@ -168,8 +175,9 @@ private SettersFactory() { } @Override - public List create(Class clazz, Schema schema) { - Map> fields = sdkFieldsByName((Class) clazz); + public List create(TypeDescriptor typeDescriptor, Schema schema) { + Map> fields = + sdkFieldsByName((Class) typeDescriptor.getRawType()); checkForUnknownFields(schema, fields); List setters = new ArrayList<>(schema.getFieldCount()); @@ -192,12 +200,17 @@ private void checkForUnknownFields(Schema schema, Map> field } @Override - public List fieldValueTypeInformations(Class cls, Schema schema) { - throw new UnsupportedOperationException("FieldValueTypeInformation not available"); + public List fieldValueTypeInformations( + TypeDescriptor typeDescriptor, Schema schema) { + List> sdkFieldList = sdkFields((Class) typeDescriptor.getRawType()); + + return sdkFieldList.stream() + .map(AwsTypes::fieldValueTypeInformationFor) + .collect(Collectors.toList()); } @Override - public SchemaUserTypeCreator schemaTypeCreator(Class cls, Schema schema) { + public SchemaUserTypeCreator schemaTypeCreator(TypeDescriptor typeDescriptor, Schema schema) { throw new UnsupportedOperationException("SchemaUserTypeCreator not available"); } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java index f5647c040526..f5b06d3cd1c9 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java @@ -27,17 +27,20 @@ import static software.amazon.awssdk.core.protocol.MarshallingType.SDK_POJO; import java.io.Serializable; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.FieldValueTypeInformation; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Ascii; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -90,6 +93,20 @@ private static FieldType fieldType(SdkField field, Set> seen) { String.format("Type %s of field %s is unknown.", type, normalizedNameOf(field))); } + static FieldValueTypeInformation fieldValueTypeInformationFor(SdkField sdkField) { + TypeDescriptor type = TypeDescriptor.of(sdkField.marshallingType().getTargetClass()); + return FieldValueTypeInformation.builder() + .setName(normalizedNameOf(sdkField)) + .setType(type) + .setRawType(sdkField.marshallingType().getClass()) + .setElementType(FieldValueTypeInformation.getIterableComponentType(type)) + .setMapKeyType(FieldValueTypeInformation.getMapKeyType(type)) + .setMapValueType(FieldValueTypeInformation.getMapValueType(type)) + .setOneOfTypes(Collections.emptyMap()) + .setNullable(true) + .build(); + } + private static Schema schemaFor(List> fields, Set> seen) { Schema.Builder builder = Schema.builder(); for (SdkField sdkField : fields) { @@ -210,7 +227,8 @@ private static class ToAws extends ConverterFactory { @Override @SuppressWarnings("nullness") // schema nullable for this factory protected SerializableFunction pojoTypeConverter(SdkField field) { - return fromRowFactory.create(targetClassOf(field.constructor().get()), null); + return fromRowFactory.create( + TypeDescriptor.of(targetClassOf(field.constructor().get())), null); } } diff --git a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java index 5ee78590f679..a72daf1052bb 100644 --- a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java +++ b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java @@ -203,17 +203,17 @@ private Schema.Field beamField(FieldMetaData fieldDescriptor) { @SuppressWarnings("rawtypes") @Override public @NonNull List fieldValueGetters( - @NonNull Class targetClass, @NonNull Schema schema) { - return schemaFieldDescriptors(targetClass, schema).keySet().stream() + @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { + return schemaFieldDescriptors(targetTypeDescriptor.getRawType(), schema).keySet().stream() .map(FieldExtractor::new) .collect(Collectors.toList()); } @Override public @NonNull List fieldValueTypeInformations( - @NonNull Class targetClass, @NonNull Schema schema) { - return schemaFieldDescriptors(targetClass, schema).values().stream() - .map(descriptor -> fieldValueTypeInfo(targetClass, descriptor.fieldName)) + @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { + return schemaFieldDescriptors(targetTypeDescriptor.getRawType(), schema).values().stream() + .map(descriptor -> fieldValueTypeInfo(targetTypeDescriptor, descriptor.fieldName)) .collect(Collectors.toList()); } @@ -223,27 +223,29 @@ Map thriftFieldDescriptors(Class targetClass) { return (Map) FieldMetaData.getStructMetaDataMap((Class) targetClass); } - private FieldValueTypeInformation fieldValueTypeInfo(Class type, String fieldName) { - if (TUnion.class.isAssignableFrom(type)) { + private FieldValueTypeInformation fieldValueTypeInfo(TypeDescriptor type, String fieldName) { + if (TUnion.class.isAssignableFrom(type.getRawType())) { final List factoryMethods = - Stream.of(type.getDeclaredMethods()) + Stream.of(type.getRawType().getDeclaredMethods()) .filter(m -> m.getName().equals(fieldName)) .filter(m -> m.getModifiers() == (Modifier.PUBLIC | Modifier.STATIC)) .filter(m -> m.getParameterCount() == 1) - .filter(m -> m.getReturnType() == type) + .filter(m -> m.getReturnType() == type.getRawType()) .collect(Collectors.toList()); if (factoryMethods.isEmpty()) { throw new IllegalArgumentException( String.format( - "No suitable static factory method: %s.%s(...)", type.getName(), fieldName)); + "No suitable static factory method: %s.%s(...)", + type.getRawType().getName(), fieldName)); } if (factoryMethods.size() > 1) { throw new IllegalStateException("Overloaded factory methods: " + factoryMethods); } - return FieldValueTypeInformation.forSetter(factoryMethods.get(0), ""); + return FieldValueTypeInformation.forSetter(type, factoryMethods.get(0), ""); } else { try { - return FieldValueTypeInformation.forField(type.getDeclaredField(fieldName), 0); + return FieldValueTypeInformation.forField( + type, type.getRawType().getDeclaredField(fieldName), 0); } catch (NoSuchFieldException e) { throw new IllegalArgumentException(e); } @@ -252,10 +254,11 @@ private FieldValueTypeInformation fieldValueTypeInfo(Class type, String field @Override public @NonNull SchemaUserTypeCreator schemaTypeCreator( - @NonNull Class targetClass, @NonNull Schema schema) { + @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { final Map fieldDescriptors = - schemaFieldDescriptors(targetClass, schema); - return params -> restoreThriftObject(targetClass, fieldDescriptors, params); + schemaFieldDescriptors(targetTypeDescriptor.getRawType(), schema); + return params -> + restoreThriftObject(targetTypeDescriptor.getRawType(), fieldDescriptors, params); } @SuppressWarnings("nullness")