Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schema inference parameterized types #32757

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.schemas.annotations.SchemaIgnore;
import org.apache.beam.sdk.schemas.utils.AutoValueUtils;
Expand Down Expand Up @@ -61,8 +63,9 @@ public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.collect(Collectors.toList());
List<FieldValueTypeInformation> types = Lists.newArrayListWithCapacity(methods.size());
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
for (int i = 0; i < methods.size(); ++i) {
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i));
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes));
}
types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber));
validateFieldNumbers(types);
Expand Down Expand Up @@ -143,7 +146,8 @@ public SchemaUserTypeCreator schemaTypeCreator(

@Override
public <T> @Nullable Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
return JavaBeanUtils.schemaFromJavaBeanClass(
typeDescriptor, AbstractGetterTypeSupplier.INSTANCE);
typeDescriptor, AbstractGetterTypeSupplier.INSTANCE, boundTypes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import java.lang.reflect.Field;
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.stream.Stream;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldName;
Expand All @@ -44,6 +46,7 @@
"nullness", // TODO(https://github.com/apache/beam/issues/20497)
"rawtypes"
})
@Internal
public abstract class FieldValueTypeInformation implements Serializable {
/** Optionally returns the field index. */
public abstract @Nullable Integer getNumber();
Expand Down Expand Up @@ -125,18 +128,20 @@ public static FieldValueTypeInformation forOneOf(
.build();
}

public static FieldValueTypeInformation forField(Field field, int index) {
TypeDescriptor<?> type = TypeDescriptor.of(field.getGenericType());
public static FieldValueTypeInformation forField(
Field field, int index, Map<Type, Type> boundTypes) {
ahmedabu98 marked this conversation as resolved.
Show resolved Hide resolved
TypeDescriptor<?> type =
TypeDescriptor.of(ReflectUtils.resolveType(field.getGenericType(), boundTypes));
return new AutoValue_FieldValueTypeInformation.Builder()
.setName(getNameOverride(field.getName(), field))
.setNumber(getNumberOverride(index, field))
.setNullable(hasNullableAnnotation(field))
.setType(type)
.setRawType(type.getRawType())
.setField(field)
.setElementType(getIterableComponentType(field))
.setMapKeyType(getMapKeyType(field))
.setMapValueType(getMapValueType(field))
.setElementType(getIterableComponentType(field, boundTypes))
.setMapKeyType(getMapKeyType(field, boundTypes))
.setMapValueType(getMapValueType(field, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.setDescription(getFieldDescription(field))
.build();
Expand Down Expand Up @@ -184,7 +189,8 @@ public static <T extends AnnotatedElement & Member> String getNameOverride(
return fieldDescription.value();
}

public static FieldValueTypeInformation forGetter(Method method, int index) {
public static FieldValueTypeInformation forGetter(
Method method, int index, Map<Type, Type> boundTypes) {
String name;
if (method.getName().startsWith("get")) {
name = ReflectUtils.stripPrefix(method.getName(), "get");
Expand All @@ -194,7 +200,8 @@ public static FieldValueTypeInformation forGetter(Method method, int index) {
throw new RuntimeException("Getter has wrong prefix " + method.getName());
}

TypeDescriptor<?> type = TypeDescriptor.of(method.getGenericReturnType());
TypeDescriptor<?> type =
TypeDescriptor.of(ReflectUtils.resolveType(method.getGenericReturnType(), boundTypes));
boolean nullable = hasNullableReturnType(method);
return new AutoValue_FieldValueTypeInformation.Builder()
.setName(getNameOverride(name, method))
Expand All @@ -203,9 +210,9 @@ public static FieldValueTypeInformation forGetter(Method method, int index) {
.setType(type)
.setRawType(type.getRawType())
.setMethod(method)
.setElementType(getIterableComponentType(type))
.setMapKeyType(getMapKeyType(type))
.setMapValueType(getMapValueType(type))
.setElementType(getIterableComponentType(type, boundTypes))
.setMapKeyType(getMapKeyType(type, boundTypes))
.setMapValueType(getMapValueType(type, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.setDescription(getFieldDescription(method))
.build();
Expand Down Expand Up @@ -252,29 +259,33 @@ private static boolean isNullableAnnotation(Annotation annotation) {
return annotation.annotationType().getSimpleName().equals("Nullable");
}

public static FieldValueTypeInformation forSetter(Method method) {
return forSetter(method, "set");
public static FieldValueTypeInformation forSetter(
Method method, Map<Type, Type> boundParameters) {
return forSetter(method, "set", boundParameters);
}

public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) {
public static FieldValueTypeInformation forSetter(
Method method, String setterPrefix, Map<Type, Type> boundTypes) {
String name;
if (method.getName().startsWith(setterPrefix)) {
name = ReflectUtils.stripPrefix(method.getName(), setterPrefix);
} else {
throw new RuntimeException("Setter has wrong prefix " + method.getName());
}

TypeDescriptor<?> type = TypeDescriptor.of(method.getGenericParameterTypes()[0]);
TypeDescriptor<?> type =
TypeDescriptor.of(
ReflectUtils.resolveType(method.getGenericParameterTypes()[0], boundTypes));
boolean nullable = hasSingleNullableParameter(method);
return new AutoValue_FieldValueTypeInformation.Builder()
.setName(name)
.setNullable(nullable)
.setType(type)
.setRawType(type.getRawType())
.setMethod(method)
.setElementType(getIterableComponentType(type))
.setMapKeyType(getMapKeyType(type))
.setMapValueType(getMapValueType(type))
.setElementType(getIterableComponentType(type, boundTypes))
.setMapKeyType(getMapKeyType(type, boundTypes))
.setMapValueType(getMapValueType(type, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.build();
}
Expand All @@ -283,13 +294,15 @@ public FieldValueTypeInformation withName(String name) {
return toBuilder().setName(name).build();
}

private static FieldValueTypeInformation getIterableComponentType(Field field) {
return getIterableComponentType(TypeDescriptor.of(field.getGenericType()));
private static FieldValueTypeInformation getIterableComponentType(
Field field, Map<Type, Type> boundTypes) {
return getIterableComponentType(TypeDescriptor.of(field.getGenericType()), boundTypes);
}

static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor<?> valueType) {
static @Nullable FieldValueTypeInformation getIterableComponentType(
TypeDescriptor<?> valueType, Map<Type, Type> boundTypes) {
// TODO: Figure out nullable elements.
TypeDescriptor<?> componentType = ReflectUtils.getIterableComponentType(valueType);
TypeDescriptor<?> componentType = ReflectUtils.getIterableComponentType(valueType, boundTypes);
if (componentType == null) {
return null;
}
Expand All @@ -299,41 +312,43 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) {
.setNullable(false)
.setType(componentType)
.setRawType(componentType.getRawType())
.setElementType(getIterableComponentType(componentType))
.setMapKeyType(getMapKeyType(componentType))
.setMapValueType(getMapValueType(componentType))
.setElementType(getIterableComponentType(componentType, boundTypes))
.setMapKeyType(getMapKeyType(componentType, boundTypes))
.setMapValueType(getMapValueType(componentType, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.build();
}

// If the Field is a map type, returns the key type, otherwise returns a null reference.

private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) {
return getMapKeyType(TypeDescriptor.of(field.getGenericType()));
private static @Nullable FieldValueTypeInformation getMapKeyType(
Field field, Map<Type, Type> boundTypes) {
return getMapKeyType(TypeDescriptor.of(field.getGenericType()), boundTypes);
}

private static @Nullable FieldValueTypeInformation getMapKeyType(
TypeDescriptor<?> typeDescriptor) {
return getMapType(typeDescriptor, 0);
TypeDescriptor<?> typeDescriptor, Map<Type, Type> boundTypes) {
return getMapType(typeDescriptor, 0, boundTypes);
}

// If the Field is a map type, returns the value type, otherwise returns a null reference.

private static @Nullable FieldValueTypeInformation getMapValueType(Field field) {
return getMapType(TypeDescriptor.of(field.getGenericType()), 1);
private static @Nullable FieldValueTypeInformation getMapValueType(
Field field, Map<Type, Type> boundTypes) {
return getMapType(TypeDescriptor.of(field.getGenericType()), 1, boundTypes);
}

private static @Nullable FieldValueTypeInformation getMapValueType(
TypeDescriptor<?> typeDescriptor) {
return getMapType(typeDescriptor, 1);
TypeDescriptor<?> typeDescriptor, Map<Type, Type> boundTypes) {
return getMapType(typeDescriptor, 1, boundTypes);
}

// If the Field is a map type, returns the key or value type (0 is key type, 1 is value).
// Otherwise returns a null reference.
@SuppressWarnings("unchecked")
private static @Nullable FieldValueTypeInformation getMapType(
TypeDescriptor<?> valueType, int index) {
TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index);
TypeDescriptor<?> valueType, int index, Map<Type, Type> boundTypes) {
TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index, boundTypes);
if (mapType == null) {
return null;
}
Expand All @@ -342,9 +357,9 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) {
.setNullable(false)
.setType(mapType)
.setRawType(mapType.getRawType())
.setElementType(getIterableComponentType(mapType))
.setMapKeyType(getMapKeyType(mapType))
.setMapValueType(getMapValueType(mapType))
.setElementType(getIterableComponentType(mapType, boundTypes))
.setMapKeyType(getMapKeyType(mapType, boundTypes))
.setMapValueType(getMapValueType(mapType, boundTypes))
.setOneOfTypes(Collections.emptyMap())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldName;
Expand Down Expand Up @@ -67,8 +69,9 @@ public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.collect(Collectors.toList());
List<FieldValueTypeInformation> types = Lists.newArrayListWithCapacity(methods.size());
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
for (int i = 0; i < methods.size(); ++i) {
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i));
types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes));
}
types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber));
validateFieldNumbers(types);
Expand Down Expand Up @@ -111,10 +114,11 @@ public static class SetterTypeSupplier implements FieldValueTypeSupplier {

@Override
public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream()
.filter(ReflectUtils::isSetter)
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.map(FieldValueTypeInformation::forSetter)
.map(m -> FieldValueTypeInformation.forSetter(m, boundTypes))
.map(
t -> {
if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) {
Expand Down Expand Up @@ -156,8 +160,10 @@ public boolean equals(@Nullable Object obj) {

@Override
public <T> Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
Schema schema =
JavaBeanUtils.schemaFromJavaBeanClass(typeDescriptor, GetterTypeSupplier.INSTANCE);
JavaBeanUtils.schemaFromJavaBeanClass(
typeDescriptor, GetterTypeSupplier.INSTANCE, boundTypes);

// If there are no creator methods, then validate that we have setters for every field.
// Otherwise, we will have no way of creating instances of the class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -62,9 +64,11 @@ public List<FieldValueTypeInformation> get(TypeDescriptor<?> typeDescriptor) {
ReflectUtils.getFields(typeDescriptor.getRawType()).stream()
.filter(m -> !m.isAnnotationPresent(SchemaIgnore.class))
.collect(Collectors.toList());

List<FieldValueTypeInformation> types = Lists.newArrayListWithCapacity(fields.size());
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
for (int i = 0; i < fields.size(); ++i) {
types.add(FieldValueTypeInformation.forField(fields.get(i), i));
types.add(FieldValueTypeInformation.forField(fields.get(i), i, boundTypes));
}
types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber));
validateFieldNumbers(types);
Expand Down Expand Up @@ -111,7 +115,9 @@ private static void validateFieldNumbers(List<FieldValueTypeInformation> types)

@Override
public <T> Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
return POJOUtils.schemaFromPojoClass(typeDescriptor, JavaFieldTypeSupplier.INSTANCE);
Map<Type, Type> boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor);
return POJOUtils.schemaFromPojoClass(
typeDescriptor, JavaFieldTypeSupplier.INSTANCE, boundTypes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ public interface SchemaProvider extends Serializable {
* Given a type, return a function that converts that type to a {@link Row} object If no schema
* exists, returns null.
*/
@Nullable
<T> SerializableFunction<T, Row> toRowFunction(TypeDescriptor<T> typeDescriptor);
<T> @Nullable SerializableFunction<T, Row> toRowFunction(TypeDescriptor<T> typeDescriptor);

/**
* Given a type, returns a function that converts from a {@link Row} object to that type. If no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid
providers.put(typeDescriptor, schemaProvider);
}

@Override
public <T> @Nullable Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
private <T> @Nullable SchemaProvider schemaProviderFor(TypeDescriptor<T> typeDescriptor) {
TypeDescriptor<?> type = typeDescriptor;
do {
SchemaProvider schemaProvider = providers.get(type);
if (schemaProvider != null) {
return schemaProvider.schemaFor(type);
return schemaProvider;
}
Class<?> superClass = type.getRawType().getSuperclass();
if (superClass == null || superClass.equals(Object.class)) {
Expand All @@ -92,38 +91,24 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid
} while (true);
}

@Override
public <T> @Nullable Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
@Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor);
return schemaProvider != null ? schemaProvider.schemaFor(typeDescriptor) : null;
}

@Override
public <T> @Nullable SerializableFunction<T, Row> toRowFunction(
TypeDescriptor<T> typeDescriptor) {
TypeDescriptor<?> type = typeDescriptor;
do {
SchemaProvider schemaProvider = providers.get(type);
if (schemaProvider != null) {
return (SerializableFunction<T, Row>) schemaProvider.toRowFunction(type);
}
Class<?> superClass = type.getRawType().getSuperclass();
if (superClass == null || superClass.equals(Object.class)) {
return null;
}
type = TypeDescriptor.of(superClass);
} while (true);
@Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor);
return schemaProvider != null ? schemaProvider.toRowFunction(typeDescriptor) : null;
}

@Override
public <T> @Nullable SerializableFunction<Row, T> fromRowFunction(
TypeDescriptor<T> typeDescriptor) {
TypeDescriptor<?> type = typeDescriptor;
do {
SchemaProvider schemaProvider = providers.get(type);
if (schemaProvider != null) {
return (SerializableFunction<Row, T>) schemaProvider.fromRowFunction(type);
}
Class<?> superClass = type.getRawType().getSuperclass();
if (superClass == null || superClass.equals(Object.class)) {
return null;
}
type = TypeDescriptor.of(superClass);
} while (true);
@Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor);
return schemaProvider != null ? schemaProvider.fromRowFunction(typeDescriptor) : null;
}
}

Expand Down
Loading
Loading