From 4f4853e2a43354299e8a0c5822f00b3888cb485d Mon Sep 17 00:00:00 2001 From: pablo rodriguez defino Date: Tue, 22 Oct 2024 14:48:20 -0700 Subject: [PATCH] Support Map in BQ for StorageWrites API for Beam Rows (#32512) --- .../bigquery/BeamRowToStorageApiProto.java | 84 ++++++++-- .../BeamRowToStorageApiProtoTest.java | 152 +++++++++++++++++- .../io/gcp/bigquery/BigQueryUtilsTest.java | 12 ++ 3 files changed, 229 insertions(+), 19 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java index 7a5aa2408d2e..d7ca787feea3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java @@ -31,7 +31,6 @@ import java.time.LocalTime; import java.time.temporal.ChronoUnit; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; @@ -221,11 +220,18 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) { case ITERABLE: @Nullable FieldType elementType = field.getType().getCollectionElementType(); if (elementType == null) { - throw new RuntimeException("Unexpected null element type!"); + throw new RuntimeException("Unexpected null element type on " + field.getName()); } + TypeName containedTypeName = + Preconditions.checkNotNull( + elementType.getTypeName(), + "Null type name found in contained type at " + field.getName()); Preconditions.checkState( - !Preconditions.checkNotNull(elementType.getTypeName()).isCollectionType(), - "Nested arrays not supported by BigQuery."); + !(containedTypeName.isCollectionType() || containedTypeName.isMapType()), + "Nested container types are not supported by BigQuery. Field " + + field.getName() + + " contains a type " + + containedTypeName.name()); TableFieldSchema elementFieldSchema = fieldDescriptorFromBeamField(Field.of(field.getName(), elementType)); builder = builder.setType(elementFieldSchema.getType()); @@ -244,7 +250,24 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) { builder = builder.setType(type); break; case MAP: - throw new RuntimeException("Map types not supported by BigQuery."); + @Nullable FieldType keyType = field.getType().getMapKeyType(); + @Nullable FieldType valueType = field.getType().getMapValueType(); + if (keyType == null) { + throw new RuntimeException( + "Unexpected null element type for the map's key on " + field.getName()); + } + if (valueType == null) { + throw new RuntimeException( + "Unexpected null element type for the map's value on " + field.getName()); + } + + builder = + builder + .setType(TableFieldSchema.Type.STRUCT) + .addFields(fieldDescriptorFromBeamField(Field.of("key", keyType))) + .addFields(fieldDescriptorFromBeamField(Field.of("value", valueType))) + .setMode(TableFieldSchema.Mode.REPEATED); + break; default: @Nullable TableFieldSchema.Type primitiveType = PRIMITIVE_TYPES.get(field.getType().getTypeName()); @@ -289,25 +312,34 @@ private static Object toProtoValue( case ROW: return messageFromBeamRow(fieldDescriptor.getMessageType(), (Row) value, null, -1); case ARRAY: - List list = (List) value; - @Nullable FieldType arrayElementType = beamFieldType.getCollectionElementType(); - if (arrayElementType == null) { - throw new RuntimeException("Unexpected null element type!"); - } - return list.stream() - .map(v -> toProtoValue(fieldDescriptor, arrayElementType, v)) - .collect(Collectors.toList()); case ITERABLE: Iterable iterable = (Iterable) value; @Nullable FieldType iterableElementType = beamFieldType.getCollectionElementType(); if (iterableElementType == null) { - throw new RuntimeException("Unexpected null element type!"); + throw new RuntimeException("Unexpected null element type: " + fieldDescriptor.getName()); } + return StreamSupport.stream(iterable.spliterator(), false) .map(v -> toProtoValue(fieldDescriptor, iterableElementType, v)) .collect(Collectors.toList()); case MAP: - throw new RuntimeException("Map types not supported by BigQuery."); + Map map = (Map) value; + @Nullable FieldType keyType = beamFieldType.getMapKeyType(); + @Nullable FieldType valueType = beamFieldType.getMapValueType(); + if (keyType == null) { + throw new RuntimeException("Unexpected null for key type: " + fieldDescriptor.getName()); + } + if (valueType == null) { + throw new RuntimeException( + "Unexpected null for value type: " + fieldDescriptor.getName()); + } + + return map.entrySet().stream() + .map( + (Map.Entry entry) -> + mapEntryToProtoValue( + fieldDescriptor.getMessageType(), keyType, valueType, entry)) + .collect(Collectors.toList()); default: return scalarToProtoValue(beamFieldType, value); } @@ -337,6 +369,28 @@ static Object scalarToProtoValue(FieldType beamFieldType, Object value) { } } + static Object mapEntryToProtoValue( + Descriptor descriptor, + FieldType keyFieldType, + FieldType valueFieldType, + Map.Entry entryValue) { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + FieldDescriptor keyFieldDescriptor = + Preconditions.checkNotNull(descriptor.findFieldByName("key")); + @Nullable Object key = toProtoValue(keyFieldDescriptor, keyFieldType, entryValue.getKey()); + if (key != null) { + builder.setField(keyFieldDescriptor, key); + } + FieldDescriptor valueFieldDescriptor = + Preconditions.checkNotNull(descriptor.findFieldByName("value")); + @Nullable + Object value = toProtoValue(valueFieldDescriptor, valueFieldType, entryValue.getValue()); + if (value != null) { + builder.setField(valueFieldDescriptor, value); + } + return builder.build(); + } + static ByteString serializeBigDecimalToNumeric(BigDecimal o) { return serializeBigDecimal(o, NUMERIC_SCALE, MAX_NUMERIC_VALUE, MIN_NUMERIC_VALUE, "Numeric"); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java index 4013f0018553..d8c580a0cd18 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import com.google.protobuf.ByteString; import com.google.protobuf.DescriptorProtos.DescriptorProto; @@ -36,8 +37,11 @@ import java.time.LocalTime; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; @@ -284,12 +288,14 @@ public class BeamRowToStorageApiProtoTest { .addField("nested", FieldType.row(BASE_SCHEMA).withNullable(true)) .addField("nestedArray", FieldType.array(FieldType.row(BASE_SCHEMA))) .addField("nestedIterable", FieldType.iterable(FieldType.row(BASE_SCHEMA))) + .addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.row(BASE_SCHEMA))) .build(); private static final Row NESTED_ROW = Row.withSchema(NESTED_SCHEMA) .withFieldValue("nested", BASE_ROW) .withFieldValue("nestedArray", ImmutableList.of(BASE_ROW, BASE_ROW)) .withFieldValue("nestedIterable", ImmutableList.of(BASE_ROW, BASE_ROW)) + .withFieldValue("nestedMap", ImmutableMap.of("key1", BASE_ROW, "key2", BASE_ROW)) .build(); @Test @@ -347,12 +353,12 @@ public void testNestedFromSchema() { .collect( Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel)); - assertEquals(3, types.size()); + assertEquals(4, types.size()); Map nestedTypes = descriptor.getNestedTypeList().stream() .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); - assertEquals(3, nestedTypes.size()); + assertEquals(4, nestedTypes.size()); assertEquals(Type.TYPE_MESSAGE, types.get("nested")); assertEquals(Label.LABEL_OPTIONAL, typeLabels.get("nested")); String nestedTypeName1 = typeNames.get("nested"); @@ -379,6 +385,87 @@ public void testNestedFromSchema() { .collect( Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); assertEquals(expectedBaseTypes, nestedTypes3); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmap")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmap")); + String nestedTypeName4 = typeNames.get("nestedmap"); + // expects 2 fields in the nested map, key and value + assertEquals(2, nestedTypes.get(nestedTypeName4).getFieldList().size()); + Supplier> stream = + () -> nestedTypes.get(nestedTypeName4).getFieldList().stream(); + assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("key"))); + assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("value"))); + + Map nestedTypes4 = + nestedTypes.get(nestedTypeName4).getNestedTypeList().stream() + .flatMap(vdesc -> vdesc.getFieldList().stream()) + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedBaseTypes, nestedTypes4); + } + + @Test + public void testParticularMapsFromSchemas() { + Schema nestedMapSchemaVariations = + Schema.builder() + .addField( + "nestedMultiMap", + FieldType.map(FieldType.STRING, FieldType.array(FieldType.STRING))) + .addField( + "nestedMapNullable", + FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true)) + .build(); + + DescriptorProto descriptor = + TableRowToStorageApiProto.descriptorSchemaFromTableSchema( + BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema((nestedMapSchemaVariations)), + true, + false); + + Map types = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + Map typeNames = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getTypeName)); + Map typeLabels = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel)); + + Map nestedTypes = + descriptor.getNestedTypeList().stream() + .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); + assertEquals(2, nestedTypes.size()); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmultimap")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmultimap")); + String nestedMultiMapName = typeNames.get("nestedmultimap"); + // expects 2 fields for the nested array of maps, key and value + assertEquals(2, nestedTypes.get(nestedMultiMapName).getFieldList().size()); + Supplier> stream = + () -> nestedTypes.get(nestedMultiMapName).getFieldList().stream(); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1); + assertTrue( + stream + .get() + .filter(fdp -> fdp.getName().equals("value")) + .filter(fdp -> fdp.getLabel().equals(Label.LABEL_REPEATED)) + .count() + == 1); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmapnullable")); + // even though the field is marked as optional in the row we will should see repeated in proto + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmapnullable")); + String nestedMapNullableName = typeNames.get("nestedmapnullable"); + // expects 2 fields in the nullable maps, key and value + assertEquals(2, nestedTypes.get(nestedMapNullableName).getFieldList().size()); + stream = () -> nestedTypes.get(nestedMapNullableName).getFieldList().stream(); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1); } private void assertBaseRecord(DynamicMessage msg) { @@ -395,7 +482,7 @@ public void testMessageFromTableRow() throws Exception { BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(NESTED_SCHEMA), true, false); DynamicMessage msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, null, -1); - assertEquals(3, msg.getAllFields().size()); + assertEquals(4, msg.getAllFields().size()); Map fieldDescriptors = descriptor.getFields().stream() @@ -404,6 +491,63 @@ public void testMessageFromTableRow() throws Exception { assertBaseRecord(nestedMsg); } + @Test + public void testMessageFromTableRowForArraysAndMaps() throws Exception { + Schema nestedMapSchemaVariations = + Schema.builder() + .addField("nestedArrayNullable", FieldType.array(FieldType.STRING).withNullable(true)) + .addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.STRING)) + .addField( + "nestedMultiMap", + FieldType.map(FieldType.STRING, FieldType.iterable(FieldType.STRING))) + .addField( + "nestedMapNullable", + FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true)) + .build(); + + Row nestedRow = + Row.withSchema(nestedMapSchemaVariations) + .withFieldValue("nestedArrayNullable", null) + .withFieldValue("nestedMap", ImmutableMap.of("key1", "value1")) + .withFieldValue( + "nestedMultiMap", + ImmutableMap.of("multikey1", ImmutableList.of("multivalue1", "multivalue2"))) + .withFieldValue("nestedMapNullable", null) + .build(); + + Descriptor descriptor = + TableRowToStorageApiProto.getDescriptorFromTableSchema( + BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(nestedMapSchemaVariations), + true, + false); + DynamicMessage msg = + BeamRowToStorageApiProto.messageFromBeamRow(descriptor, nestedRow, null, -1); + + Map fieldDescriptors = + descriptor.getFields().stream() + .collect(Collectors.toMap(FieldDescriptor::getName, Functions.identity())); + + DynamicMessage nestedMapEntryMsg = + (DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmap"), 0); + String value = + (String) + nestedMapEntryMsg.getField( + fieldDescriptors.get("nestedmap").getMessageType().findFieldByName("value")); + assertEquals("value1", value); + + DynamicMessage nestedMultiMapEntryMsg = + (DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmultimap"), 0); + List values = + (List) + nestedMultiMapEntryMsg.getField( + fieldDescriptors.get("nestedmultimap").getMessageType().findFieldByName("value")); + assertTrue(values.size() == 2); + assertEquals("multivalue1", values.get(0)); + + assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedarraynullable")) == 0); + assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedmapnullable")) == 0); + } + @Test public void testCdcFields() throws Exception { Descriptor descriptor = @@ -413,7 +557,7 @@ public void testCdcFields() throws Exception { assertNotNull(descriptor.findFieldByName(StorageApiCDC.CHANGE_SQN_COLUMN)); DynamicMessage msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, "UPDATE", 42); - assertEquals(5, msg.getAllFields().size()); + assertEquals(6, msg.getAllFields().size()); Map fieldDescriptors = descriptor.getFields().stream() diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java index e26348b7b478..8b65e58a4601 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java @@ -698,6 +698,18 @@ public void testToTableSchema_map() { assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE)); } + @Test + public void testToTableSchema_map_array() { + TableSchema schema = toTableSchema(MAP_ARRAY_TYPE); + + assertThat(schema.getFields().size(), equalTo(1)); + TableFieldSchema field = schema.getFields().get(0); + assertThat(field.getName(), equalTo("map")); + assertThat(field.getType(), equalTo(StandardSQLTypeName.STRUCT.toString())); + assertThat(field.getMode(), equalTo(Mode.REPEATED.toString())); + assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE)); + } + @Test public void testToTableRow_flat() { TableRow row = toTableRow().apply(FLAT_ROW);