From 1366db3f38cbc63321cf0954b3544dd0a81c2c0a Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Tue, 8 Oct 2024 15:35:04 -0700 Subject: [PATCH] Enable schema inference in the presence of generic type parameters. --- .../beam/sdk/schemas/AutoValueSchema.java | 8 +- .../schemas/FieldValueTypeInformation.java | 87 +++++----- .../beam/sdk/schemas/JavaBeanSchema.java | 12 +- .../beam/sdk/schemas/JavaFieldSchema.java | 10 +- .../beam/sdk/schemas/SchemaProvider.java | 3 +- .../beam/sdk/schemas/SchemaRegistry.java | 39 ++--- .../transforms/providers/JavaRowUdf.java | 3 +- .../sdk/schemas/utils/AutoValueUtils.java | 19 ++- .../sdk/schemas/utils/ByteBuddyUtils.java | 53 ++++--- .../sdk/schemas/utils/ConvertHelpers.java | 4 +- .../beam/sdk/schemas/utils/JavaBeanUtils.java | 8 +- .../beam/sdk/schemas/utils/POJOUtils.java | 18 ++- .../beam/sdk/schemas/utils/ReflectUtils.java | 79 ++++++++-- .../schemas/utils/StaticSchemaInference.java | 89 +++++------ .../beam/sdk/schemas/AutoValueSchemaTest.java | 149 ++++++++++++++++++ .../beam/sdk/schemas/JavaBeanSchemaTest.java | 38 +++++ .../beam/sdk/schemas/JavaFieldSchemaTest.java | 120 ++++++++++++++ .../sdk/schemas/utils/JavaBeanUtilsTest.java | 33 +++- .../beam/sdk/schemas/utils/POJOUtilsTest.java | 36 +++-- .../beam/sdk/schemas/utils/TestJavaBeans.java | 56 +++++++ .../beam/sdk/schemas/utils/TestPOJOs.java | 121 +++++++++++++- .../schemas/utils/AvroByteBuddyUtils.java | 6 +- .../avro/schemas/utils/AvroUtils.java | 10 +- .../protobuf/ProtoByteBuddyUtils.java | 4 +- .../protobuf/ProtoMessageSchema.java | 8 +- .../python/PythonExternalTransform.java | 4 +- .../beam/sdk/io/thrift/ThriftSchema.java | 5 +- 27 files changed, 825 insertions(+), 197 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java index 5ccfe39b92af..c369eefeb65c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java @@ -19,8 +19,10 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.Type; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; import org.apache.beam.sdk.schemas.utils.AutoValueUtils; @@ -61,8 +63,9 @@ public List get(TypeDescriptor typeDescriptor) { .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -143,7 +146,8 @@ public SchemaUserTypeCreator schemaTypeCreator( @Override public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); return JavaBeanUtils.schemaFromJavaBeanClass( - typeDescriptor, AbstractGetterTypeSupplier.INSTANCE); + typeDescriptor, AbstractGetterTypeSupplier.INSTANCE, boundTypes); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index 750709192c08..ee68f7087b9e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -24,6 +24,7 @@ import java.lang.reflect.Field; import java.lang.reflect.Member; import java.lang.reflect.Method; +import java.lang.reflect.Type; import java.util.Arrays; import java.util.Collections; import java.util.Map; @@ -125,8 +126,10 @@ public static FieldValueTypeInformation forOneOf( .build(); } - public static FieldValueTypeInformation forField(Field field, int index) { - TypeDescriptor type = TypeDescriptor.of(field.getGenericType()); + public static FieldValueTypeInformation forField( + Field field, int index, Map boundTypes) { + TypeDescriptor type = + TypeDescriptor.of(ReflectUtils.resolveType(field.getGenericType(), boundTypes)); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(field.getName(), field)) .setNumber(getNumberOverride(index, field)) @@ -134,9 +137,9 @@ public static FieldValueTypeInformation forField(Field field, int index) { .setType(type) .setRawType(type.getRawType()) .setField(field) - .setElementType(getIterableComponentType(field)) - .setMapKeyType(getMapKeyType(field)) - .setMapValueType(getMapValueType(field)) + .setElementType(getIterableComponentType(field, boundTypes)) + .setMapKeyType(getMapKeyType(field, boundTypes)) + .setMapValueType(getMapValueType(field, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(field)) .build(); @@ -184,7 +187,8 @@ public static String getNameOverride( return fieldDescription.value(); } - public static FieldValueTypeInformation forGetter(Method method, int index) { + public static FieldValueTypeInformation forGetter( + Method method, int index, Map boundTypes) { String name; if (method.getName().startsWith("get")) { name = ReflectUtils.stripPrefix(method.getName(), "get"); @@ -194,7 +198,8 @@ public static FieldValueTypeInformation forGetter(Method method, int index) { throw new RuntimeException("Getter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericReturnType()); + TypeDescriptor type = + TypeDescriptor.of(ReflectUtils.resolveType(method.getGenericReturnType(), boundTypes)); boolean nullable = hasNullableReturnType(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(name, method)) @@ -203,9 +208,9 @@ public static FieldValueTypeInformation forGetter(Method method, int index) { .setType(type) .setRawType(type.getRawType()) .setMethod(method) - .setElementType(getIterableComponentType(type)) - .setMapKeyType(getMapKeyType(type)) - .setMapValueType(getMapValueType(type)) + .setElementType(getIterableComponentType(type, boundTypes)) + .setMapKeyType(getMapKeyType(type, boundTypes)) + .setMapValueType(getMapValueType(type, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(method)) .build(); @@ -252,11 +257,13 @@ private static boolean isNullableAnnotation(Annotation annotation) { return annotation.annotationType().getSimpleName().equals("Nullable"); } - public static FieldValueTypeInformation forSetter(Method method) { - return forSetter(method, "set"); + public static FieldValueTypeInformation forSetter( + Method method, Map boundParameters) { + return forSetter(method, "set", boundParameters); } - public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) { + public static FieldValueTypeInformation forSetter( + Method method, String setterPrefix, Map boundTypes) { String name; if (method.getName().startsWith(setterPrefix)) { name = ReflectUtils.stripPrefix(method.getName(), setterPrefix); @@ -264,7 +271,9 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr throw new RuntimeException("Setter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericParameterTypes()[0]); + TypeDescriptor type = + TypeDescriptor.of( + ReflectUtils.resolveType(method.getGenericParameterTypes()[0], boundTypes)); boolean nullable = hasSingleNullableParameter(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(name) @@ -272,9 +281,9 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr .setType(type) .setRawType(type.getRawType()) .setMethod(method) - .setElementType(getIterableComponentType(type)) - .setMapKeyType(getMapKeyType(type)) - .setMapValueType(getMapValueType(type)) + .setElementType(getIterableComponentType(type, boundTypes)) + .setMapKeyType(getMapKeyType(type, boundTypes)) + .setMapValueType(getMapValueType(type, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .build(); } @@ -283,13 +292,15 @@ public FieldValueTypeInformation withName(String name) { return toBuilder().setName(name).build(); } - private static FieldValueTypeInformation getIterableComponentType(Field field) { - return getIterableComponentType(TypeDescriptor.of(field.getGenericType())); + private static FieldValueTypeInformation getIterableComponentType( + Field field, Map boundTypes) { + return getIterableComponentType(TypeDescriptor.of(field.getGenericType()), boundTypes); } - static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { + static @Nullable FieldValueTypeInformation getIterableComponentType( + TypeDescriptor valueType, Map boundTypes) { // TODO: Figure out nullable elements. - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType, boundTypes); if (componentType == null) { return null; } @@ -299,41 +310,43 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { .setNullable(false) .setType(componentType) .setRawType(componentType.getRawType()) - .setElementType(getIterableComponentType(componentType)) - .setMapKeyType(getMapKeyType(componentType)) - .setMapValueType(getMapValueType(componentType)) + .setElementType(getIterableComponentType(componentType, boundTypes)) + .setMapKeyType(getMapKeyType(componentType, boundTypes)) + .setMapValueType(getMapValueType(componentType, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .build(); } // If the Field is a map type, returns the key type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) { - return getMapKeyType(TypeDescriptor.of(field.getGenericType())); + private static @Nullable FieldValueTypeInformation getMapKeyType( + Field field, Map boundTypes) { + return getMapKeyType(TypeDescriptor.of(field.getGenericType()), boundTypes); } private static @Nullable FieldValueTypeInformation getMapKeyType( - TypeDescriptor typeDescriptor) { - return getMapType(typeDescriptor, 0); + TypeDescriptor typeDescriptor, Map boundTypes) { + return getMapType(typeDescriptor, 0, boundTypes); } // If the Field is a map type, returns the value type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapValueType(Field field) { - return getMapType(TypeDescriptor.of(field.getGenericType()), 1); + private static @Nullable FieldValueTypeInformation getMapValueType( + Field field, Map boundTypes) { + return getMapType(TypeDescriptor.of(field.getGenericType()), 1, boundTypes); } private static @Nullable FieldValueTypeInformation getMapValueType( - TypeDescriptor typeDescriptor) { - return getMapType(typeDescriptor, 1); + TypeDescriptor typeDescriptor, Map boundTypes) { + return getMapType(typeDescriptor, 1, boundTypes); } // If the Field is a map type, returns the key or value type (0 is key type, 1 is value). // Otherwise returns a null reference. @SuppressWarnings("unchecked") private static @Nullable FieldValueTypeInformation getMapType( - TypeDescriptor valueType, int index) { - TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); + TypeDescriptor valueType, int index, Map boundTypes) { + TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index, boundTypes); if (mapType == null) { return null; } @@ -342,9 +355,9 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { .setNullable(false) .setType(mapType) .setRawType(mapType.getRawType()) - .setElementType(getIterableComponentType(mapType)) - .setMapKeyType(getMapKeyType(mapType)) - .setMapValueType(getMapValueType(mapType)) + .setElementType(getIterableComponentType(mapType, boundTypes)) + .setMapKeyType(getMapKeyType(mapType, boundTypes)) + .setMapValueType(getMapValueType(mapType, boundTypes)) .setOneOfTypes(Collections.emptyMap()) .build(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java index a9cf01c52057..ad71576670bf 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java @@ -19,8 +19,10 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.lang.reflect.Type; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -67,8 +69,9 @@ public List get(TypeDescriptor typeDescriptor) { .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -111,10 +114,11 @@ public static class SetterTypeSupplier implements FieldValueTypeSupplier { @Override public List get(TypeDescriptor typeDescriptor) { + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream() .filter(ReflectUtils::isSetter) .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) - .map(FieldValueTypeInformation::forSetter) + .map(m -> FieldValueTypeInformation.forSetter(m, boundTypes)) .map( t -> { if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) { @@ -156,8 +160,10 @@ public boolean equals(@Nullable Object obj) { @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); Schema schema = - JavaBeanUtils.schemaFromJavaBeanClass(typeDescriptor, GetterTypeSupplier.INSTANCE); + JavaBeanUtils.schemaFromJavaBeanClass( + typeDescriptor, GetterTypeSupplier.INSTANCE, boundTypes); // If there are no creator methods, then validate that we have setters for every field. // Otherwise, we will have no way of creating instances of the class. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java index 21f07c47b47f..da0f59c8ee96 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java @@ -21,8 +21,10 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.Type; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -62,9 +64,11 @@ public List get(TypeDescriptor typeDescriptor) { ReflectUtils.getFields(typeDescriptor.getRawType()).stream() .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); + List types = Lists.newArrayListWithCapacity(fields.size()); + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < fields.size(); ++i) { - types.add(FieldValueTypeInformation.forField(fields.get(i), i)); + types.add(FieldValueTypeInformation.forField(fields.get(i), i, boundTypes)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -111,7 +115,9 @@ private static void validateFieldNumbers(List types) @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { - return POJOUtils.schemaFromPojoClass(typeDescriptor, JavaFieldTypeSupplier.INSTANCE); + Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); + return POJOUtils.schemaFromPojoClass( + typeDescriptor, JavaFieldTypeSupplier.INSTANCE, boundTypes); } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java index 37b4952e529c..b7e3cdf60c18 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java @@ -38,8 +38,7 @@ public interface SchemaProvider extends Serializable { * Given a type, return a function that converts that type to a {@link Row} object If no schema * exists, returns null. */ - @Nullable - SerializableFunction toRowFunction(TypeDescriptor typeDescriptor); + @Nullable SerializableFunction toRowFunction(TypeDescriptor typeDescriptor); /** * Given a type, returns a function that converts from a {@link Row} object to that type. If no diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java index 679a1fcf54fc..5d8b7aab6193 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java @@ -76,13 +76,12 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid providers.put(typeDescriptor, schemaProvider); } - @Override - public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { + private @Nullable SchemaProvider schemaProviderFor(TypeDescriptor typeDescriptor) { TypeDescriptor type = typeDescriptor; do { SchemaProvider schemaProvider = providers.get(type); if (schemaProvider != null) { - return schemaProvider.schemaFor(type); + return schemaProvider; } Class superClass = type.getRawType().getSuperclass(); if (superClass == null || superClass.equals(Object.class)) { @@ -92,38 +91,24 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid } while (true); } + @Override + public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { + @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); + return schemaProvider != null ? schemaProvider.schemaFor(typeDescriptor) : null; + } + @Override public @Nullable SerializableFunction toRowFunction( TypeDescriptor typeDescriptor) { - TypeDescriptor type = typeDescriptor; - do { - SchemaProvider schemaProvider = providers.get(type); - if (schemaProvider != null) { - return (SerializableFunction) schemaProvider.toRowFunction(type); - } - Class superClass = type.getRawType().getSuperclass(); - if (superClass == null || superClass.equals(Object.class)) { - return null; - } - type = TypeDescriptor.of(superClass); - } while (true); + @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); + return schemaProvider != null ? schemaProvider.toRowFunction(typeDescriptor) : null; } @Override public @Nullable SerializableFunction fromRowFunction( TypeDescriptor typeDescriptor) { - TypeDescriptor type = typeDescriptor; - do { - SchemaProvider schemaProvider = providers.get(type); - if (schemaProvider != null) { - return (SerializableFunction) schemaProvider.fromRowFunction(type); - } - Class superClass = type.getRawType().getSuperclass(); - if (superClass == null || superClass.equals(Object.class)) { - return null; - } - type = TypeDescriptor.of(superClass); - } while (true); + @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); + return schemaProvider != null ? schemaProvider.fromRowFunction(typeDescriptor) : null; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java index 54e2a595fa71..c3a71bbb454b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java @@ -160,7 +160,8 @@ public FunctionAndType(Type outputType, Function function) { public FunctionAndType(TypeDescriptor outputType, Function function) { this( - StaticSchemaInference.fieldFromType(outputType, new EmptyFieldValueTypeSupplier()), + StaticSchemaInference.fieldFromType( + outputType, new EmptyFieldValueTypeSupplier(), Collections.emptyMap()), function); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java index d7fddd8abfed..80a894f20cab 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java @@ -63,6 +63,7 @@ import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for managing AutoValue schemas. */ @@ -161,7 +162,8 @@ private static boolean matchConstructor( // Verify that constructor parameters match (name and type) the inferred schema. for (Parameter parameter : constructor.getParameters()) { FieldValueTypeInformation type = typeMap.get(parameter.getName()); - if (type == null || type.getRawType() != parameter.getType()) { + ; + if (type == null || !type.getRawType().equals(parameter.getType())) { valid = false; break; } @@ -178,7 +180,7 @@ private static boolean matchConstructor( } name = name.substring(0, name.length() - 1); FieldValueTypeInformation type = typeMap.get(name); - if (type == null || type.getRawType() != parameter.getType()) { + if (type == null || !type.getRawType().equals(parameter.getType())) { return false; } } @@ -196,11 +198,12 @@ private static boolean matchConstructor( return null; } - Map setterTypes = - ReflectUtils.getMethods(builderClass).stream() - .filter(ReflectUtils::isSetter) - .map(FieldValueTypeInformation::forSetter) - .collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity())); + Map boundTypes = ReflectUtils.getAllBoundTypes(TypeDescriptor.of(builderClass)); + Map setterTypes = Maps.newHashMap(); + ReflectUtils.getMethods(builderClass).stream() + .filter(ReflectUtils::isSetter) + .map(m -> FieldValueTypeInformation.forSetter(m, boundTypes)) + .forEach(fv -> setterTypes.putIfAbsent(fv.getName(), fv)); List setterMethods = Lists.newArrayList(); // The builder methods to call in order. @@ -321,7 +324,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Duplication.SINGLE, typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(parameter.getType())), + .convert(TypeDescriptor.of(parameter.getParameterizedType())), MethodInvocation.invoke(new ForLoadedMethod(setterMethod)), Removal.SINGLE); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index c2b33c2d2315..65adc33a1bab 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -344,19 +344,22 @@ protected Type convertArray(TypeDescriptor type) { @Override protected Type convertCollection(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertList(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertIterable(TypeDescriptor type) { - TypeDescriptor ret = createIterableType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createIterableType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @@ -687,7 +690,8 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); @@ -707,7 +711,8 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { @@ -726,7 +731,8 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { @@ -745,8 +751,8 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); - final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); + final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0, Collections.emptyMap()); + final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1, Collections.emptyMap()); Type convertedKeyType = getFactory().createTypeConversion(true).convert(keyType); Type convertedValueType = getFactory().createTypeConversion(true).convert(valueType); @@ -1038,8 +1044,9 @@ protected StackManipulation convertIterable(TypeDescriptor type) { Type rowElementType = getFactory() .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor iterableElementType = ReflectUtils.getIterableComponentType(type); + .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + final TypeDescriptor iterableElementType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); if (!iterableElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = new ForLoadedType( @@ -1060,8 +1067,9 @@ protected StackManipulation convertCollection(TypeDescriptor type) { Type rowElementType = getFactory() .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + final TypeDescriptor collectionElementType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); if (!collectionElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -1083,8 +1091,9 @@ protected StackManipulation convertList(TypeDescriptor type) { Type rowElementType = getFactory() .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + final TypeDescriptor collectionElementType = + ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); StackManipulation readTrasformedValue = readValue; if (!collectionElementType.hasUnresolvedParameters()) { @@ -1113,11 +1122,17 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { Type rowKeyType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 0)); - final TypeDescriptor keyElementType = ReflectUtils.getMapType(type, 0); + getFactory() + .createTypeConversion(false) + .convert(ReflectUtils.getMapType(type, 0, Collections.emptyMap())); + final TypeDescriptor keyElementType = + ReflectUtils.getMapType(type, 0, Collections.emptyMap()); Type rowValueType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 1)); - final TypeDescriptor valueElementType = ReflectUtils.getMapType(type, 1); + getFactory() + .createTypeConversion(false) + .convert(ReflectUtils.getMapType(type, 1, Collections.emptyMap())); + final TypeDescriptor valueElementType = + ReflectUtils.getMapType(type, 1, Collections.emptyMap()); StackManipulation readTrasformedValue = readValue; if (!keyElementType.hasUnresolvedParameters() @@ -1475,7 +1490,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Parameter parameter = parameters.get(i); ForLoadedType convertedType = new ForLoadedType( - (Class) convertType.convert(TypeDescriptor.of(parameter.getType()))); + (Class) convertType.convert(TypeDescriptor.of(parameter.getParameterizedType()))); // The instruction to read the parameter. Use the fieldMapping to reorder parameters as // necessary. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java index 7f2403035d97..2132e6024ea3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java @@ -22,6 +22,7 @@ import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Type; +import java.util.Collections; import java.util.ServiceLoader; import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; @@ -148,7 +149,8 @@ public static SerializableFunction getConvertPrimitive( TypeDescriptor outputTypeDescriptor, TypeConversionsFactory typeConversionsFactory) { FieldType expectedFieldType = - StaticSchemaInference.fieldFromType(outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE); + StaticSchemaInference.fieldFromType( + outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE, Collections.emptyMap()); if (!expectedFieldType.equals(fieldType)) { throw new IllegalArgumentException( "Element argument type " diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java index 911f79f6eeed..f2ba796d30de 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java @@ -22,6 +22,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.lang.reflect.Type; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -64,8 +65,11 @@ public class JavaBeanUtils { /** Create a {@link Schema} for a Java Bean class. */ public static Schema schemaFromJavaBeanClass( - TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { - return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); + TypeDescriptor typeDescriptor, + FieldValueTypeSupplier fieldValueTypeSupplier, + Map boundTypes) { + return StaticSchemaInference.schemaFromClass( + typeDescriptor, fieldValueTypeSupplier, boundTypes); } private static final String CONSTRUCTOR_HELP_STRING = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index 571b9c690900..a36e91de3fc4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -73,8 +73,11 @@ public class POJOUtils { public static Schema schemaFromPojoClass( - TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { - return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); + TypeDescriptor typeDescriptor, + FieldValueTypeSupplier fieldValueTypeSupplier, + Map boundTypes) { + return StaticSchemaInference.schemaFromClass( + typeDescriptor, fieldValueTypeSupplier, boundTypes); } // Static ByteBuddy instance used by all helpers. @@ -301,7 +304,7 @@ public static SchemaUserTypeCreator createStaticCreator( field.getDeclaringClass(), typeConversionsFactory .createTypeConversion(false) - .convert(TypeDescriptor.of(field.getType()))); + .convert(TypeDescriptor.of(field.getGenericType()))); builder = implementGetterMethods(builder, field, typeInformation.getName(), typeConversionsFactory); try { @@ -383,7 +386,7 @@ private static FieldValueSetter createSetter( field.getDeclaringClass(), typeConversionsFactory .createTypeConversion(false) - .convert(TypeDescriptor.of(field.getType()))); + .convert(TypeDescriptor.of(field.getGenericType()))); builder = implementSetterMethods(builder, field, typeConversionsFactory); try { return builder @@ -491,7 +494,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Do any conversions necessary. typeConversionsFactory .createSetterConversions(readField) - .convert(TypeDescriptor.of(field.getType())), + .convert(TypeDescriptor.of(field.getGenericType())), // Now update the field and return void. FieldAccess.forField(new ForLoadedField(field)).write(), MethodReturn.VOID); @@ -546,7 +549,8 @@ public ByteCodeAppender appender(final Target implementationTarget) { Field field = fields.get(i); ForLoadedType convertedType = - new ForLoadedType((Class) convertType.convert(TypeDescriptor.of(field.getType()))); + new ForLoadedType( + (Class) convertType.convert(TypeDescriptor.of(field.getGenericType()))); // The instruction to read the parameter. StackManipulation readParameter = @@ -563,7 +567,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Do any conversions necessary. typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(field.getType())), + .convert(TypeDescriptor.of(field.getGenericType())), // Now update the field. FieldAccess.forField(new ForLoadedField(field)).write()); stackManipulation = new StackManipulation.Compound(stackManipulation, updateField); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index 4349a04c28ad..04eee4b19c09 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -26,16 +26,17 @@ import java.lang.reflect.Modifier; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; import java.security.InvalidParameterException; import java.util.Arrays; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.SchemaCreate; +import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; @@ -88,14 +89,23 @@ public static List getMethods(Class clazz) { return DECLARED_METHODS.computeIfAbsent( clazz, c -> { - return Arrays.stream(c.getDeclaredMethods()) - .filter( - m -> !m.isBridge()) // Covariant overloads insert bridge functions, which we must - // ignore. - .filter(m -> !Modifier.isPrivate(m.getModifiers())) - .filter(m -> !Modifier.isProtected(m.getModifiers())) - .filter(m -> !Modifier.isStatic(m.getModifiers())) - .collect(Collectors.toList()); + List methods = Lists.newArrayList(); + do { + if (c.getPackage() != null && c.getPackage().getName().startsWith("java.")) { + break; // skip java built-in classes + } + Arrays.stream(c.getDeclaredMethods()) + .filter( + m -> + !m.isBridge()) // Covariant overloads insert bridge functions, which we must + // ignore. + .filter(m -> !Modifier.isPrivate(m.getModifiers())) + .filter(m -> !Modifier.isProtected(m.getModifiers())) + .filter(m -> !Modifier.isStatic(m.getModifiers())) + .forEach(methods::add); + c = c.getSuperclass(); + } while (c != null); + return methods; }); } @@ -201,7 +211,8 @@ public static String stripSetterPrefix(String method) { } /** For an array T[] or a subclass of Iterable, return a TypeDescriptor describing T. */ - public static @Nullable TypeDescriptor getIterableComponentType(TypeDescriptor valueType) { + public static @Nullable TypeDescriptor getIterableComponentType( + TypeDescriptor valueType, Map boundTypes) { TypeDescriptor componentType = null; if (valueType.isArray()) { Type component = valueType.getComponentType().getType(); @@ -215,7 +226,7 @@ public static String stripSetterPrefix(String method) { ParameterizedType ptype = (ParameterizedType) collection.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); checkArgument(params.length == 1); - componentType = TypeDescriptor.of(params[0]); + componentType = TypeDescriptor.of(resolveType(params[0], boundTypes)); } else { throw new RuntimeException("Collection parameter is not parameterized!"); } @@ -223,14 +234,15 @@ public static String stripSetterPrefix(String method) { return componentType; } - public static TypeDescriptor getMapType(TypeDescriptor valueType, int index) { + public static TypeDescriptor getMapType( + TypeDescriptor valueType, int index, Map boundTypes) { TypeDescriptor mapType = null; if (valueType.isSubtypeOf(TypeDescriptor.of(Map.class))) { TypeDescriptor> map = valueType.getSupertype(Map.class); if (map.getType() instanceof ParameterizedType) { ParameterizedType ptype = (ParameterizedType) map.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - mapType = TypeDescriptor.of(params[index]); + mapType = TypeDescriptor.of(resolveType(params[index], boundTypes)); } else { throw new RuntimeException("Map type is not parameterized! " + map); } @@ -243,4 +255,45 @@ public static TypeDescriptor boxIfPrimitive(TypeDescriptor typeDescriptor) { ? TypeDescriptor.of(Primitives.wrap(typeDescriptor.getRawType())) : typeDescriptor; } + + public static Map getAllBoundTypes(TypeDescriptor typeDescriptor) { + Map boundParameters = Maps.newHashMap(); + TypeDescriptor currentType = typeDescriptor; + do { + if (currentType.getType() instanceof ParameterizedType) { + ParameterizedType parameterizedType = (ParameterizedType) currentType.getType(); + TypeVariable[] typeVariables = currentType.getRawType().getTypeParameters(); + Type[] typeArguments = parameterizedType.getActualTypeArguments(); + ; + if (typeArguments.length != typeVariables.length) { + throw new RuntimeException("Unmatching arguments lengths"); + } + for (int i = 0; i < typeVariables.length; ++i) { + boundParameters.put(typeVariables[i], typeArguments[i]); + } + } + Type superClass = currentType.getRawType().getGenericSuperclass(); + if (superClass == null || superClass.equals(Object.class)) { + break; + } + currentType = TypeDescriptor.of(superClass); + } while (true); + return boundParameters; + } + + public static Type resolveType(Type type, Map boundTypes) { + TypeDescriptor typeDescriptor = TypeDescriptor.of(type); + if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Iterable.class)) + || typeDescriptor.isSubtypeOf(TypeDescriptor.of(Map.class))) { + // Don't resolve these as we special case map and interable. + return type; + } + + if (type instanceof TypeVariable) { + TypeVariable typeVariable = (TypeVariable) type; + return Preconditions.checkArgumentNotNull(boundTypes.get(typeVariable)); + } else { + return type; + } + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java index 196ee6f86593..33934144322b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java @@ -19,7 +19,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.math.BigDecimal; import java.nio.ByteBuffer; import java.util.Arrays; @@ -33,6 +33,7 @@ import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.ReadableInstant; @@ -85,14 +86,17 @@ enum MethodType { * public getter methods, or special annotations on the class. */ public static Schema schemaFromClass( - TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { - return schemaFromClass(typeDescriptor, fieldValueTypeSupplier, new HashMap<>()); + TypeDescriptor typeDescriptor, + FieldValueTypeSupplier fieldValueTypeSupplier, + Map boundTypes) { + return schemaFromClass(typeDescriptor, fieldValueTypeSupplier, new HashMap<>(), boundTypes); } private static Schema schemaFromClass( TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier, - Map, Schema> alreadyVisitedSchemas) { + Map, Schema> alreadyVisitedSchemas, + Map boundTypes) { if (alreadyVisitedSchemas.containsKey(typeDescriptor)) { Schema existingSchema = alreadyVisitedSchemas.get(typeDescriptor); if (existingSchema == null) { @@ -106,7 +110,7 @@ private static Schema schemaFromClass( Schema.Builder builder = Schema.builder(); for (FieldValueTypeInformation type : fieldValueTypeSupplier.get(typeDescriptor)) { Schema.FieldType fieldType = - fieldFromType(type.getType(), fieldValueTypeSupplier, alreadyVisitedSchemas); + fieldFromType(type.getType(), fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes); Schema.Field f = type.isNullable() ? Schema.Field.nullable(type.getName(), fieldType) @@ -123,15 +127,18 @@ private static Schema schemaFromClass( /** Map a Java field type to a Beam Schema FieldType. */ public static Schema.FieldType fieldFromType( - TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier) { - return fieldFromType(type, fieldValueTypeSupplier, new HashMap<>()); + TypeDescriptor type, + FieldValueTypeSupplier fieldValueTypeSupplier, + Map boundTypes) { + return fieldFromType(type, fieldValueTypeSupplier, new HashMap<>(), boundTypes); } // TODO(https://github.com/apache/beam/issues/21567): support type inference for logical types private static Schema.FieldType fieldFromType( TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier, - Map, Schema> alreadyVisitedSchemas) { + Map, Schema> alreadyVisitedSchemas, + Map boundTypes) { FieldType primitiveType = PRIMITIVE_TYPES.get(type.getRawType()); if (primitiveType != null) { return primitiveType; @@ -152,27 +159,25 @@ private static Schema.FieldType fieldFromType( } else { // Otherwise this is an array type. return FieldType.array( - fieldFromType(component, fieldValueTypeSupplier, alreadyVisitedSchemas)); + fieldFromType(component, fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes)); } } else if (type.isSubtypeOf(TypeDescriptor.of(Map.class))) { - TypeDescriptor> map = type.getSupertype(Map.class); - if (map.getType() instanceof ParameterizedType) { - ParameterizedType ptype = (ParameterizedType) map.getType(); - java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - checkArgument(params.length == 2); - FieldType keyType = - fieldFromType( - TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas); - FieldType valueType = - fieldFromType( - TypeDescriptor.of(params[1]), fieldValueTypeSupplier, alreadyVisitedSchemas); - checkArgument( - keyType.getTypeName().isPrimitiveType(), - "Only primitive types can be map keys. type: " + keyType.getTypeName()); - return FieldType.map(keyType, valueType); - } else { - throw new RuntimeException("Cannot infer schema from unparameterized map."); - } + FieldType keyType = + fieldFromType( + ReflectUtils.getMapType(type, 0, boundTypes), + fieldValueTypeSupplier, + alreadyVisitedSchemas, + boundTypes); + FieldType valueType = + fieldFromType( + ReflectUtils.getMapType(type, 1, boundTypes), + fieldValueTypeSupplier, + alreadyVisitedSchemas, + boundTypes); + checkArgument( + keyType.getTypeName().isPrimitiveType(), + "Only primitive types can be map keys. type: " + keyType.getTypeName()); + return FieldType.map(keyType, valueType); } else if (type.isSubtypeOf(TypeDescriptor.of(CharSequence.class))) { return FieldType.STRING; } else if (type.isSubtypeOf(TypeDescriptor.of(ReadableInstant.class))) { @@ -180,26 +185,22 @@ private static Schema.FieldType fieldFromType( } else if (type.isSubtypeOf(TypeDescriptor.of(ByteBuffer.class))) { return FieldType.BYTES; } else if (type.isSubtypeOf(TypeDescriptor.of(Iterable.class))) { - TypeDescriptor> iterable = type.getSupertype(Iterable.class); - if (iterable.getType() instanceof ParameterizedType) { - ParameterizedType ptype = (ParameterizedType) iterable.getType(); - java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - checkArgument(params.length == 1); - // TODO: should this be AbstractCollection? - if (type.isSubtypeOf(TypeDescriptor.of(Collection.class))) { - return FieldType.array( - fieldFromType( - TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas)); - } else { - return FieldType.iterable( - fieldFromType( - TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas)); - } + FieldType elementType = + fieldFromType( + Preconditions.checkArgumentNotNull( + ReflectUtils.getIterableComponentType(type, boundTypes)), + fieldValueTypeSupplier, + alreadyVisitedSchemas, + boundTypes); + // TODO: should this be AbstractCollection? + if (type.isSubtypeOf(TypeDescriptor.of(Collection.class))) { + return FieldType.array(elementType); } else { - throw new RuntimeException("Cannot infer schema from unparameterized collection."); + return FieldType.iterable(elementType); } } else { - return FieldType.row(schemaFromClass(type, fieldValueTypeSupplier, alreadyVisitedSchemas)); + return FieldType.row( + schemaFromClass(type, fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes)); } } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java index d0ee623dea7c..49fd2bfe2259 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java @@ -28,6 +28,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.Map; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; @@ -39,6 +40,7 @@ import org.apache.beam.sdk.schemas.utils.SchemaTestUtils; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; import org.joda.time.DateTime; import org.joda.time.Instant; @@ -886,4 +888,151 @@ public void testSchema_SchemaFieldDescription() throws NoSuchSchemaException { assertEquals(FIELD_DESCRIPTION_SCHEMA.getField("lng"), schema.getField("lng")); assertEquals(FIELD_DESCRIPTION_SCHEMA.getField("str"), schema.getField("str")); } + + @AutoValue + @DefaultSchema(AutoValueSchema.class) + abstract static class ParameterizedAutoValue { + abstract W getValue1(); + + abstract T getValue2(); + + abstract V getValue3(); + + abstract X getValue4(); + } + + @Test + public void testAutoValueWithTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> typeDescriptor = + new TypeDescriptor>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @DefaultSchema(AutoValueSchema.class) + abstract static class ParameterizedAutoValueSubclass + extends ParameterizedAutoValue { + abstract T getValue5(); + } + + @Test + public void testAutoValueWithInheritedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> typeDescriptor = + new TypeDescriptor>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .addInt16Field("value5") + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @AutoValue + @DefaultSchema(AutoValueSchema.class) + abstract static class NestedParameterizedCollectionAutoValue { + abstract Iterable getNested(); + + abstract Map getMap(); + } + + @Test + public void testAutoValueWithNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + NestedParameterizedCollectionAutoValue< + ParameterizedAutoValue, String>> + typeDescriptor = + new TypeDescriptor< + NestedParameterizedCollectionAutoValue< + ParameterizedAutoValue, String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField("nested", FieldType.row(expectedInnerSchema)) + .addMapField("map", FieldType.STRING, FieldType.row(expectedInnerSchema)) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testAutoValueWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + NestedParameterizedCollectionAutoValue< + Iterable>, String>> + typeDescriptor = + new TypeDescriptor< + NestedParameterizedCollectionAutoValue< + Iterable>, + String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField("nested", FieldType.iterable(FieldType.row(expectedInnerSchema))) + .addMapField( + "map", FieldType.STRING, FieldType.iterable(FieldType.row(expectedInnerSchema))) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @AutoValue + @DefaultSchema(AutoValueSchema.class) + abstract static class NestedParameterizedAutoValue { + abstract T getNested(); + } + + @Test + public void testAutoValueWithNestedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + NestedParameterizedAutoValue< + ParameterizedAutoValue>> + typeDescriptor = + new TypeDescriptor< + NestedParameterizedAutoValue< + ParameterizedAutoValue>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder().addRowField("nested", expectedInnerSchema).build(); + assertTrue(expectedSchema.equivalent(schema)); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java index 5313feb5c6c0..4cfc64f2f722 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java @@ -68,6 +68,7 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBeanWithAnnotations; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -625,4 +626,41 @@ public void testSetterConstructionWithRenamedFields() throws NoSuchSchemaExcepti assertEquals( registry.getFromRowFunction(BeanWithCaseFormat.class).apply(row), beanWithCaseFormat); } + + @Test + public void testRegisterBeamWithTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> + typeDescriptor = + new TypeDescriptor< + TestJavaBeans.SimpleParameterizedBean>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_BEAN_SCHEMA) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testRegisterBeanWithInheritedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> typeDescriptor = + new TypeDescriptor>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_BEAN_SCHEMA) + .addInt16Field("value5") + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index 11bef79b26f7..70bc3030924b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -76,6 +76,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -781,4 +782,123 @@ public void testCircularNestedPOJOThrows() throws NoSuchSchemaException { thrown.getMessage(), containsString("TestPOJOs$FirstCircularNestedPOJO")); } + + @Test + public void testPojoWithTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> + typeDescriptor = + new TypeDescriptor< + TestPOJOs.SimpleParameterizedPOJO>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testPojoWithInheritedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor> typeDescriptor = + new TypeDescriptor>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .addInt16Field("value5") + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testPojoWithNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + TestPOJOs.NestedParameterizedCollectionPOJO< + TestPOJOs.SimpleParameterizedPOJO, String>> + typeDescriptor = + new TypeDescriptor< + TestPOJOs.NestedParameterizedCollectionPOJO< + TestPOJOs.SimpleParameterizedPOJO, + String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField("nested", FieldType.row(expectedInnerSchema)) + .addMapField("map", FieldType.STRING, FieldType.row(expectedInnerSchema)) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testPojoWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + TestPOJOs.NestedParameterizedCollectionPOJO< + Iterable>, + String>> + typeDescriptor = + new TypeDescriptor< + TestPOJOs.NestedParameterizedCollectionPOJO< + Iterable>, + String>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder() + .addIterableField("nested", FieldType.iterable(FieldType.row(expectedInnerSchema))) + .addMapField( + "map", FieldType.STRING, FieldType.iterable(FieldType.row(expectedInnerSchema))) + .build(); + assertTrue(expectedSchema.equivalent(schema)); + } + + @Test + public void testPojoWithNestedTypeParameter() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + TypeDescriptor< + TestPOJOs.NestedParameterizedPOJO< + TestPOJOs.SimpleParameterizedPOJO>> + typeDescriptor = + new TypeDescriptor< + TestPOJOs.NestedParameterizedPOJO< + TestPOJOs.SimpleParameterizedPOJO>>() {}; + Schema schema = registry.getSchema(typeDescriptor); + + final Schema expectedInnerSchema = + Schema.builder() + .addBooleanField("value1") + .addStringField("value2") + .addInt64Field("value3") + .addRowField("value4", SIMPLE_POJO_SCHEMA) + .build(); + final Schema expectedSchema = + Schema.builder().addRowField("nested", expectedInnerSchema).build(); + assertTrue(expectedSchema.equivalent(schema)); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java index 021e39b84849..e0a45c2c82fe 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java @@ -34,6 +34,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.List; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; @@ -65,7 +66,9 @@ public class JavaBeanUtilsTest { public void testNullable() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); assertTrue(schema.getField("str").getType().getNullable()); assertFalse(schema.getField("anInt").getType().getNullable()); } @@ -74,7 +77,9 @@ public void testNullable() { public void testSimpleBean() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(SIMPLE_BEAN_SCHEMA, schema); } @@ -82,7 +87,9 @@ public void testSimpleBean() { public void testNestedBean() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_BEAN_SCHEMA, schema); } @@ -90,7 +97,9 @@ public void testNestedBean() { public void testPrimitiveArray() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_ARRAY_BEAN_SCHEMA, schema); } @@ -98,7 +107,9 @@ public void testPrimitiveArray() { public void testNestedArray() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_ARRAY_BEAN_SCHEMA, schema); } @@ -106,7 +117,9 @@ public void testNestedArray() { public void testNestedCollection() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_COLLECTION_BEAN_SCHEMA, schema); } @@ -114,7 +127,9 @@ public void testNestedCollection() { public void testPrimitiveMap() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_MAP_BEAN_SCHEMA, schema); } @@ -122,7 +137,9 @@ public void testPrimitiveMap() { public void testNestedMap() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + GetterTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_MAP_BEAN_SCHEMA, schema); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java index 723353ed8d15..46c098dddaeb 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java @@ -35,6 +35,7 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.List; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; @@ -71,7 +72,9 @@ public class POJOUtilsTest { public void testNullables() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); assertTrue(schema.getField("str").getType().getNullable()); assertFalse(schema.getField("anInt").getType().getNullable()); } @@ -80,7 +83,9 @@ public void testNullables() { public void testSimplePOJO() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); assertEquals(SIMPLE_POJO_SCHEMA, schema); } @@ -88,7 +93,9 @@ public void testSimplePOJO() { public void testNestedPOJO() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_SCHEMA, schema); } @@ -97,7 +104,8 @@ public void testNestedPOJOWithSimplePOJO() { Schema schema = POJOUtils.schemaFromPojoClass( new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE); + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_WITH_SIMPLE_POJO_SCHEMA, schema); } @@ -105,7 +113,9 @@ public void testNestedPOJOWithSimplePOJO() { public void testPrimitiveArray() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_ARRAY_POJO_SCHEMA, schema); } @@ -113,7 +123,9 @@ public void testPrimitiveArray() { public void testNestedArray() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_ARRAY_POJO_SCHEMA, schema); } @@ -121,7 +133,9 @@ public void testNestedArray() { public void testNestedCollection() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_COLLECTION_POJO_SCHEMA, schema); } @@ -129,7 +143,9 @@ public void testNestedCollection() { public void testPrimitiveMap() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_MAP_POJO_SCHEMA, schema); } @@ -137,7 +153,9 @@ public void testPrimitiveMap() { public void testNestedMap() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); + new TypeDescriptor() {}, + JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap()); SchemaTestUtils.assertSchemaEquivalent(NESTED_MAP_POJO_SCHEMA, schema); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java index b5ad6f989d9e..9d11fce34148 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java @@ -1397,4 +1397,60 @@ public void setValue(@Nullable Float value) { Schema.Field.nullable("value", FieldType.FLOAT) .withDescription("This value is the value stored in the object as a float.")) .build(); + + @DefaultSchema(JavaBeanSchema.class) + public static class SimpleParameterizedBean { + @Nullable private W value1; + @Nullable private T value2; + @Nullable private V value3; + @Nullable private X value4; + + public W getValue1() { + return value1; + } + + public void setValue1(W value1) { + this.value1 = value1; + } + + public T getValue2() { + return value2; + } + + public void setValue2(T value2) { + this.value2 = value2; + } + + public V getValue3() { + return value3; + } + + public void setValue3(V value3) { + this.value3 = value3; + } + + public X getValue4() { + return value4; + } + + public void setValue4(X value4) { + this.value4 = value4; + } + } + + @DefaultSchema(JavaBeanSchema.class) + public static class SimpleParameterizedBeanSubclass + extends SimpleParameterizedBean { + @Nullable private T value5; + + public SimpleParameterizedBeanSubclass() {} + + public T getValue5() { + return value5; + } + + public void setValue5(T value5) { + this.value5 = value5; + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java index 789de02adee8..ce7409365d09 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java @@ -495,6 +495,125 @@ public int hashCode() { .addStringField("stringBuilder") .build(); + @DefaultSchema(JavaFieldSchema.class) + public static class SimpleParameterizedPOJO { + public W value1; + public T value2; + public V value3; + public X value4; + + public SimpleParameterizedPOJO() {} + + public SimpleParameterizedPOJO(W value1, T value2, V value3, X value4) { + this.value1 = value1; + this.value2 = value2; + this.value3 = value3; + this.value4 = value4; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SimpleParameterizedPOJO)) { + return false; + } + SimpleParameterizedPOJO that = (SimpleParameterizedPOJO) o; + return Objects.equals(value1, that.value1) + && Objects.equals(value2, that.value2) + && Objects.equals(value3, that.value3) + && Objects.equals(value4, that.value4); + } + + @Override + public int hashCode() { + return Objects.hash(value1, value2, value3, value4); + } + } + + @DefaultSchema(JavaFieldSchema.class) + public static class SimpleParameterizedPOJOSubclass + extends SimpleParameterizedPOJO { + public T value5; + + public SimpleParameterizedPOJOSubclass() {} + + public SimpleParameterizedPOJOSubclass(T value5) { + this.value5 = value5; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SimpleParameterizedPOJOSubclass)) { + return false; + } + SimpleParameterizedPOJOSubclass that = (SimpleParameterizedPOJOSubclass) o; + return Objects.equals(value5, that.value5); + } + + @Override + public int hashCode() { + return Objects.hash(value4); + } + } + + @DefaultSchema(JavaFieldSchema.class) + public static class NestedParameterizedCollectionPOJO { + public Iterable nested; + public Map map; + + public NestedParameterizedCollectionPOJO(Iterable nested, Map map) { + this.nested = nested; + this.map = map; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof NestedParameterizedCollectionPOJO)) { + return false; + } + NestedParameterizedCollectionPOJO that = (NestedParameterizedCollectionPOJO) o; + return Objects.equals(nested, that.nested) && Objects.equals(map, that.map); + } + + @Override + public int hashCode() { + return Objects.hash(nested, map); + } + } + + @DefaultSchema(JavaFieldSchema.class) + public static class NestedParameterizedPOJO { + public T nested; + + public NestedParameterizedPOJO(T nested) { + this.nested = nested; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof NestedParameterizedPOJO)) { + return false; + } + NestedParameterizedPOJO that = (NestedParameterizedPOJO) o; + return Objects.equals(nested, that.nested); + } + + @Override + public int hashCode() { + return Objects.hash(nested); + } + } /** A POJO containing a nested class. * */ @DefaultSchema(JavaFieldSchema.class) public static class NestedPOJO { @@ -887,7 +1006,7 @@ public boolean equals(@Nullable Object o) { if (this == o) { return true; } - if (!(o instanceof PojoWithNestedArray)) { + if (!(o instanceof PojoWithIterable)) { return false; } PojoWithIterable that = (PojoWithIterable) o; diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java index 0a82663c1771..1a530a3f6ca5 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java @@ -78,8 +78,8 @@ private static SchemaUserTypeCreator createCreator(Class clazz, Schema sc // Generate a method call to create and invoke the SpecificRecord's constructor. . MethodCall construct = MethodCall.construct(baseConstructor); - for (int i = 0; i < baseConstructor.getParameterTypes().length; ++i) { - Class baseType = baseConstructor.getParameterTypes()[i]; + for (int i = 0; i < baseConstructor.getGenericParameterTypes().length; ++i) { + Type baseType = baseConstructor.getGenericParameterTypes()[i]; construct = construct.with(readAndConvertParameter(baseType, i), baseType); } @@ -110,7 +110,7 @@ private static SchemaUserTypeCreator createCreator(Class clazz, Schema sc } private static StackManipulation readAndConvertParameter( - Class constructorParameterType, int index) { + Type constructorParameterType, int index) { TypeConversionsFactory typeConversionsFactory = new AvroUtils.AvroTypeConversionFactory(); // The types in the AVRO-generated constructor might be the types returned by Beam's Row class, diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java index 1b1c45969307..1324d254e44e 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -814,6 +814,9 @@ public List get(TypeDescriptor typeDescriptor) { @Override public List get(TypeDescriptor typeDescriptor, Schema schema) { + Map boundTypes = + ReflectUtils.getAllBoundTypes(typeDescriptor); + Map mapping = getMapping(schema); List methods = ReflectUtils.getMethods(typeDescriptor.getRawType()); List types = Lists.newArrayList(); @@ -821,7 +824,7 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = methods.get(i); if (ReflectUtils.isGetter(method)) { FieldValueTypeInformation fieldValueTypeInformation = - FieldValueTypeInformation.forGetter(method, i); + FieldValueTypeInformation.forGetter(method, i, boundTypes); String name = mapping.get(fieldValueTypeInformation.getName()); if (name != null) { types.add(fieldValueTypeInformation.withName(name)); @@ -865,13 +868,16 @@ private Map getMapping(Schema schema) { private static final class AvroPojoFieldValueTypeSupplier implements FieldValueTypeSupplier { @Override public List get(TypeDescriptor typeDescriptor) { + Map boundTypes = + ReflectUtils.getAllBoundTypes(typeDescriptor); List classFields = ReflectUtils.getFields(typeDescriptor.getRawType()); Map types = Maps.newHashMap(); for (int i = 0; i < classFields.size(); ++i) { java.lang.reflect.Field f = classFields.get(i); if (!f.isAnnotationPresent(AvroIgnore.class)) { - FieldValueTypeInformation typeInformation = FieldValueTypeInformation.forField(f, i); + FieldValueTypeInformation typeInformation = + FieldValueTypeInformation.forField(f, i, boundTypes); AvroName avroname = f.getAnnotation(AvroName.class); if (avroname != null) { typeInformation = typeInformation.withName(avroname.value()); diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java index d159e9de44a8..fcfc40403b43 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -39,6 +39,7 @@ import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.Arrays; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -1045,7 +1046,8 @@ FieldValueSetter getProtoFieldValueSetter( } else { Method method = getProtoSetter(methods, field.getName(), field.getType()); return JavaBeanUtils.createSetter( - FieldValueTypeInformation.forSetter(method, protoSetterPrefix(field.getType())), + FieldValueTypeInformation.forSetter( + method, protoSetterPrefix(field.getType()), Collections.emptyMap()), new ProtoTypeConversionsFactory()); } } diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java index faf3ad407af5..4b8d51abdea6 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java @@ -23,6 +23,7 @@ import com.google.protobuf.DynamicMessage; import com.google.protobuf.Message; import java.lang.reflect.Method; +import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.beam.sdk.extensions.protobuf.ProtoByteBuddyUtils.ProtoTypeConversionsFactory; @@ -72,7 +73,8 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = getProtoGetter(methods, oneOfField.getName(), oneOfField.getType()); oneOfTypes.put( oneOfField.getName(), - FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + FieldValueTypeInformation.forGetter(method, i, Collections.emptyMap()) + .withName(field.getName())); } // Add an entry that encapsulates information about all possible getters. types.add( @@ -82,7 +84,9 @@ public List get(TypeDescriptor typeDescriptor, Sch } else { // This is a simple field. Add the getter. Method method = getProtoGetter(methods, field.getName(), field.getType()); - types.add(FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + types.add( + FieldValueTypeInformation.forGetter(method, i, Collections.emptyMap()) + .withName(field.getName())); } } return types; diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java index d5f1745a9a2c..64f600903d87 100644 --- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java +++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java @@ -25,6 +25,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -389,7 +390,8 @@ private Schema generateSchemaDirectly( fieldName, StaticSchemaInference.fieldFromType( TypeDescriptor.of(field.getClass()), - JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE)); + JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE, + Collections.emptyMap())); } counter++; diff --git a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java index 5f4e195f227f..73b3709da832 100644 --- a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java +++ b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java @@ -242,10 +242,11 @@ private FieldValueTypeInformation fieldValueTypeInfo(Class type, String field if (factoryMethods.size() > 1) { throw new IllegalStateException("Overloaded factory methods: " + factoryMethods); } - return FieldValueTypeInformation.forSetter(factoryMethods.get(0), ""); + return FieldValueTypeInformation.forSetter(factoryMethods.get(0), "", Collections.emptyMap()); } else { try { - return FieldValueTypeInformation.forField(type.getDeclaredField(fieldName), 0); + return FieldValueTypeInformation.forField( + type.getDeclaredField(fieldName), 0, Collections.emptyMap()); } catch (NoSuchFieldException e) { throw new IllegalArgumentException(e); }