Skip to content

Commit

Permalink
add file loads translation and tests; add test checks that the correc…
Browse files Browse the repository at this point in the history
…t transform is chosen
  • Loading branch information
ahmedabu98 committed Nov 7, 2024
1 parent 01a01f7 commit 697c0b8
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
Expand Down Expand Up @@ -55,7 +57,7 @@ public class BigQueryFileLoadsWriteSchemaTransformProvider

@Override
protected SchemaTransform from(BigQueryWriteConfiguration configuration) {
return new BigQueryWriteSchemaTransform(configuration);
return new BigQueryFileLoadsSchemaTransform(configuration);
}

@Override
Expand All @@ -73,13 +75,13 @@ public List<String> outputCollectionNames() {
return Collections.emptyList();
}

protected static class BigQueryWriteSchemaTransform extends SchemaTransform {
public static class BigQueryFileLoadsSchemaTransform extends SchemaTransform {
/** An instance of {@link BigQueryServices} used for testing. */
private BigQueryServices testBigQueryServices = null;

private final BigQueryWriteConfiguration configuration;

BigQueryWriteSchemaTransform(BigQueryWriteConfiguration configuration) {
BigQueryFileLoadsSchemaTransform(BigQueryWriteConfiguration configuration) {
configuration.validate();
this.configuration = configuration;
}
Expand Down Expand Up @@ -126,5 +128,19 @@ BigQueryIO.Write<Row> toWrite() {
void setTestBigQueryServices(BigQueryServices testBigQueryServices) {
this.testBigQueryServices = testBigQueryServices;
}

public Row getConfigurationRow() {
try {
// To stay consistent with our SchemaTransform configuration naming conventions,
// we sort lexicographically
return SchemaRegistry.createDefault()
.getToRowFunction(BigQueryWriteConfiguration.class)
.apply(configuration)
.sorted()
.toSnakeCase();
} catch (NoSuchSchemaException e) {
throw new RuntimeException(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.gcp.bigquery;
package org.apache.beam.sdk.io.gcp.bigquery.providers;

import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransform;
import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsWriteSchemaTransformProvider.BigQueryFileLoadsSchemaTransform;
import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransform;

import com.google.auto.service.AutoService;
import java.util.Map;
import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider;
import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation;
import org.apache.beam.sdk.transforms.PTransform;
Expand Down Expand Up @@ -61,6 +60,20 @@ public Row toConfigRow(BigQueryStorageWriteApiSchemaTransform transform) {
}
}

public static class BigQueryFileLoadsSchemaTransformTranslator
extends SchemaTransformTranslation.SchemaTransformPayloadTranslator<
BigQueryFileLoadsSchemaTransform> {
@Override
public SchemaTransformProvider provider() {
return new BigQueryFileLoadsWriteSchemaTransformProvider();
}

@Override
public Row toConfigRow(BigQueryFileLoadsSchemaTransform transform) {
return transform.getConfigurationRow();
}
}

@AutoService(TransformPayloadTranslatorRegistrar.class)
public static class ReadWriteRegistrar implements TransformPayloadTranslatorRegistrar {
@Override
Expand All @@ -79,6 +92,9 @@ public static class ReadWriteRegistrar implements TransformPayloadTranslatorRegi
.put(
BigQueryStorageWriteApiSchemaTransform.class,
new BigQueryStorageWriteSchemaTransformTranslator())
.put(
BigQueryFileLoadsSchemaTransform.class,
new BigQueryFileLoadsSchemaTransformTranslator())
.build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryOptions;
import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsWriteSchemaTransformProvider.BigQueryWriteSchemaTransform;
import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsWriteSchemaTransformProvider.BigQueryFileLoadsSchemaTransform;
import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices;
import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService;
import org.apache.beam.sdk.io.gcp.testing.FakeJobService;
Expand Down Expand Up @@ -106,8 +106,8 @@ public void testLoad() throws IOException, InterruptedException {
.setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name())
.setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name())
.build();
BigQueryWriteSchemaTransform schemaTransform =
(BigQueryWriteSchemaTransform) provider.from(configuration);
BigQueryFileLoadsSchemaTransform schemaTransform =
(BigQueryFileLoadsSchemaTransform) provider.from(configuration);
schemaTransform.setTestBigQueryServices(fakeBigQueryServices);
String tag = provider.inputCollectionNames().get(0);
PCollectionRowTuple input =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.gcp.bigquery;
package org.apache.beam.sdk.io.gcp.bigquery.providers;

import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM;
import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.io.gcp.testing.BigqueryClient;
Expand All @@ -33,10 +39,12 @@
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PeriodicImpulse;
import org.apache.beam.sdk.util.construction.PipelineTranslation;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.joda.time.Duration;
import org.joda.time.Instant;
Expand Down Expand Up @@ -87,6 +95,27 @@ public static void cleanup() {
BQ_CLIENT.deleteDataset(PROJECT, BIG_QUERY_DATASET_ID);
}

private void assertPipelineContainsTransformIdentifier(
Pipeline p, String schemaTransformIdentifier) {
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
List<RunnerApi.PTransform> writeTransformProto =
pipelineProto.getComponents().getTransformsMap().values().stream()
.filter(
tr -> {
RunnerApi.FunctionSpec spec = tr.getSpec();
try {
return spec.getUrn().equals(getUrn(SCHEMA_TRANSFORM))
&& ExternalTransforms.SchemaTransformPayload.parseFrom(spec.getPayload())
.getIdentifier()
.equals(schemaTransformIdentifier);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
assertEquals(1, writeTransformProto.size());
}

@Test
public void testBatchFileLoadsWriteRead() {
String table =
Expand All @@ -100,6 +129,8 @@ public void testBatchFileLoadsWriteRead() {
// batch write
PCollectionRowTuple.of("input", getInput(writePipeline, false))
.apply(Managed.write(Managed.BIGQUERY).withConfig(config));
assertPipelineContainsTransformIdentifier(
writePipeline, new BigQueryFileLoadsWriteSchemaTransformProvider().identifier());
writePipeline.run().waitUntilFinish();

// read and validate
Expand All @@ -108,7 +139,8 @@ public void testBatchFileLoadsWriteRead() {
.apply(Managed.read(Managed.BIGQUERY).withConfig(config))
.getSinglePCollection();
PAssert.that(outputRows).containsInAnyOrder(ROWS);

assertPipelineContainsTransformIdentifier(
readPipeline, new BigQueryDirectReadSchemaTransformProvider().identifier());
readPipeline.run().waitUntilFinish();
}

Expand All @@ -121,6 +153,8 @@ public void testStreamingStorageWriteRead() {
// streaming write
PCollectionRowTuple.of("input", getInput(writePipeline, true))
.apply(Managed.write(Managed.BIGQUERY).withConfig(config));
assertPipelineContainsTransformIdentifier(
writePipeline, new BigQueryStorageWriteApiSchemaTransformProvider().identifier());
writePipeline.run().waitUntilFinish();

// read and validate
Expand All @@ -129,7 +163,8 @@ public void testStreamingStorageWriteRead() {
.apply(Managed.read(Managed.BIGQUERY).withConfig(config))
.getSinglePCollection();
PAssert.that(outputRows).containsInAnyOrder(ROWS);

assertPipelineContainsTransformIdentifier(
readPipeline, new BigQueryDirectReadSchemaTransformProvider().identifier());
readPipeline.run().waitUntilFinish();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.gcp.bigquery;
package org.apache.beam.sdk.io.gcp.bigquery.providers;

import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM;
import static org.apache.beam.sdk.io.gcp.bigquery.BigQuerySchemaTransformTranslation.BigQueryStorageReadSchemaTransformTranslator;
import static org.apache.beam.sdk.io.gcp.bigquery.BigQuerySchemaTransformTranslation.BigQueryStorageWriteSchemaTransformTranslator;
import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransform;
import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsWriteSchemaTransformProvider.BigQueryFileLoadsSchemaTransform;
import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQuerySchemaTransformTranslation.BigQueryFileLoadsSchemaTransformTranslator;
import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQuerySchemaTransformTranslation.BigQueryStorageReadSchemaTransformTranslator;
import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQuerySchemaTransformTranslation.BigQueryStorageWriteSchemaTransformTranslator;
import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransform;
import static org.junit.Assert.assertEquals;

Expand All @@ -33,8 +35,6 @@
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider;
import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
Expand All @@ -51,12 +51,14 @@

@RunWith(JUnit4.class)
public class BigQuerySchemaTransformTranslationTest {
static final BigQueryStorageWriteApiSchemaTransformProvider WRITE_PROVIDER =
static final BigQueryStorageWriteApiSchemaTransformProvider STORAGE_WRITE_PROVIDER =
new BigQueryStorageWriteApiSchemaTransformProvider();
static final BigQueryFileLoadsWriteSchemaTransformProvider FILE_LOADS_PROVIDER =
new BigQueryFileLoadsWriteSchemaTransformProvider();
static final BigQueryDirectReadSchemaTransformProvider READ_PROVIDER =
new BigQueryDirectReadSchemaTransformProvider();
static final Row WRITE_CONFIG_ROW =
Row.withSchema(WRITE_PROVIDER.configurationSchema())
Row.withSchema(STORAGE_WRITE_PROVIDER.configurationSchema())
.withFieldValue("table", "project:dataset.table")
.withFieldValue("create_disposition", "create_never")
.withFieldValue("write_disposition", "write_append")
Expand All @@ -75,9 +77,9 @@ public class BigQuerySchemaTransformTranslationTest {
.build();

@Test
public void testRecreateWriteTransformFromRow() {
public void testRecreateStorageWriteTransformFromRow() {
BigQueryStorageWriteApiSchemaTransform writeTransform =
(BigQueryStorageWriteApiSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG_ROW);
(BigQueryStorageWriteApiSchemaTransform) STORAGE_WRITE_PROVIDER.from(WRITE_CONFIG_ROW);

BigQueryStorageWriteSchemaTransformTranslator translator =
new BigQueryStorageWriteSchemaTransformTranslator();
Expand All @@ -90,7 +92,22 @@ public void testRecreateWriteTransformFromRow() {
}

@Test
public void testWriteTransformProtoTranslation()
public void testRecreateFileLoadsTransformFromRow() {
BigQueryFileLoadsSchemaTransform writeTransform =
(BigQueryFileLoadsSchemaTransform) FILE_LOADS_PROVIDER.from(WRITE_CONFIG_ROW);

BigQueryFileLoadsSchemaTransformTranslator translator =
new BigQueryFileLoadsSchemaTransformTranslator();
Row translatedRow = translator.toConfigRow(writeTransform);

BigQueryFileLoadsSchemaTransform writeTransformFromRow =
translator.fromConfigRow(translatedRow, PipelineOptionsFactory.create());

assertEquals(WRITE_CONFIG_ROW, writeTransformFromRow.getConfigurationRow());
}

@Test
public void testStorageWriteTransformProtoTranslation()
throws InvalidProtocolBufferException, IOException {
// First build a pipeline
Pipeline p = Pipeline.create();
Expand All @@ -103,7 +120,7 @@ public void testWriteTransformProtoTranslation()
.setRowSchema(inputSchema);

BigQueryStorageWriteApiSchemaTransform writeTransform =
(BigQueryStorageWriteApiSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG_ROW);
(BigQueryStorageWriteApiSchemaTransform) STORAGE_WRITE_PROVIDER.from(WRITE_CONFIG_ROW);
PCollectionRowTuple.of("input", input).apply(writeTransform);

// Then translate the pipeline to a proto and extract KafkaWriteSchemaTransform proto
Expand All @@ -117,7 +134,7 @@ public void testWriteTransformProtoTranslation()
return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))
&& SchemaTransformPayload.parseFrom(spec.getPayload())
.getIdentifier()
.equals(WRITE_PROVIDER.identifier());
.equals(STORAGE_WRITE_PROVIDER.identifier());
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
Expand All @@ -129,7 +146,7 @@ public void testWriteTransformProtoTranslation()
// Check that the proto contains correct values
SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload());
Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());
assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec);
assertEquals(STORAGE_WRITE_PROVIDER.configurationSchema(), schemaFromSpec);
Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput());

assertEquals(WRITE_CONFIG_ROW, rowFromSpec);
Expand All @@ -143,6 +160,60 @@ public void testWriteTransformProtoTranslation()
assertEquals(WRITE_CONFIG_ROW, writeTransformFromSpec.getConfigurationRow());
}

@Test
public void testFileLoadsTransformProtoTranslation()
throws InvalidProtocolBufferException, IOException {
// First build a pipeline
Pipeline p = Pipeline.create();
Schema inputSchema = Schema.builder().addByteArrayField("b").build();
PCollection<Row> input =
p.apply(
Create.of(
Collections.singletonList(
Row.withSchema(inputSchema).addValue(new byte[] {1, 2, 3}).build())))
.setRowSchema(inputSchema);

BigQueryFileLoadsSchemaTransform writeTransform =
(BigQueryFileLoadsSchemaTransform) FILE_LOADS_PROVIDER.from(WRITE_CONFIG_ROW);
PCollectionRowTuple.of("input", input).apply(writeTransform);

// Then translate the pipeline to a proto and extract KafkaWriteSchemaTransform proto
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
List<RunnerApi.PTransform> writeTransformProto =
pipelineProto.getComponents().getTransformsMap().values().stream()
.filter(
tr -> {
RunnerApi.FunctionSpec spec = tr.getSpec();
try {
return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))
&& SchemaTransformPayload.parseFrom(spec.getPayload())
.getIdentifier()
.equals(FILE_LOADS_PROVIDER.identifier());
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
assertEquals(1, writeTransformProto.size());
RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec();

// Check that the proto contains correct values
SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload());
Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());
assertEquals(FILE_LOADS_PROVIDER.configurationSchema(), schemaFromSpec);
Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput());

assertEquals(WRITE_CONFIG_ROW, rowFromSpec);

// Use the information in the proto to recreate the KafkaWriteSchemaTransform
BigQueryFileLoadsSchemaTransformTranslator translator =
new BigQueryFileLoadsSchemaTransformTranslator();
BigQueryFileLoadsSchemaTransform writeTransformFromSpec =
translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create());

assertEquals(WRITE_CONFIG_ROW, writeTransformFromSpec.getConfigurationRow());
}

@Test
public void testReCreateReadTransformFromRow() {
BigQueryDirectReadSchemaTransform readTransform =
Expand Down

0 comments on commit 697c0b8

Please sign in to comment.