diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java index d52329675d61..3785b07a3b45 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java @@ -25,11 +25,14 @@ import com.google.cloud.pubsublite.TopicPath; import com.google.cloud.pubsublite.proto.PubSubMessage; import com.google.protobuf.ByteString; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Set; import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; @@ -38,16 +41,22 @@ import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; import org.apache.beam.sdk.schemas.utils.JsonUtils; -import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @AutoService(SchemaTransformProvider.class) public class PubsubLiteWriteSchemaTransformProvider @@ -57,6 +66,12 @@ public class PubsubLiteWriteSchemaTransformProvider public static final String SUPPORTED_FORMATS_STR = "JSON,AVRO"; public static final Set SUPPORTED_FORMATS = Sets.newHashSet(SUPPORTED_FORMATS_STR.split(",")); + public static final TupleTag OUTPUT_TAG = new TupleTag() {}; + public static final TupleTag ERROR_TAG = new TupleTag() {}; + public static final Schema ERROR_SCHEMA = + Schema.builder().addStringField("error").addNullableByteArrayField("row").build(); + private static final Logger LOG = + LoggerFactory.getLogger(PubsubLiteWriteSchemaTransformProvider.class); @Override protected @UnknownKeyFor @NonNull @Initialized Class @@ -64,6 +79,44 @@ public class PubsubLiteWriteSchemaTransformProvider return PubsubLiteWriteSchemaTransformConfiguration.class; } + public static class ErrorCounterFn extends DoFn { + private SerializableFunction toBytesFn; + private Counter errorCounter; + private long errorsInBundle = 0L; + + public ErrorCounterFn(String name, SerializableFunction toBytesFn) { + this.toBytesFn = toBytesFn; + errorCounter = Metrics.counter(PubsubLiteWriteSchemaTransformProvider.class, name); + } + + @ProcessElement + public void process(@DoFn.Element Row row, MultiOutputReceiver receiver) { + try { + PubSubMessage message = + PubSubMessage.newBuilder() + .setData(ByteString.copyFrom(Objects.requireNonNull(toBytesFn.apply(row)))) + .build(); + + receiver.get(OUTPUT_TAG).output(message); + } catch (Exception e) { + errorsInBundle += 1; + LOG.warn("Error while parsing the element", e); + receiver + .get(ERROR_TAG) + .output( + Row.withSchema(ERROR_SCHEMA) + .addValues(e.toString(), row.toString().getBytes(StandardCharsets.UTF_8)) + .build()); + } + } + + @FinishBundle + public void finish() { + errorCounter.inc(errorsInBundle); + errorsInBundle = 0L; + } + } + @Override public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( PubsubLiteWriteSchemaTransformConfiguration configuration) { @@ -92,18 +145,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { configuration.getFormat().equals("JSON") ? JsonUtils.getRowToJsonBytesFunction(inputSchema) : AvroUtils.getRowToAvroBytesFunction(inputSchema); - input - .get("input") - .apply( - "Map Rows to PubSubMessages", - MapElements.into(TypeDescriptor.of(PubSubMessage.class)) - .via( - row -> - PubSubMessage.newBuilder() - .setData( - ByteString.copyFrom( - Objects.requireNonNull(toBytesFn.apply(row)))) - .build())) + + PCollectionTuple outputTuple = + input + .get("input") + .apply( + "Map Rows to PubSubMessages", + ParDo.of(new ErrorCounterFn("PubSubLite-write-error-counter", toBytesFn)) + .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); + + outputTuple + .get(OUTPUT_TAG) .apply("Add UUIDs", PubsubLiteIO.addUuids()) .apply( "Write to PS Lite", @@ -117,7 +169,9 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { CloudRegionOrZone.parse(configuration.getLocation())) .build()) .build())); - return PCollectionRowTuple.empty(input.getPipeline()); + + return PCollectionRowTuple.of( + "errors", outputTuple.get(ERROR_TAG).setRowSchema(ERROR_SCHEMA)); } }; } @@ -138,7 +192,7 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { @Override public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> outputCollectionNames() { - return Collections.emptyList(); + return Collections.singletonList("errors"); } @AutoValue diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/PubsubLiteWriteDlqTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/PubsubLiteWriteDlqTest.java new file mode 100644 index 000000000000..d42eb249b27e --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/PubsubLiteWriteDlqTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.pubsublite.internal; + +import com.google.cloud.pubsublite.proto.PubSubMessage; +import com.google.protobuf.ByteString; +import java.util.Arrays; +import java.util.List; +import org.apache.beam.sdk.io.gcp.pubsublite.PubsubLiteWriteSchemaTransformProvider; +import org.apache.beam.sdk.io.gcp.pubsublite.PubsubLiteWriteSchemaTransformProvider.ErrorCounterFn; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.utils.JsonUtils; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class PubsubLiteWriteDlqTest { + + private static final TupleTag OUTPUT_TAG = + PubsubLiteWriteSchemaTransformProvider.OUTPUT_TAG; + private static final TupleTag ERROR_TAG = PubsubLiteWriteSchemaTransformProvider.ERROR_TAG; + + private static final Schema BEAMSCHEMA = + Schema.of(Schema.Field.of("name", Schema.FieldType.STRING)); + private static final Schema ERRORSCHEMA = PubsubLiteWriteSchemaTransformProvider.ERROR_SCHEMA; + + private static final List ROWS = + Arrays.asList( + Row.withSchema(BEAMSCHEMA).withFieldValue("name", "a").build(), + Row.withSchema(BEAMSCHEMA).withFieldValue("name", "b").build(), + Row.withSchema(BEAMSCHEMA).withFieldValue("name", "c").build()); + + private static final List MESSAGES = + Arrays.asList( + PubSubMessage.newBuilder().setData(ByteString.copyFromUtf8("{\"name\":\"a\"}")).build(), + PubSubMessage.newBuilder().setData(ByteString.copyFromUtf8("{\"name\":\"b\"}")).build(), + PubSubMessage.newBuilder().setData(ByteString.copyFromUtf8("{\"name\":\"c\"}")).build()); + + final SerializableFunction valueMapper = + JsonUtils.getRowToJsonBytesFunction(BEAMSCHEMA); + + @Rule public transient TestPipeline p = TestPipeline.create(); + + @Test + public void testPubsubLiteErrorFnSuccess() throws Exception { + PCollection input = p.apply(Create.of(ROWS)); + PCollectionTuple output = + input.apply( + ParDo.of(new ErrorCounterFn("ErrorCounter", valueMapper)) + .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); + + output.get(ERROR_TAG).setRowSchema(ERRORSCHEMA); + + PAssert.that(output.get(OUTPUT_TAG)).containsInAnyOrder(MESSAGES); + p.run().waitUntilFinish(); + } +}