From e770eacca77d13d710adad0d32e06e864d8f7d02 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud Date: Fri, 12 Apr 2024 01:05:20 -0400 Subject: [PATCH] tests --- .../schemas/transforms/SchemaTransform.java | 9 ++- ...chemaTransformProviderTranslationTest.java | 78 +++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SchemaTransformProviderTranslationTest.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaTransform.java index ef6ebcd2a92d..bcd201a86b63 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaTransform.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaTransform.java @@ -24,6 +24,7 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PCollectionRowTuple; @@ -61,8 +62,14 @@ protected SchemaTransform(ConfigT configuration, String identifier) { Class typedClass = (Class) parameterizedType.getActualTypeArguments()[0]; try { + // Get initial row with values + Row row = SchemaRegistry.createDefault().getToRowFunction(typedClass).apply(configuration); + // Get sorted Schema and recreate the Row + Schema configurationSchema = SchemaRegistry.createDefault().getSchema(typedClass).sorted(); this.configurationRow = - SchemaRegistry.createDefault().getToRowFunction(typedClass).apply(configuration); + configurationSchema.getFields().stream() + .map(field -> row.getValue(field.getName())) + .collect(Row.toRow(configurationSchema)); } catch (NoSuchSchemaException e) { throw new RuntimeException("Unable to find schema for this SchemaTransform's config.", e); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SchemaTransformProviderTranslationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SchemaTransformProviderTranslationTest.java new file mode 100644 index 000000000000..06983de9e295 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SchemaTransformProviderTranslationTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas.transforms; + +import static org.apache.beam.sdk.schemas.transforms.SchemaTransformProviderTranslation.SchemaTransformTranslator; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.transforms.providers.FlattenTransformProvider; +import org.apache.beam.sdk.schemas.transforms.providers.JavaExplodeTransformProvider; +import org.apache.beam.sdk.schemas.transforms.providers.LoggingTransformProvider; +import org.apache.beam.sdk.values.Row; +import org.junit.Test; + +public class SchemaTransformProviderTranslationTest { + public void translateAndRunChecks(SchemaTransformProvider provider, Row originalRow) { + SchemaTransform transform = provider.from(originalRow); + + SchemaTransformTranslator translator = new SchemaTransformTranslator(provider.identifier()); + Row rowFromTransform = translator.toConfigRow(transform); + + SchemaTransform transformFromRow = + translator.fromConfigRow(rowFromTransform, PipelineOptionsFactory.create()); + + assertEquals(originalRow, rowFromTransform); + assertEquals(originalRow, transformFromRow.getConfigurationRow()); + assertEquals( + provider.configurationSchema(), transformFromRow.getConfigurationRow().getSchema()); + assertEquals(provider.identifier(), transformFromRow.getIdentifier()); + } + + @Test + public void testReCreateJavaExplodeTransform() { + JavaExplodeTransformProvider provider = new JavaExplodeTransformProvider(); + + Row originalRow = + Row.withSchema(provider.configurationSchema()) + .withFieldValue("crossProduct", true) + .withFieldValue("fields", Arrays.asList("a", "c")) + .build(); + + translateAndRunChecks(provider, originalRow); + } + + @Test + public void testReCreateFlattenTransform() { + FlattenTransformProvider provider = new FlattenTransformProvider(); + Row originalRow = Row.withSchema(provider.configurationSchema()).build(); + translateAndRunChecks(provider, originalRow); + } + + @Test + public void testReCreateLoggingTransform() { + LoggingTransformProvider provider = new LoggingTransformProvider(); + Row originalRow = + Row.withSchema(provider.configurationSchema()) + .withFieldValue("level", "INFO") + .withFieldValue("prefix", "some_prefix") + .build(); + translateAndRunChecks(provider, originalRow); + } +}