Skip to content

Commit

Permalink
Adding support for Beam Schema Rows with BQ DIRECT_READ (#22926)
Browse files Browse the repository at this point in the history
* Adding support for Beam Schema Rows with BQ DIRECT_READ

* Fixing for trimmed-out fields

* refactor

* Fix NPE in FakeJobService

* Addressing comments
  • Loading branch information
pabloem authored Sep 2, 2022
1 parent a045577 commit 948af30
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1068,16 +1068,20 @@ public PCollection<T> expand(PBegin input) {
checkArgument(getParseFn() != null, "A parseFn is required");

// if both toRowFn and fromRowFn values are set, enable Beam schema support
boolean beamSchemaEnabled = false;
Pipeline p = input.getPipeline();
final BigQuerySourceDef sourceDef = createSourceDef();

Schema beamSchema = null;
if (getTypeDescriptor() != null && getToBeamRowFn() != null && getFromBeamRowFn() != null) {
beamSchemaEnabled = true;
BigQueryOptions bqOptions = p.getOptions().as(BigQueryOptions.class);
beamSchema = sourceDef.getBeamSchema(bqOptions);
beamSchema = getFinalSchema(beamSchema, getSelectedFields());
}

Pipeline p = input.getPipeline();
final Coder<T> coder = inferCoder(p.getCoderRegistry());

if (getMethod() == TypedRead.Method.DIRECT_READ) {
return expandForDirectRead(input, coder);
return expandForDirectRead(input, coder, beamSchema);
}

checkArgument(
Expand All @@ -1090,7 +1094,6 @@ public PCollection<T> expand(PBegin input) {
"Invalid BigQueryIO.Read: Specifies row restriction, "
+ "which only applies when using Method.DIRECT_READ");

final BigQuerySourceDef sourceDef = createSourceDef();
final PCollectionView<String> jobIdTokenView;
PCollection<String> jobIdTokenCollection;
PCollection<T> rows;
Expand Down Expand Up @@ -1221,33 +1224,60 @@ void cleanup(PassThroughThenCleanup.ContextContainer c) throws Exception {

rows = rows.apply(new PassThroughThenCleanup<>(cleanupOperation, jobIdTokenView));

if (beamSchemaEnabled) {
BigQueryOptions bqOptions = p.getOptions().as(BigQueryOptions.class);
Schema beamSchema = sourceDef.getBeamSchema(bqOptions);
SerializableFunction<T, Row> toBeamRow = getToBeamRowFn().apply(beamSchema);
SerializableFunction<Row, T> fromBeamRow = getFromBeamRowFn().apply(beamSchema);

rows.setSchema(beamSchema, getTypeDescriptor(), toBeamRow, fromBeamRow);
if (beamSchema != null) {
rows.setSchema(
beamSchema,
getTypeDescriptor(),
getToBeamRowFn().apply(beamSchema),
getFromBeamRowFn().apply(beamSchema));
}
return rows;
}

private PCollection<T> expandForDirectRead(PBegin input, Coder<T> outputCoder) {
private static Schema getFinalSchema(
Schema beamSchema, ValueProvider<List<String>> selectedFields) {
List<Schema.Field> flds =
beamSchema.getFields().stream()
.filter(
field -> {
if (selectedFields != null
&& selectedFields.isAccessible()
&& selectedFields.get() != null) {
return selectedFields.get().contains(field.getName());
} else {
return true;
}
})
.collect(Collectors.toList());
return Schema.builder().addFields(flds).build();
}

private PCollection<T> expandForDirectRead(
PBegin input, Coder<T> outputCoder, Schema beamSchema) {
ValueProvider<TableReference> tableProvider = getTableProvider();
Pipeline p = input.getPipeline();
if (tableProvider != null) {
// No job ID is required. Read directly from BigQuery storage.
return p.apply(
org.apache.beam.sdk.io.Read.from(
BigQueryStorageTableSource.create(
tableProvider,
getFormat(),
getSelectedFields(),
getRowRestriction(),
getParseFn(),
outputCoder,
getBigQueryServices(),
getProjectionPushdownApplied())));
PCollection<T> rows =
p.apply(
org.apache.beam.sdk.io.Read.from(
BigQueryStorageTableSource.create(
tableProvider,
getFormat(),
getSelectedFields(),
getRowRestriction(),
getParseFn(),
outputCoder,
getBigQueryServices(),
getProjectionPushdownApplied())));
if (beamSchema != null) {
rows.setSchema(
beamSchema,
getTypeDescriptor(),
getToBeamRowFn().apply(beamSchema),
getFromBeamRowFn().apply(beamSchema));
}
return rows;
}

checkArgument(
Expand Down Expand Up @@ -1437,6 +1467,13 @@ void cleanup(ContextContainer c) throws Exception {
}
};

if (beamSchema != null) {
rows.setSchema(
beamSchema,
getTypeDescriptor(),
getToBeamRowFn().apply(beamSchema),
getFromBeamRowFn().apply(beamSchema));
}
return rows.apply(new PassThroughThenCleanup<>(cleanupOperation, jobIdTokenView));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ private static Object toBeamRowFieldValue(Field field, Object bqValue) {
return null;
} else {
throw new IllegalArgumentException(
"Received null value for non-nullable field " + field.getName());
"Received null value for non-nullable field \"" + field.getName() + "\"");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ public Job getJob(JobReference jobRef) {
"Job %s failed: %s", job.job.getConfiguration(), e.toString())));
List<ResourceId> sourceFiles =
filesForLoadJobs.get(jobRef.getProjectId(), jobRef.getJobId());
FileSystems.delete(sourceFiles);
if (sourceFiles != null) {
FileSystems.delete(sourceFiles);
}
}
return JSON_FACTORY.fromString(JSON_FACTORY.toString(job.job), Job.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.io.gcp.bigquery;

import static org.junit.Assert.assertEquals;

import com.google.cloud.bigquery.storage.v1.DataFormat;
import java.util.Map;
import org.apache.beam.sdk.Pipeline;
Expand All @@ -32,6 +34,7 @@
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.transforms.Convert;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestPipelineOptions;
Expand Down Expand Up @@ -123,6 +126,39 @@ public void testBigQueryStorageRead1GArrow() throws Exception {
runBigQueryIOStorageReadPipeline();
}

@Test
public void testBigQueryStorageReadWithAvro() throws Exception {
storageReadWithSchema(DataFormat.AVRO);
}

@Test
public void testBigQueryStorageReadWithArrow() throws Exception {
storageReadWithSchema(DataFormat.ARROW);
}

private void storageReadWithSchema(DataFormat format) {
setUpTestEnvironment("multi_field", format);

Schema multiFieldSchema =
Schema.builder()
.addNullableField("string_field", FieldType.STRING)
.addNullableField("int_field", FieldType.INT64)
.build();

Pipeline p = Pipeline.create(options);
PCollection<Row> tableContents =
p.apply(
"Read",
BigQueryIO.readTableRowsWithSchema()
.from(options.getInputTable())
.withMethod(Method.DIRECT_READ)
.withFormat(options.getDataFormat()))
.apply(Convert.toRows());
PAssert.thatSingleton(tableContents.apply(Count.globally())).isEqualTo(options.getNumRecords());
assertEquals(tableContents.getSchema(), multiFieldSchema);
p.run().waitUntilFinish();
}

/**
* Tests a pipeline where {@link
* org.apache.beam.runners.core.construction.graph.ProjectionPushdownOptimizer} may do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.transforms.Convert;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.PTransform;
Expand Down Expand Up @@ -1481,6 +1482,82 @@ public void testReadFromBigQueryIOWithTrimmedSchema() throws Exception {
p.run();
}

@Test
public void testReadFromBigQueryIOWithBeamSchema() throws Exception {
fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null);
TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table");
Table table = new Table().setTableReference(tableRef).setNumBytes(10L).setSchema(TABLE_SCHEMA);
fakeDatasetService.createTable(table);

CreateReadSessionRequest expectedCreateReadSessionRequest =
CreateReadSessionRequest.newBuilder()
.setParent("projects/project-id")
.setReadSession(
ReadSession.newBuilder()
.setTable("projects/foo.com:project/datasets/dataset/tables/table")
.setReadOptions(
ReadSession.TableReadOptions.newBuilder().addSelectedFields("name"))
.setDataFormat(DataFormat.AVRO))
.setMaxStreamCount(10)
.build();

ReadSession readSession =
ReadSession.newBuilder()
.setName("readSessionName")
.setAvroSchema(AvroSchema.newBuilder().setSchema(TRIMMED_AVRO_SCHEMA_STRING))
.addStreams(ReadStream.newBuilder().setName("streamName"))
.setDataFormat(DataFormat.AVRO)
.build();

ReadRowsRequest expectedReadRowsRequest =
ReadRowsRequest.newBuilder().setReadStream("streamName").build();

List<GenericRecord> records =
Lists.newArrayList(
createRecord("A", TRIMMED_AVRO_SCHEMA),
createRecord("B", TRIMMED_AVRO_SCHEMA),
createRecord("C", TRIMMED_AVRO_SCHEMA),
createRecord("D", TRIMMED_AVRO_SCHEMA));

List<ReadRowsResponse> readRowsResponses =
Lists.newArrayList(
createResponse(TRIMMED_AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.50),
createResponse(TRIMMED_AVRO_SCHEMA, records.subList(2, 4), 0.5, 0.75));

StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable());
when(fakeStorageClient.createReadSession(expectedCreateReadSessionRequest))
.thenReturn(readSession);
when(fakeStorageClient.readRows(expectedReadRowsRequest, ""))
.thenReturn(new FakeBigQueryServerStream<>(readRowsResponses));

PCollection<Row> output =
p.apply(
BigQueryIO.readTableRowsWithSchema()
.from("foo.com:project:dataset.table")
.withMethod(Method.DIRECT_READ)
.withSelectedFields(Lists.newArrayList("name"))
.withFormat(DataFormat.AVRO)
.withTestServices(
new FakeBigQueryServices()
.withDatasetService(fakeDatasetService)
.withStorageClient(fakeStorageClient)))
.apply(Convert.toRows());

org.apache.beam.sdk.schemas.Schema beamSchema =
org.apache.beam.sdk.schemas.Schema.of(
org.apache.beam.sdk.schemas.Schema.Field.of(
"name", org.apache.beam.sdk.schemas.Schema.FieldType.STRING));
PAssert.that(output)
.containsInAnyOrder(
ImmutableList.of(
Row.withSchema(beamSchema).addValue("A").build(),
Row.withSchema(beamSchema).addValue("B").build(),
Row.withSchema(beamSchema).addValue("C").build(),
Row.withSchema(beamSchema).addValue("D").build()));

p.run();
}

@Test
public void testReadFromBigQueryIOArrow() throws Exception {
fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null);
Expand Down

0 comments on commit 948af30

Please sign in to comment.