Skip to content

Commit

Permalink
Support Map in BQ for StorageWrites API for Beam Rows (#32512)
Browse files Browse the repository at this point in the history
  • Loading branch information
prodriguezdefino authored Oct 22, 2024
1 parent ad6f4ee commit 4f4853e
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import java.time.LocalTime;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand Down Expand Up @@ -221,11 +220,18 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) {
case ITERABLE:
@Nullable FieldType elementType = field.getType().getCollectionElementType();
if (elementType == null) {
throw new RuntimeException("Unexpected null element type!");
throw new RuntimeException("Unexpected null element type on " + field.getName());
}
TypeName containedTypeName =
Preconditions.checkNotNull(
elementType.getTypeName(),
"Null type name found in contained type at " + field.getName());
Preconditions.checkState(
!Preconditions.checkNotNull(elementType.getTypeName()).isCollectionType(),
"Nested arrays not supported by BigQuery.");
!(containedTypeName.isCollectionType() || containedTypeName.isMapType()),
"Nested container types are not supported by BigQuery. Field "
+ field.getName()
+ " contains a type "
+ containedTypeName.name());
TableFieldSchema elementFieldSchema =
fieldDescriptorFromBeamField(Field.of(field.getName(), elementType));
builder = builder.setType(elementFieldSchema.getType());
Expand All @@ -244,7 +250,24 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) {
builder = builder.setType(type);
break;
case MAP:
throw new RuntimeException("Map types not supported by BigQuery.");
@Nullable FieldType keyType = field.getType().getMapKeyType();
@Nullable FieldType valueType = field.getType().getMapValueType();
if (keyType == null) {
throw new RuntimeException(
"Unexpected null element type for the map's key on " + field.getName());
}
if (valueType == null) {
throw new RuntimeException(
"Unexpected null element type for the map's value on " + field.getName());
}

builder =
builder
.setType(TableFieldSchema.Type.STRUCT)
.addFields(fieldDescriptorFromBeamField(Field.of("key", keyType)))
.addFields(fieldDescriptorFromBeamField(Field.of("value", valueType)))
.setMode(TableFieldSchema.Mode.REPEATED);
break;
default:
@Nullable
TableFieldSchema.Type primitiveType = PRIMITIVE_TYPES.get(field.getType().getTypeName());
Expand Down Expand Up @@ -289,25 +312,34 @@ private static Object toProtoValue(
case ROW:
return messageFromBeamRow(fieldDescriptor.getMessageType(), (Row) value, null, -1);
case ARRAY:
List<Object> list = (List<Object>) value;
@Nullable FieldType arrayElementType = beamFieldType.getCollectionElementType();
if (arrayElementType == null) {
throw new RuntimeException("Unexpected null element type!");
}
return list.stream()
.map(v -> toProtoValue(fieldDescriptor, arrayElementType, v))
.collect(Collectors.toList());
case ITERABLE:
Iterable<Object> iterable = (Iterable<Object>) value;
@Nullable FieldType iterableElementType = beamFieldType.getCollectionElementType();
if (iterableElementType == null) {
throw new RuntimeException("Unexpected null element type!");
throw new RuntimeException("Unexpected null element type: " + fieldDescriptor.getName());
}

return StreamSupport.stream(iterable.spliterator(), false)
.map(v -> toProtoValue(fieldDescriptor, iterableElementType, v))
.collect(Collectors.toList());
case MAP:
throw new RuntimeException("Map types not supported by BigQuery.");
Map<Object, Object> map = (Map<Object, Object>) value;
@Nullable FieldType keyType = beamFieldType.getMapKeyType();
@Nullable FieldType valueType = beamFieldType.getMapValueType();
if (keyType == null) {
throw new RuntimeException("Unexpected null for key type: " + fieldDescriptor.getName());
}
if (valueType == null) {
throw new RuntimeException(
"Unexpected null for value type: " + fieldDescriptor.getName());
}

return map.entrySet().stream()
.map(
(Map.Entry<Object, Object> entry) ->
mapEntryToProtoValue(
fieldDescriptor.getMessageType(), keyType, valueType, entry))
.collect(Collectors.toList());
default:
return scalarToProtoValue(beamFieldType, value);
}
Expand Down Expand Up @@ -337,6 +369,28 @@ static Object scalarToProtoValue(FieldType beamFieldType, Object value) {
}
}

static Object mapEntryToProtoValue(
Descriptor descriptor,
FieldType keyFieldType,
FieldType valueFieldType,
Map.Entry<Object, Object> entryValue) {
DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor);
FieldDescriptor keyFieldDescriptor =
Preconditions.checkNotNull(descriptor.findFieldByName("key"));
@Nullable Object key = toProtoValue(keyFieldDescriptor, keyFieldType, entryValue.getKey());
if (key != null) {
builder.setField(keyFieldDescriptor, key);
}
FieldDescriptor valueFieldDescriptor =
Preconditions.checkNotNull(descriptor.findFieldByName("value"));
@Nullable
Object value = toProtoValue(valueFieldDescriptor, valueFieldType, entryValue.getValue());
if (value != null) {
builder.setField(valueFieldDescriptor, value);
}
return builder.build();
}

static ByteString serializeBigDecimalToNumeric(BigDecimal o) {
return serializeBigDecimal(o, NUMERIC_SCALE, MAX_NUMERIC_VALUE, MIN_NUMERIC_VALUE, "Numeric");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import com.google.protobuf.ByteString;
import com.google.protobuf.DescriptorProtos.DescriptorProto;
Expand All @@ -36,8 +37,11 @@
import java.time.LocalTime;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
Expand Down Expand Up @@ -284,12 +288,14 @@ public class BeamRowToStorageApiProtoTest {
.addField("nested", FieldType.row(BASE_SCHEMA).withNullable(true))
.addField("nestedArray", FieldType.array(FieldType.row(BASE_SCHEMA)))
.addField("nestedIterable", FieldType.iterable(FieldType.row(BASE_SCHEMA)))
.addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.row(BASE_SCHEMA)))
.build();
private static final Row NESTED_ROW =
Row.withSchema(NESTED_SCHEMA)
.withFieldValue("nested", BASE_ROW)
.withFieldValue("nestedArray", ImmutableList.of(BASE_ROW, BASE_ROW))
.withFieldValue("nestedIterable", ImmutableList.of(BASE_ROW, BASE_ROW))
.withFieldValue("nestedMap", ImmutableMap.of("key1", BASE_ROW, "key2", BASE_ROW))
.build();

@Test
Expand Down Expand Up @@ -347,12 +353,12 @@ public void testNestedFromSchema() {
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel));

assertEquals(3, types.size());
assertEquals(4, types.size());

Map<String, DescriptorProto> nestedTypes =
descriptor.getNestedTypeList().stream()
.collect(Collectors.toMap(DescriptorProto::getName, Functions.identity()));
assertEquals(3, nestedTypes.size());
assertEquals(4, nestedTypes.size());
assertEquals(Type.TYPE_MESSAGE, types.get("nested"));
assertEquals(Label.LABEL_OPTIONAL, typeLabels.get("nested"));
String nestedTypeName1 = typeNames.get("nested");
Expand All @@ -379,6 +385,87 @@ public void testNestedFromSchema() {
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType));
assertEquals(expectedBaseTypes, nestedTypes3);

assertEquals(Type.TYPE_MESSAGE, types.get("nestedmap"));
assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmap"));
String nestedTypeName4 = typeNames.get("nestedmap");
// expects 2 fields in the nested map, key and value
assertEquals(2, nestedTypes.get(nestedTypeName4).getFieldList().size());
Supplier<Stream<FieldDescriptorProto>> stream =
() -> nestedTypes.get(nestedTypeName4).getFieldList().stream();
assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("key")));
assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("value")));

Map<String, Type> nestedTypes4 =
nestedTypes.get(nestedTypeName4).getNestedTypeList().stream()
.flatMap(vdesc -> vdesc.getFieldList().stream())
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType));
assertEquals(expectedBaseTypes, nestedTypes4);
}

@Test
public void testParticularMapsFromSchemas() {
Schema nestedMapSchemaVariations =
Schema.builder()
.addField(
"nestedMultiMap",
FieldType.map(FieldType.STRING, FieldType.array(FieldType.STRING)))
.addField(
"nestedMapNullable",
FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true))
.build();

DescriptorProto descriptor =
TableRowToStorageApiProto.descriptorSchemaFromTableSchema(
BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema((nestedMapSchemaVariations)),
true,
false);

Map<String, Type> types =
descriptor.getFieldList().stream()
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType));
Map<String, String> typeNames =
descriptor.getFieldList().stream()
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getTypeName));
Map<String, Label> typeLabels =
descriptor.getFieldList().stream()
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel));

Map<String, DescriptorProto> nestedTypes =
descriptor.getNestedTypeList().stream()
.collect(Collectors.toMap(DescriptorProto::getName, Functions.identity()));
assertEquals(2, nestedTypes.size());

assertEquals(Type.TYPE_MESSAGE, types.get("nestedmultimap"));
assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmultimap"));
String nestedMultiMapName = typeNames.get("nestedmultimap");
// expects 2 fields for the nested array of maps, key and value
assertEquals(2, nestedTypes.get(nestedMultiMapName).getFieldList().size());
Supplier<Stream<FieldDescriptorProto>> stream =
() -> nestedTypes.get(nestedMultiMapName).getFieldList().stream();
assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1);
assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1);
assertTrue(
stream
.get()
.filter(fdp -> fdp.getName().equals("value"))
.filter(fdp -> fdp.getLabel().equals(Label.LABEL_REPEATED))
.count()
== 1);

assertEquals(Type.TYPE_MESSAGE, types.get("nestedmapnullable"));
// even though the field is marked as optional in the row we will should see repeated in proto
assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmapnullable"));
String nestedMapNullableName = typeNames.get("nestedmapnullable");
// expects 2 fields in the nullable maps, key and value
assertEquals(2, nestedTypes.get(nestedMapNullableName).getFieldList().size());
stream = () -> nestedTypes.get(nestedMapNullableName).getFieldList().stream();
assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1);
assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1);
}

private void assertBaseRecord(DynamicMessage msg) {
Expand All @@ -395,7 +482,7 @@ public void testMessageFromTableRow() throws Exception {
BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(NESTED_SCHEMA), true, false);
DynamicMessage msg =
BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, null, -1);
assertEquals(3, msg.getAllFields().size());
assertEquals(4, msg.getAllFields().size());

Map<String, FieldDescriptor> fieldDescriptors =
descriptor.getFields().stream()
Expand All @@ -404,6 +491,63 @@ public void testMessageFromTableRow() throws Exception {
assertBaseRecord(nestedMsg);
}

@Test
public void testMessageFromTableRowForArraysAndMaps() throws Exception {
Schema nestedMapSchemaVariations =
Schema.builder()
.addField("nestedArrayNullable", FieldType.array(FieldType.STRING).withNullable(true))
.addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.STRING))
.addField(
"nestedMultiMap",
FieldType.map(FieldType.STRING, FieldType.iterable(FieldType.STRING)))
.addField(
"nestedMapNullable",
FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true))
.build();

Row nestedRow =
Row.withSchema(nestedMapSchemaVariations)
.withFieldValue("nestedArrayNullable", null)
.withFieldValue("nestedMap", ImmutableMap.of("key1", "value1"))
.withFieldValue(
"nestedMultiMap",
ImmutableMap.of("multikey1", ImmutableList.of("multivalue1", "multivalue2")))
.withFieldValue("nestedMapNullable", null)
.build();

Descriptor descriptor =
TableRowToStorageApiProto.getDescriptorFromTableSchema(
BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(nestedMapSchemaVariations),
true,
false);
DynamicMessage msg =
BeamRowToStorageApiProto.messageFromBeamRow(descriptor, nestedRow, null, -1);

Map<String, FieldDescriptor> fieldDescriptors =
descriptor.getFields().stream()
.collect(Collectors.toMap(FieldDescriptor::getName, Functions.identity()));

DynamicMessage nestedMapEntryMsg =
(DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmap"), 0);
String value =
(String)
nestedMapEntryMsg.getField(
fieldDescriptors.get("nestedmap").getMessageType().findFieldByName("value"));
assertEquals("value1", value);

DynamicMessage nestedMultiMapEntryMsg =
(DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmultimap"), 0);
List<String> values =
(List<String>)
nestedMultiMapEntryMsg.getField(
fieldDescriptors.get("nestedmultimap").getMessageType().findFieldByName("value"));
assertTrue(values.size() == 2);
assertEquals("multivalue1", values.get(0));

assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedarraynullable")) == 0);
assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedmapnullable")) == 0);
}

@Test
public void testCdcFields() throws Exception {
Descriptor descriptor =
Expand All @@ -413,7 +557,7 @@ public void testCdcFields() throws Exception {
assertNotNull(descriptor.findFieldByName(StorageApiCDC.CHANGE_SQN_COLUMN));
DynamicMessage msg =
BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, "UPDATE", 42);
assertEquals(5, msg.getAllFields().size());
assertEquals(6, msg.getAllFields().size());

Map<String, FieldDescriptor> fieldDescriptors =
descriptor.getFields().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,18 @@ public void testToTableSchema_map() {
assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE));
}

@Test
public void testToTableSchema_map_array() {
TableSchema schema = toTableSchema(MAP_ARRAY_TYPE);

assertThat(schema.getFields().size(), equalTo(1));
TableFieldSchema field = schema.getFields().get(0);
assertThat(field.getName(), equalTo("map"));
assertThat(field.getType(), equalTo(StandardSQLTypeName.STRUCT.toString()));
assertThat(field.getMode(), equalTo(Mode.REPEATED.toString()));
assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE));
}

@Test
public void testToTableRow_flat() {
TableRow row = toTableRow().apply(FLAT_ROW);
Expand Down

0 comments on commit 4f4853e

Please sign in to comment.