diff --git a/.asf.yaml b/.asf.yaml index 703aca276e6b..50886f2cea5a 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -49,6 +49,7 @@ github: protected_branches: master: {} + release-2.61.0: {} release-2.60.0: {} release-2.59.0: {} release-2.58.1: {} diff --git a/.github/REVIEWERS.yml b/.github/REVIEWERS.yml index 38adde6a7820..dba969180c45 100644 --- a/.github/REVIEWERS.yml +++ b/.github/REVIEWERS.yml @@ -61,6 +61,12 @@ labels: reviewers: - svetakvsundhar exclusionList: [] + - name: kafka + reviewers: + - johnjcasey + - fozzie15 + - Dippatel98 + - sjvanrossum - name: Build reviewers: - damccorm diff --git a/.github/actions/setup-default-test-properties/test-properties.json b/.github/actions/setup-default-test-properties/test-properties.json index 098e4ca1935c..6439492ba5a2 100644 --- a/.github/actions/setup-default-test-properties/test-properties.json +++ b/.github/actions/setup-default-test-properties/test-properties.json @@ -14,7 +14,7 @@ }, "JavaTestProperties": { "SUPPORTED_VERSIONS": ["8", "11", "17", "21"], - "FLINK_VERSIONS": ["1.15", "1.16", "1.17", "1.18"], + "FLINK_VERSIONS": ["1.17", "1.18", "1.19"], "SPARK_VERSIONS": ["2", "3"] }, "GoTestProperties": { diff --git a/.github/gh-actions-self-hosted-runners/arc/config/arc_autoscaler.tpl b/.github/gh-actions-self-hosted-runners/arc/config/arc_autoscaler.tpl index 4b04c5ad8eb1..2b28068adfee 100644 --- a/.github/gh-actions-self-hosted-runners/arc/config/arc_autoscaler.tpl +++ b/.github/gh-actions-self-hosted-runners/arc/config/arc_autoscaler.tpl @@ -39,4 +39,7 @@ spec: scaleDownThreshold: '0.25' scaleUpFactor: '2' scaleDownFactor: '0.5' + - type: TotalNumberOfQueuedAndInProgressWorkflowRuns + repositoryNames: + - beam %{~ endif ~} diff --git a/.github/gh-actions-self-hosted-runners/arc/environments/beam.env b/.github/gh-actions-self-hosted-runners/arc/environments/beam.env index b8d20d6e9f74..e97e9b575a02 100644 --- a/.github/gh-actions-self-hosted-runners/arc/environments/beam.env +++ b/.github/gh-actions-self-hosted-runners/arc/environments/beam.env @@ -81,5 +81,22 @@ additional_runner_pools = [{ labels = ["self-hosted", "ubuntu-20.04", "highmem"] enable_selector = true enable_taint = true +}, +{ + name = "highmem-runner-22" + machine_type = "c3-highmem-22" + runner_image = "us-central1-docker.pkg.dev/apache-beam-testing/beam-github-actions/beam-arc-runner:3063b55757509dad1c14751c9f2aa5905826d9a0" + min_node_count = "0" + max_node_count = "2" + min_replicas = "0" + max_replicas = "2" + webhook_scaling = false + requests = { + cpu = "7.5" + memory = "100Gi" + } + labels = ["self-hosted", "ubuntu-20.04", "highmem22"] + enable_selector = true + enable_taint = true }] #state_bucket_name = "beam-arc-state" diff --git a/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json b/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json index e3d6056a5de9..b98aece75634 100644 --- a/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 1, + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java.json b/.github/trigger_files/beam_PostCommit_Java.json new file mode 100644 index 000000000000..920c8d132e4a --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 +} \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_Java_Avro_Versions.json b/.github/trigger_files/beam_PostCommit_Java_Avro_Versions.json new file mode 100644 index 000000000000..3f63c0c9975f --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Avro_Versions.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json b/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json index a03c067d2c4e..3f63c0c9975f 100644 --- a/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json +++ b/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json b/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json new file mode 100644 index 000000000000..dd9afb90e638 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json b/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json index 08c2e40784a9..8784d0786c02 100644 --- a/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json +++ b/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 } \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_Java_IO_Performance_Tests.json b/.github/trigger_files/beam_PostCommit_Java_IO_Performance_Tests.json new file mode 100644 index 000000000000..b26833333238 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_IO_Performance_Tests.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json b/.github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json new file mode 100644 index 000000000000..dd9afb90e638 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Java_Nexmark_Dataflow.json b/.github/trigger_files/beam_PostCommit_Java_Nexmark_Dataflow.json new file mode 100644 index 000000000000..0967ef424bce --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Nexmark_Dataflow.json @@ -0,0 +1 @@ +{} diff --git a/.github/trigger_files/beam_PostCommit_Java_Nexmark_Flink.json b/.github/trigger_files/beam_PostCommit_Java_Nexmark_Flink.json new file mode 100644 index 000000000000..531514a72738 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Nexmark_Flink.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "runFor": "#33146" +} diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json index b970762c8397..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json index b970762c8397..bdd2197e534a 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "modification": "1" } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json index b60f5c4cc3c8..e3d6056a5de9 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 0 + "modification": 1 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Samza.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Samza.json index b60f5c4cc3c8..e3d6056a5de9 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Samza.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Samza.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 0 + "modification": 1 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json index b60f5c4cc3c8..c537844dc84a 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 0 + "modification": 3 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json index b60f5c4cc3c8..f1ba03a243ee 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 0 + "modification": 5 } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.json new file mode 100644 index 000000000000..b26833333238 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json index b970762c8397..531514a72738 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "runFor": "#33146" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json index b970762c8397..9200c368abbe 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json new file mode 100644 index 000000000000..b07a3c47e196 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json @@ -0,0 +1,4 @@ +{ + + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 1eb60f6e4959..9c7a70ceed74 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 3 + "modification": 7 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Dependency.json b/.github/trigger_files/beam_PostCommit_Python_Dependency.json index e69de29bb2d1..a7fc54b3e4bb 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Dependency.json +++ b/.github/trigger_files/beam_PostCommit_Python_Dependency.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 + } \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.json new file mode 100644 index 000000000000..3f63c0c9975f --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json index e69de29bb2d1..0b34d452d42c 100644 --- a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index 6b3a9dc134ee..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,4 +1,4 @@ - { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index e3d6056a5de9..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json index e3d6056a5de9..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json new file mode 100644 index 000000000000..b26833333238 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/trigger_files/beam_PostCommit_TransformService_Direct.json b/.github/trigger_files/beam_PostCommit_TransformService_Direct.json index c4edaa85a89d..8ed972c9f579 100644 --- a/.github/trigger_files/beam_PostCommit_TransformService_Direct.json +++ b/.github/trigger_files/beam_PostCommit_TransformService_Direct.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "revision": 3 } diff --git a/.github/trigger_files/beam_PostCommit_XVR_Direct.json b/.github/trigger_files/beam_PostCommit_XVR_Direct.json new file mode 100644 index 000000000000..236b7bee8af8 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_Direct.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_XVR_Flink.json b/.github/trigger_files/beam_PostCommit_XVR_Flink.json new file mode 100644 index 000000000000..236b7bee8af8 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_Flink.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.json b/.github/trigger_files/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.json new file mode 100644 index 000000000000..9e26dfeeb6e6 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_XVR_Samza.json b/.github/trigger_files/beam_PostCommit_XVR_Samza.json new file mode 100644 index 000000000000..9e26dfeeb6e6 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_Samza.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/.github/trigger_files/beam_PreCommit_Flink_Container.json b/.github/trigger_files/beam_PreCommit_Flink_Container.json new file mode 100644 index 000000000000..3f63c0c9975f --- /dev/null +++ b/.github/trigger_files/beam_PreCommit_Flink_Container.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/trigger_files/beam_PreCommit_Java_HBase_IO_Direct.json b/.github/trigger_files/beam_PreCommit_Java_HBase_IO_Direct.json new file mode 100644 index 000000000000..0967ef424bce --- /dev/null +++ b/.github/trigger_files/beam_PreCommit_Java_HBase_IO_Direct.json @@ -0,0 +1 @@ +{} diff --git a/.github/workflows/IO_Iceberg_Integration_Tests.yml b/.github/workflows/IO_Iceberg_Integration_Tests.yml index 22b2b4f9287d..68a72790006f 100644 --- a/.github/workflows/IO_Iceberg_Integration_Tests.yml +++ b/.github/workflows/IO_Iceberg_Integration_Tests.yml @@ -75,4 +75,4 @@ jobs: - name: Run IcebergIO Integration Test uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :sdks:java:io:iceberg:catalogTests \ No newline at end of file + gradle-command: :sdks:java:io:iceberg:catalogTests --info \ No newline at end of file diff --git a/.github/workflows/IO_Iceberg_Unit_Tests.yml b/.github/workflows/IO_Iceberg_Unit_Tests.yml index 0d72b0da8597..d063f6ac71db 100644 --- a/.github/workflows/IO_Iceberg_Unit_Tests.yml +++ b/.github/workflows/IO_Iceberg_Unit_Tests.yml @@ -111,6 +111,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/README.md b/.github/workflows/README.md index d386f4dc40f9..206364f416f7 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -285,6 +285,7 @@ Additional PreCommit jobs running basic SDK unit test on a matrices of operating | [Java Tests](https://github.com/apache/beam/actions/workflows/java_tests.yml) | [![.github/workflows/java_tests.yml](https://github.com/apache/beam/actions/workflows/java_tests.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/java_tests.yml?query=event%3Aschedule) | | [Python Tests](https://github.com/apache/beam/actions/workflows/python_tests.yml) | [![.github/workflows/python_tests.yml](https://github.com/apache/beam/actions/workflows/python_tests.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/python_tests.yml?query=event%3Aschedule) | | [TypeScript Tests](https://github.com/apache/beam/actions/workflows/typescript_tests.yml) | [![.github/workflows/typescript_tests.yml](https://github.com/apache/beam/actions/workflows/typescript_tests.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/typescript_tests.yml?query=event%3Aschedule) | +| [Build Wheels](https://github.com/apache/beam/actions/workflows/build_wheels.yml) | [![.github/workflows/build_wheels.yml](https://github.com/apache/beam/actions/workflows/build_wheels.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/build_wheels.yml?query=event%3Aschedule) | ### PostCommit Jobs @@ -330,7 +331,6 @@ PostCommit Jobs run in a schedule against master branch and generally do not get | [ PostCommit Java SingleStoreIO IT ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml) | N/A |`beam_PostCommit_Java_SingleStoreIO_IT.json`| [![.github/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml?query=event%3Aschedule) | | [ PostCommit Java PVR Spark3 Streaming ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml) | N/A |`beam_PostCommit_Java_PVR_Spark3_Streaming.json`| [![.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml?query=event%3Aschedule) | | [ PostCommit Java PVR Spark Batch ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark_Batch.yml) | N/A |`beam_PostCommit_Java_PVR_Spark_Batch.json`| [![.github/workflows/beam_PostCommit_Java_PVR_Spark_Batch.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark_Batch.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark_Batch.yml?query=event%3Aschedule) | -| [ PostCommit Java Sickbay ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Sickbay.yml) | N/A |`beam_PostCommit_Java_Sickbay.json`| [![.github/workflows/beam_PostCommit_Java_Sickbay.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Sickbay.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Sickbay.yml?query=event%3Aschedule) | | [ PostCommit Java Tpcds Dataflow ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Dataflow.yml) | N/A |`beam_PostCommit_Java_Tpcds_Dataflow.json`| [![.github/workflows/beam_PostCommit_Java_Tpcds_Dataflow.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Dataflow.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Dataflow.yml?query=event%3Aschedule) | | [ PostCommit Java Tpcds Flink ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Flink.yml) | N/A |`beam_PostCommit_Java_Tpcds_Flink.json`| [![.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Flink.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Flink.yml?query=event%3Aschedule) | | [ PostCommit Java Tpcds Spark ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Spark.yml) | N/A |`beam_PostCommit_Java_Tpcds_Spark.json`| [![.github/workflows/beam_PostCommit_Java_Tpcds_Spark.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Spark.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Spark.yml?query=event%3Aschedule) | @@ -371,7 +371,6 @@ PostCommit Jobs run in a schedule against master branch and generally do not get | [ PostCommit Python Xlang Gcp Dataflow ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml) | N/A |`beam_PostCommit_Python_Xlang_Gcp_Dataflow.json`| [![.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml?query=event%3Aschedule) | | [ PostCommit Python Xlang Gcp Direct ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml) | N/A |`beam_PostCommit_Python_Xlang_Gcp_Direct.json`| [![.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml?query=event%3Aschedule) | | [ PostCommit Python Xlang IO Dataflow ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml) | N/A |`beam_PostCommit_Python_Xlang_IO_Dataflow.json`| [![.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml?query=event%3Aschedule) | -| [ PostCommit Sickbay Python ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Sickbay_Python.yml) | ['3.8','3.9','3.10','3.11'] |`beam_PostCommit_Sickbay_Python.json`| [![.github/workflows/beam_PostCommit_Sickbay_Python.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Sickbay_Python.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Sickbay_Python.yml?query=event%3Aschedule) | | [ PostCommit SQL ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_SQL.yml) | N/A |`beam_PostCommit_SQL.json`| [![.github/workflows/beam_PostCommit_SQL.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_SQL.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_SQL.yml?query=event%3Aschedule) | | [ PostCommit TransformService Direct ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_TransformService_Direct.yml) | N/A |`beam_PostCommit_TransformService_Direct.json`| [![.github/workflows/beam_PostCommit_TransformService_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_TransformService_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_TransformService_Direct.yml?query=event%3Aschedule) | [ PostCommit Website Test](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Website_Test.yml) | N/A |`beam_PostCommit_Website_Test.json`| [![.github/workflows/beam_PostCommit_Website_Test.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Website_Test.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Website_Test.yml?query=event%3Aschedule) | diff --git a/.github/workflows/beam_LoadTests_Go_CoGBK_Flink_batch.yml b/.github/workflows/beam_LoadTests_Go_CoGBK_Flink_batch.yml index fae86961ea27..a2c347ebddb6 100644 --- a/.github/workflows/beam_LoadTests_Go_CoGBK_Flink_batch.yml +++ b/.github/workflows/beam_LoadTests_Go_CoGBK_Flink_batch.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-go-cogbk-flink-batch-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-go-cogbk-flink-batch-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_LoadTests_Go_Combine_Flink_Batch.yml b/.github/workflows/beam_LoadTests_Go_Combine_Flink_Batch.yml index e814cc809be2..cdb034edcd27 100644 --- a/.github/workflows/beam_LoadTests_Go_Combine_Flink_Batch.yml +++ b/.github/workflows/beam_LoadTests_Go_Combine_Flink_Batch.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-go-combine-flink-batch-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-go-combine-flink-batch-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_LoadTests_Go_GBK_Flink_Batch.yml b/.github/workflows/beam_LoadTests_Go_GBK_Flink_Batch.yml index 8c01bc1cf304..f95e1c831da7 100644 --- a/.github/workflows/beam_LoadTests_Go_GBK_Flink_Batch.yml +++ b/.github/workflows/beam_LoadTests_Go_GBK_Flink_Batch.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-go-gbk-flink-batch-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-go-gbk-flink-batch-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_LoadTests_Go_ParDo_Flink_Batch.yml b/.github/workflows/beam_LoadTests_Go_ParDo_Flink_Batch.yml index ba7323a8b63c..89b31e02261d 100644 --- a/.github/workflows/beam_LoadTests_Go_ParDo_Flink_Batch.yml +++ b/.github/workflows/beam_LoadTests_Go_ParDo_Flink_Batch.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-go-pardo-flink-batch-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-go-pardo-flink-batch-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_LoadTests_Go_SideInput_Flink_Batch.yml b/.github/workflows/beam_LoadTests_Go_SideInput_Flink_Batch.yml index 5440ce968898..7ab3d837721b 100644 --- a/.github/workflows/beam_LoadTests_Go_SideInput_Flink_Batch.yml +++ b/.github/workflows/beam_LoadTests_Go_SideInput_Flink_Batch.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-go-sideinput-flink-batch-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-go-sideinput-flink-batch-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_LoadTests_Java_CoGBK_Dataflow_Streaming.yml b/.github/workflows/beam_LoadTests_Java_CoGBK_Dataflow_Streaming.yml index 2b631d2f7664..659c85b002df 100644 --- a/.github/workflows/beam_LoadTests_Java_CoGBK_Dataflow_Streaming.yml +++ b/.github/workflows/beam_LoadTests_Java_CoGBK_Dataflow_Streaming.yml @@ -124,4 +124,5 @@ jobs: uses: EnricoMi/publish-unit-test-result-action@v2 if: always() with: - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_LoadTests_Java_GBK_Smoke.yml b/.github/workflows/beam_LoadTests_Java_GBK_Smoke.yml index 6dea32d4f5a0..d3b6c38ce7ae 100644 --- a/.github/workflows/beam_LoadTests_Java_GBK_Smoke.yml +++ b/.github/workflows/beam_LoadTests_Java_GBK_Smoke.yml @@ -106,7 +106,7 @@ jobs: arguments: | --info \ -PloadTest.mainClass=org.apache.beam.sdk.loadtests.GroupByKeyLoadTest \ - -Prunner=:runners:flink:1.15 \ + -Prunner=:runners:flink:1.19 \ '-PloadTest.args=${{ env.beam_LoadTests_Java_GBK_Smoke_test_arguments_3 }}' \ - name: run GroupByKey load test Spark uses: ./.github/actions/gradle-command-self-hosted-action @@ -115,4 +115,4 @@ jobs: arguments: | -PloadTest.mainClass=org.apache.beam.sdk.loadtests.GroupByKeyLoadTest \ -Prunner=:runners:spark:3 \ - '-PloadTest.args=${{ env.beam_LoadTests_Java_GBK_Smoke_test_arguments_4 }}' \ No newline at end of file + '-PloadTest.args=${{ env.beam_LoadTests_Java_GBK_Smoke_test_arguments_4 }}' diff --git a/.github/workflows/beam_LoadTests_Python_CoGBK_Flink_Batch.yml b/.github/workflows/beam_LoadTests_Python_CoGBK_Flink_Batch.yml index e2afb2e2cfd7..9b0dec2249f6 100644 --- a/.github/workflows/beam_LoadTests_Python_CoGBK_Flink_Batch.yml +++ b/.github/workflows/beam_LoadTests_Python_CoGBK_Flink_Batch.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-py-cogbk-flink-batch-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-python-cogbk-flink-batch-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_LoadTests_Python_Combine_Flink_Batch.yml b/.github/workflows/beam_LoadTests_Python_Combine_Flink_Batch.yml index 0f666a0b7db6..6363de044149 100644 --- a/.github/workflows/beam_LoadTests_Python_Combine_Flink_Batch.yml +++ b/.github/workflows/beam_LoadTests_Python_Combine_Flink_Batch.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-py-cmb-flink-batch-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-py-cmb-flink-batch-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_LoadTests_Python_Combine_Flink_Streaming.yml b/.github/workflows/beam_LoadTests_Python_Combine_Flink_Streaming.yml index 6f491e6b9fa9..243e9d32c066 100644 --- a/.github/workflows/beam_LoadTests_Python_Combine_Flink_Streaming.yml +++ b/.github/workflows/beam_LoadTests_Python_Combine_Flink_Streaming.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-py-cmb-flink-streaming-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-py-cmb-flink-streaming-${{ github.run_id }} jobs: @@ -65,7 +65,7 @@ jobs: (github.event_name == 'schedule' && github.repository == 'apache/beam') || github.event.comment.body == 'Run Load Tests Python Combine Flink Streaming' runs-on: [self-hosted, ubuntu-20.04, main] - timeout-minutes: 720 + timeout-minutes: 80 name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) strategy: matrix: @@ -89,17 +89,22 @@ jobs: test-type: load test-language: python argument-file-paths: | - ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt - ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt + ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_1.txt + ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_2.txt + # large loads do not work now + # ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt + # ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt - name: Start Flink with parallelism 16 env: FLINK_NUM_WORKERS: 16 + HIGH_MEM_MACHINE: n1-highmem-16 + HIGH_MEM_FLINK_PROPS: flink:taskmanager.memory.process.size=16g,flink:taskmanager.memory.flink.size=12g,flink:taskmanager.memory.jvm-overhead.max=4g,flink:jobmanager.memory.process.size=6g,flink:jobmanager.memory.jvm-overhead.max= 2g,flink:jobmanager.memory.flink.size=4g run: | cd ${{ github.workspace }}/.test-infra/dataproc; ./flink_cluster.sh create - name: get current time run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV - # The env variables are created and populated in the test-arguments-action as "_test_arguments_" - - name: run Load test 2GB Fanout 4 + # The env variables are created and populated in the test-arguments-action as "_test_arguments_" + - name: run Load test small Fanout 1 uses: ./.github/actions/gradle-command-self-hosted-action with: gradle-command: :sdks:python:apache_beam:testing:load_tests:run @@ -108,7 +113,7 @@ jobs: -PloadTest.mainClass=apache_beam.testing.load_tests.combine_test \ -Prunner=PortableRunner \ '-PloadTest.args=${{ env.beam_LoadTests_Python_Combine_Flink_Streaming_test_arguments_1 }} --job_name=load-tests-python-flink-streaming-combine-4-${{env.NOW_UTC}}' \ - - name: run Load test 2GB Fanout 8 + - name: run Load test small Fanout 2 uses: ./.github/actions/gradle-command-self-hosted-action with: gradle-command: :sdks:python:apache_beam:testing:load_tests:run @@ -123,4 +128,4 @@ jobs: ${{ github.workspace }}/.test-infra/dataproc/flink_cluster.sh delete # // TODO(https://github.com/apache/beam/issues/20402). Skipping some cases because they are too slow: - # load-tests-python-flink-streaming-combine-1' \ No newline at end of file + # load-tests-python-flink-streaming-combine-1' diff --git a/.github/workflows/beam_LoadTests_Python_GBK_Flink_Batch.yml b/.github/workflows/beam_LoadTests_Python_GBK_Flink_Batch.yml index c938b284a866..e05885246090 100644 --- a/.github/workflows/beam_LoadTests_Python_GBK_Flink_Batch.yml +++ b/.github/workflows/beam_LoadTests_Python_GBK_Flink_Batch.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-py-gbk-flk-batch-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-py-gbk-flk-batch-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_LoadTests_Python_ParDo_Flink_Batch.yml b/.github/workflows/beam_LoadTests_Python_ParDo_Flink_Batch.yml index b6c86e01c299..8d907cf643bf 100644 --- a/.github/workflows/beam_LoadTests_Python_ParDo_Flink_Batch.yml +++ b/.github/workflows/beam_LoadTests_Python_ParDo_Flink_Batch.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-py-pardo-flink-batch-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-python-pardo-flink-batch-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_LoadTests_Python_ParDo_Flink_Streaming.yml b/.github/workflows/beam_LoadTests_Python_ParDo_Flink_Streaming.yml index a6443c0df10b..142d1b5e2dc2 100644 --- a/.github/workflows/beam_LoadTests_Python_ParDo_Flink_Streaming.yml +++ b/.github/workflows/beam_LoadTests_Python_ParDo_Flink_Streaming.yml @@ -50,12 +50,12 @@ env: GCLOUD_ZONE: us-central1-a CLUSTER_NAME: beam-loadtests-py-pardo-flink-stream-${{ github.run_id }} GCS_BUCKET: gs://beam-flink-cluster - FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar FLINK_TASKMANAGER_SLOTS: 1 DETACHED_MODE: true HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest - JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.15_job_server:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest ARTIFACTS_DIR: gs://beam-flink-cluster/beam-loadtests-python-pardo-flink-stream-${{ github.run_id }} jobs: diff --git a/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Avro.yml b/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Avro.yml index af0569f4784a..74932079fe4c 100644 --- a/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Avro.yml +++ b/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Avro.yml @@ -102,4 +102,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Json.yml b/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Json.yml index 9e3962e2576e..05e5369a6384 100644 --- a/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Json.yml +++ b/.github/workflows/beam_PerformanceTests_BigQueryIO_Batch_Java_Json.yml @@ -102,4 +102,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PerformanceTests_BigQueryIO_Streaming_Java.yml b/.github/workflows/beam_PerformanceTests_BigQueryIO_Streaming_Java.yml index 7514bd5cacb3..32db2cff6cbc 100644 --- a/.github/workflows/beam_PerformanceTests_BigQueryIO_Streaming_Java.yml +++ b/.github/workflows/beam_PerformanceTests_BigQueryIO_Streaming_Java.yml @@ -102,4 +102,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PerformanceTests_SQLBigQueryIO_Batch_Java.yml b/.github/workflows/beam_PerformanceTests_SQLBigQueryIO_Batch_Java.yml index 6ac07a1bd76c..d04a6e63c800 100644 --- a/.github/workflows/beam_PerformanceTests_SQLBigQueryIO_Batch_Java.yml +++ b/.github/workflows/beam_PerformanceTests_SQLBigQueryIO_Batch_Java.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PerformanceTests_WordCountIT_PythonVersions.yml b/.github/workflows/beam_PerformanceTests_WordCountIT_PythonVersions.yml index e9ef9cd1716a..756ecb5a58c2 100644 --- a/.github/workflows/beam_PerformanceTests_WordCountIT_PythonVersions.yml +++ b/.github/workflows/beam_PerformanceTests_WordCountIT_PythonVersions.yml @@ -115,4 +115,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/playground_backend_precommit.yml b/.github/workflows/beam_Playground_Precommit.yml similarity index 75% rename from .github/workflows/playground_backend_precommit.yml rename to .github/workflows/beam_Playground_Precommit.yml index 9ba6cf20534f..edb50661b1ee 100644 --- a/.github/workflows/playground_backend_precommit.yml +++ b/.github/workflows/beam_Playground_Precommit.yml @@ -17,10 +17,12 @@ name: Playground PreCommit on: workflow_dispatch: - pull_request: + pull_request_target: paths: - .github/workflows/playground_backend_precommit.yml - playground/backend/** + issue_comment: + types: [created] env: DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} @@ -28,17 +30,30 @@ env: GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} jobs: - precommit_check: - name: precommit-check - runs-on: ubuntu-latest + beam_Playground_PreCommit: + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'pull_request_target' || + github.event.comment.body == 'Run Playground PreCommit' + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + runs-on: [self-hosted, ubuntu-20.04, main] + strategy: + fail-fast: false + matrix: + job_name: [beam_Playground_PreCommit] + job_phrase: [Run Playground PreCommit] env: DATASTORE_EMULATOR_VERSION: '423.0.0' PYTHON_VERSION: '3.9' JAVA_VERSION: '11' steps: - - name: Check out the repo - uses: actions/checkout@v4 - + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action with: @@ -58,7 +73,7 @@ jobs: sudo chmod 644 /etc/apt/trusted.gpg.d/scalasbt-release.gpg sudo apt-get update --yes sudo apt-get install sbt --yes - sudo wget https://codeload.github.com/spotify/scio.g8/zip/7c1ba7c1651dfd70976028842e721da4107c0d6d -O scio.g8.zip && unzip scio.g8.zip && mv scio.g8-7c1ba7c1651dfd70976028842e721da4107c0d6d /opt/scio.g8 + sudo wget https://codeload.github.com/spotify/scio.g8/zip/7c1ba7c1651dfd70976028842e721da4107c0d6d -O scio.g8.zip && unzip scio.g8.zip && sudo mv scio.g8-7c1ba7c1651dfd70976028842e721da4107c0d6d /opt/scio.g8 - name: Set up Cloud SDK and its components uses: google-github-actions/setup-gcloud@v2 with: diff --git a/.github/workflows/beam_PostCommit_Java.yml b/.github/workflows/beam_PostCommit_Java.yml index 3428551cb8f9..4fafa3b2a993 100644 --- a/.github/workflows/beam_PostCommit_Java.yml +++ b/.github/workflows/beam_PostCommit_Java.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Avro_Versions.yml b/.github/workflows/beam_PostCommit_Java_Avro_Versions.yml index e3a9db23ed67..8ffcc4a28a71 100644 --- a/.github/workflows/beam_PostCommit_Java_Avro_Versions.yml +++ b/.github/workflows/beam_PostCommit_Java_Avro_Versions.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml b/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml index 1a6f7c14db50..8707b515e10b 100644 --- a/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml +++ b/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml @@ -110,3 +110,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_DataflowV1.yml b/.github/workflows/beam_PostCommit_Java_DataflowV1.yml index e7c2aa6fe7e2..752b15936b5f 100644 --- a/.github/workflows/beam_PostCommit_Java_DataflowV1.yml +++ b/.github/workflows/beam_PostCommit_Java_DataflowV1.yml @@ -94,4 +94,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_DataflowV2.yml b/.github/workflows/beam_PostCommit_Java_DataflowV2.yml index 3c0a46d6bb40..cb107572b621 100644 --- a/.github/workflows/beam_PostCommit_Java_DataflowV2.yml +++ b/.github/workflows/beam_PostCommit_Java_DataflowV2.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow.yml index 469d7e31f173..81725c4005af 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow.yml @@ -89,4 +89,5 @@ jobs: uses: EnricoMi/publish-unit-test-result-action@v2 if: always() with: - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_ARM.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_ARM.yml index 9fd84daef63b..eacdfe5a5c23 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_ARM.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_ARM.yml @@ -119,3 +119,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_Java.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_Java.yml index 13ab05f8f173..efb926681cbf 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_Java.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_Java.yml @@ -97,4 +97,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2.yml index 9be8a34f3732..1882cdf1d76b 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2_Java.yml b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2_Java.yml index cd2486ae8e10..05b28ac93658 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2_Java.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Dataflow_V2_Java.yml @@ -104,4 +104,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Direct.yml b/.github/workflows/beam_PostCommit_Java_Examples_Direct.yml index ca06e72877c7..a746acb4333f 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Direct.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Direct.yml @@ -92,4 +92,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml b/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml index 4077b7be68fe..f72910bd15bc 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml @@ -80,7 +80,7 @@ jobs: - name: run examplesIntegrationTest script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.15:examplesIntegrationTest + gradle-command: :runners:flink:1.19:examplesIntegrationTest - name: Archive JUnit Test Results uses: actions/upload-artifact@v4 if: ${{ !success() }} @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Spark.yml b/.github/workflows/beam_PostCommit_Java_Examples_Spark.yml index 8008daf4584f..c3620e46fac9 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Spark.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Spark.yml @@ -92,4 +92,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Hadoop_Versions.yml b/.github/workflows/beam_PostCommit_Java_Hadoop_Versions.yml index 67a48b105955..1202ecc0e27f 100644 --- a/.github/workflows/beam_PostCommit_Java_Hadoop_Versions.yml +++ b/.github/workflows/beam_PostCommit_Java_Hadoop_Versions.yml @@ -100,4 +100,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_IO_Performance_Tests.yml b/.github/workflows/beam_PostCommit_Java_IO_Performance_Tests.yml index b8c79e2677ca..6023a895a458 100644 --- a/.github/workflows/beam_PostCommit_Java_IO_Performance_Tests.yml +++ b/.github/workflows/beam_PostCommit_Java_IO_Performance_Tests.yml @@ -88,17 +88,6 @@ jobs: uses: ./.github/actions/setup-environment-action with: java-version: default - - name: Authenticate on GCP - uses: google-github-actions/auth@v2 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - token_format: 'access_token' - - name: Setup gcloud - uses: google-github-actions/setup-gcloud@v2 - with: - project_id: ${{ secrets.GCP_PROJECT_ID }} - skip_install: true - name: run scheduled javaPostcommitIOPerformanceTests script if: github.event_name == 'schedule' #This ensures only scheduled runs publish metrics publicly by changing which exportTable is configured uses: ./.github/actions/gradle-command-self-hosted-action @@ -128,3 +117,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java11.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java11.yml index 37f784770477..323f85b9851a 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java11.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java11.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java17.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java17.yml index 377602ad08dd..1ccb26f5aa1f 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java17.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Dataflow_Java17.yml @@ -96,4 +96,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java11.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java11.yml index 80406cf4eb0c..02ac93135957 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java11.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java11.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java17.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java17.yml index 3cbc317317c2..2cbf60a48d2e 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java17.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java17.yml @@ -96,4 +96,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java21.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java21.yml index 97fd1fb4913e..6a7058ef566d 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java21.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Direct_Java21.yml @@ -97,4 +97,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Flink_Java11.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Flink_Java11.yml index 1a7405836f69..1559061634d3 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Flink_Java11.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Flink_Java11.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Jpms_Spark_Java11.yml b/.github/workflows/beam_PostCommit_Java_Jpms_Spark_Java11.yml index eec4867a997b..1b4f8c5bcce5 100644 --- a/.github/workflows/beam_PostCommit_Java_Jpms_Spark_Java11.yml +++ b/.github/workflows/beam_PostCommit_Java_Jpms_Spark_Java11.yml @@ -91,4 +91,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Nexmark_Flink.yml b/.github/workflows/beam_PostCommit_Java_Nexmark_Flink.yml index 58b12b4e4530..ef69a2918196 100644 --- a/.github/workflows/beam_PostCommit_Java_Nexmark_Flink.yml +++ b/.github/workflows/beam_PostCommit_Java_Nexmark_Flink.yml @@ -102,7 +102,7 @@ jobs: with: gradle-command: :sdks:java:testing:nexmark:run arguments: | - -Pnexmark.runner=:runners:flink:1.15 \ + -Pnexmark.runner=:runners:flink:1.19 \ "${{ env.GRADLE_COMMAND_ARGUMENTS }} --streaming=${{ matrix.streaming }} --queryLanguage=${{ matrix.queryLanguage }}" \ - name: run PostCommit Java Nexmark Flink (${{ matrix.streaming }}) script if: matrix.queryLanguage == 'none' @@ -110,5 +110,5 @@ jobs: with: gradle-command: :sdks:java:testing:nexmark:run arguments: | - -Pnexmark.runner=:runners:flink:1.15 \ - "${{ env.GRADLE_COMMAND_ARGUMENTS }}--streaming=${{ matrix.streaming }}" \ No newline at end of file + -Pnexmark.runner=:runners:flink:1.19 \ + "${{ env.GRADLE_COMMAND_ARGUMENTS }}--streaming=${{ matrix.streaming }}" diff --git a/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml b/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml index ff3cce441069..8c5fcb1acff4 100644 --- a/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml @@ -77,7 +77,7 @@ jobs: - name: run PostCommit Java Flink PortableValidatesRunner Streaming script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: runners:flink:1.15:job-server:validatesPortableRunnerStreaming + gradle-command: runners:flink:1.19:job-server:validatesPortableRunnerStreaming - name: Archive JUnit Test Results uses: actions/upload-artifact@v4 if: ${{ !success() }} @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_PVR_Samza.yml b/.github/workflows/beam_PostCommit_Java_PVR_Samza.yml index 7cc48ebd4b0e..c1a22b9c871d 100644 --- a/.github/workflows/beam_PostCommit_Java_PVR_Samza.yml +++ b/.github/workflows/beam_PostCommit_Java_PVR_Samza.yml @@ -100,4 +100,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml b/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml index ad10bfc684d8..76ab560f15ec 100644 --- a/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml b/.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml index 19329026c034..cf85d9563122 100644 --- a/.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml +++ b/.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml @@ -101,5 +101,5 @@ jobs: with: gradle-command: :sdks:java:testing:tpcds:run arguments: | - -Ptpcds.runner=:runners:flink:1.18 \ + -Ptpcds.runner=:runners:flink:1.19 \ "-Ptpcds.args=${{env.tpcdsBigQueryArgs}} ${{env.tpcdsInfluxDBArgs}} ${{ env.GRADLE_COMMAND_ARGUMENTS }} --queries=${{env.tpcdsQueriesArg}}" \ diff --git a/.github/workflows/beam_PostCommit_Java_Sickbay.yml b/.github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml similarity index 63% rename from .github/workflows/beam_PostCommit_Java_Sickbay.yml rename to .github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml index 95c36fc863cf..73fd6f0b78fa 100644 --- a/.github/workflows/beam_PostCommit_Java_Sickbay.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml @@ -15,13 +15,17 @@ # specific language governing permissions and limitations # under the License. -name: PostCommit Java Sickbay +name: PostCommit Java ValidatesDistrolessContainer Dataflow on: schedule: - - cron: '30 4/6 * * *' + - cron: '30 6/8 * * *' pull_request_target: - paths: ['.github/trigger_files/beam_PostCommit_Java_Sickbay.json'] + paths: + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json' + - '.github/trigger_files/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.json' + workflow_dispatch: # This allows a subsequently queued workflow run to interrupt previous runs @@ -51,19 +55,19 @@ env: GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} jobs: - beam_PostCommit_Java_Sickbay: + beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow: name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) runs-on: [self-hosted, ubuntu-20.04, main] - timeout-minutes: 120 + timeout-minutes: 390 strategy: - matrix: - job_name: [beam_PostCommit_Java_Sickbay] - job_phrase: [Run Java Sickbay] + matrix: + job_name: [beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow] + job_phrase: [Run Java Dataflow ValidatesDistrolessContainer] if: | github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request_target' || (github.event_name == 'schedule' && github.repository == 'apache/beam') || - github.event.comment.body == 'Run Java Sickbay' + github.event.comment.body == 'Run Java Dataflow ValidatesDistrolessContainer' steps: - uses: actions/checkout@v4 - name: Setup repository @@ -74,10 +78,28 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: run PostCommit Java Sickbay script + with: + java-version: | + 17 + 21 + - name: Setup docker + run: | + gcloud auth configure-docker us-docker.pkg.dev --quiet + gcloud auth configure-docker us.gcr.io --quiet + gcloud auth configure-docker gcr.io --quiet + gcloud auth configure-docker us-central1-docker.pkg.dev --quiet + - name: run validatesDistrolessContainer script (Java 17) + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :runners:google-cloud-dataflow-java:examplesJavaRunnerV2IntegrationTestDistroless + arguments: '-PtestJavaVersion=java17 -PdockerTag=$(date +%s)' + max-workers: 12 + - name: run validatesDistrolessContainer script (Java 21) uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :javaPostCommitSickbay + gradle-command: :runners:google-cloud-dataflow-java:examplesJavaRunnerV2IntegrationTestDistroless + arguments: '-PtestJavaVersion=java21 -PdockerTag=$(date +%s)' + max-workers: 12 - name: Archive JUnit Test Results uses: actions/upload-artifact@v4 if: ${{ !success() }} @@ -90,4 +112,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow.yml index d66381393725..c85c0b8468dc 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_JavaVersions.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_JavaVersions.yml index da2ba2f88465..5963a33007e0 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_JavaVersions.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_JavaVersions.yml @@ -111,4 +111,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.yml index edb055321c87..2e8227fb84a6 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.yml index 8957ce7de053..2abc081e6ae5 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.yml index 2a98746a0b84..fde10e0898e9 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct.yml index 3f48bb921805..f439be9ec58e 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct_JavaVersions.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct_JavaVersions.yml index 75ebbda93f80..eb70a654c93d 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct_JavaVersions.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Direct_JavaVersions.yml @@ -106,4 +106,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink.yml index b6334d8e9858..1442f5ffafc0 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink.yml @@ -78,7 +78,7 @@ jobs: - name: run validatesRunner script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.18:validatesRunner + gradle-command: :runners:flink:1.19:validatesRunner - name: Archive JUnit Test Results uses: actions/upload-artifact@v4 if: ${{ !success() }} @@ -93,3 +93,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.yml index 15c99d7bfb37..0f12ce6f90ef 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.yml @@ -80,11 +80,11 @@ jobs: 11 - name: run jar Java8 script run: | - ./gradlew :runners:flink:1.15:jar :runners:flink:1.15:testJar + ./gradlew :runners:flink:1.19:jar :runners:flink:1.19:testJar - name: run validatesRunner Java8 script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.15:validatesRunner + gradle-command: :runners:flink:1.19:validatesRunner arguments: | -x shadowJar \ -x shadowTestJar \ @@ -109,4 +109,5 @@ jobs: large_files: true commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Samza.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Samza.yml index 794308d3a85e..edcb45303fd4 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Samza.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Samza.yml @@ -96,4 +96,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml index d1f264aaac01..d05963263931 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml index 15863d4c8c9b..da04582a7caa 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml index c05284186617..8d531c120dd6 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml @@ -108,4 +108,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml index 522cb300c687..8310e5ed8bb2 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml @@ -90,4 +90,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml index 36fc06aea421..3b130b6d290f 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml @@ -89,4 +89,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_PortableJar_Flink.yml b/.github/workflows/beam_PostCommit_PortableJar_Flink.yml index 37bfe68d9b20..318b5104c39c 100644 --- a/.github/workflows/beam_PostCommit_PortableJar_Flink.yml +++ b/.github/workflows/beam_PostCommit_PortableJar_Flink.yml @@ -94,4 +94,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_PortableJar_Spark.yml b/.github/workflows/beam_PostCommit_PortableJar_Spark.yml index ce7be60133d7..0712dfb255b7 100644 --- a/.github/workflows/beam_PostCommit_PortableJar_Spark.yml +++ b/.github/workflows/beam_PostCommit_PortableJar_Spark.yml @@ -94,4 +94,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python.yml b/.github/workflows/beam_PostCommit_Python.yml index 4770515c75fb..93b85a318487 100644 --- a/.github/workflows/beam_PostCommit_Python.yml +++ b/.github/workflows/beam_PostCommit_Python.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Arm.yml b/.github/workflows/beam_PostCommit_Python_Arm.yml index 48fb00b1bb9d..352b95e6747a 100644 --- a/.github/workflows/beam_PostCommit_Python_Arm.yml +++ b/.github/workflows/beam_PostCommit_Python_Arm.yml @@ -124,4 +124,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Dependency.yml b/.github/workflows/beam_PostCommit_Python_Dependency.yml index 6e7c4ddbd3eb..80e1bbc290c9 100644 --- a/.github/workflows/beam_PostCommit_Python_Dependency.yml +++ b/.github/workflows/beam_PostCommit_Python_Dependency.yml @@ -96,3 +96,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Python_Examples_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_Examples_Dataflow.yml index 4ce3b1893215..bf8330a2ae58 100644 --- a/.github/workflows/beam_PostCommit_Python_Examples_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_Examples_Dataflow.yml @@ -94,4 +94,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Examples_Direct.yml b/.github/workflows/beam_PostCommit_Python_Examples_Direct.yml index a6bb49f4e444..e271b7da9a7b 100644 --- a/.github/workflows/beam_PostCommit_Python_Examples_Direct.yml +++ b/.github/workflows/beam_PostCommit_Python_Examples_Direct.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Examples_Flink.yml b/.github/workflows/beam_PostCommit_Python_Examples_Flink.yml index bda807eb147b..28fd13c181b3 100644 --- a/.github/workflows/beam_PostCommit_Python_Examples_Flink.yml +++ b/.github/workflows/beam_PostCommit_Python_Examples_Flink.yml @@ -55,7 +55,7 @@ jobs: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request_target' || startsWith(github.event.comment.body, 'Run Python Examples_Flink') - runs-on: [self-hosted, ubuntu-20.04, main] + runs-on: [self-hosted, ubuntu-20.04, highmem22] timeout-minutes: 240 name: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) strategy: @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Python_Examples_Spark.yml b/.github/workflows/beam_PostCommit_Python_Examples_Spark.yml index d866d412507b..5df6bcf8c01c 100644 --- a/.github/workflows/beam_PostCommit_Python_Examples_Spark.yml +++ b/.github/workflows/beam_PostCommit_Python_Examples_Spark.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_MongoDBIO_IT.yml b/.github/workflows/beam_PostCommit_Python_MongoDBIO_IT.yml index 578775a9d3ed..0d334b679dc5 100644 --- a/.github/workflows/beam_PostCommit_Python_MongoDBIO_IT.yml +++ b/.github/workflows/beam_PostCommit_Python_MongoDBIO_IT.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow.yml index bcd936324124..6e16f43476b2 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow.yml @@ -108,3 +108,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow_With_RC.yml b/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow_With_RC.yml index f2eba045722c..3ab7257f8a9d 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow_With_RC.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesContainer_Dataflow_With_RC.yml @@ -106,4 +106,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml new file mode 100644 index 000000000000..c294dd3c9068 --- /dev/null +++ b/.github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PostCommit Python ValidatesDistrolessContainer Dataflow + +on: + schedule: + - cron: '15 5/6 * * *' + pull_request_target: + paths: + - 'release/trigger_all_tests.json' + # Since distroless is based on original sdk container images, we want to also trigger distroless checks here. + - '.github/trigger_files/beam_PostCommit_Python_ValidatesContainer_Dataflow.json' + - '.github/trigger_files/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.json' + workflow_dispatch: + issue_comment: + types: [created] + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: write + checks: write + contents: read + deployments: read + id-token: none + issues: write + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PostCommit_Python_ValidatesContainer_Dataflow: + if: | + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + github.event_name == 'workflow_dispatch' || + github.event_name == 'pull_request_target' || + startsWith(github.event.comment.body, 'Run Python Dataflow ValidatesDistrolessContainer') + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 100 + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + strategy: + fail-fast: false + matrix: + job_name: ["beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow"] + job_phrase: ["Run Python Dataflow ValidatesDistrolessContainer"] + python_version: ['3.9','3.10','3.11','3.12'] + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} ${{ matrix.python_version }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + java-version: | + 11 + 8 + python-version: ${{ matrix.python_version }} + - name: Setup docker + run: | + gcloud auth configure-docker us-docker.pkg.dev --quiet + gcloud auth configure-docker us.gcr.io --quiet + gcloud auth configure-docker gcr.io --quiet + gcloud auth configure-docker us-central1-docker.pkg.dev --quiet + - name: Set PY_VER_CLEAN + id: set_py_ver_clean + run: | + PY_VER=${{ matrix.python_version }} + PY_VER_CLEAN=${PY_VER//.} + echo "py_ver_clean=$PY_VER_CLEAN" >> $GITHUB_OUTPUT + - name: Run validatesDistrolessContainer script + env: + USER: github-actions + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:python:test-suites:dataflow:py${{steps.set_py_ver_clean.outputs.py_ver_clean}}:validatesDistrolessContainer + arguments: | + -PpythonVersion=${{ matrix.python_version }} \ + - name: Archive Python Test Results + uses: actions/upload-artifact@v4 + if: failure() + with: + name: Python Test Results + path: '**/pytest*.xml' + - name: Publish Python Test Results + uses: EnricoMi/publish-unit-test-result-action@v2 + if: always() + with: + commit: '${{ env.prsha || env.GITHUB_SHA }}' + comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} + files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Dataflow.yml index 1876950c7a93..f8daa1a96634 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Dataflow.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Flink.yml b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Flink.yml index f837c7476e12..9277bd68fc01 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Flink.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Flink.yml @@ -103,4 +103,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Samza.yml b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Samza.yml index 91c249adf338..e058724cd2ac 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Samza.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Samza.yml @@ -102,4 +102,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Spark.yml b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Spark.yml index 7e87aaff22cc..a47f758ed410 100644 --- a/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Spark.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesRunner_Spark.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml index b3f37c6b39f0..bd266cf6fdab 100644 --- a/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml @@ -93,4 +93,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml b/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml index 137d7bc13d2f..6d26d1c46012 100644 --- a/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml +++ b/.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml @@ -92,4 +92,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml index 8fc0db189078..08e99fa0fe0f 100644 --- a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml @@ -95,4 +95,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml new file mode 100644 index 000000000000..a7643c795af4 --- /dev/null +++ b/.github/workflows/beam_PostCommit_Python_Xlang_IO_Direct.yml @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PostCommit Python Xlang IO Direct + +on: + schedule: + - cron: '30 5/6 * * *' + pull_request_target: + paths: ['release/trigger_all_tests.json', '.github/trigger_files/beam_PostCommit_Python_Xlang_IO_Direct.json'] + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: write + checks: write + contents: read + deployments: read + id-token: none + issues: write + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PostCommit_Python_Xlang_IO_Direct: + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + github.event.comment.body == 'Run Python_Xlang_IO_Direct PostCommit' + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 100 + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + strategy: + matrix: + job_name: ["beam_PostCommit_Python_Xlang_IO_Direct"] + job_phrase: ["Run Python_Xlang_IO_Direct PostCommit"] + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + python-version: | + 3.9 + 3.12 + - name: run PostCommit Python Xlang IO Direct script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:python:test-suites:direct:ioCrossLanguagePostCommit + arguments: -PuseWheelDistribution + - name: Archive Python Test Results + uses: actions/upload-artifact@v4 + if: failure() + with: + name: Python Test Results + path: '**/pytest*.xml' + - name: Publish Python Test Results + uses: EnricoMi/publish-unit-test-result-action@v2 + if: always() + with: + commit: '${{ env.prsha || env.GITHUB_SHA }}' + comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_SQL.yml b/.github/workflows/beam_PostCommit_SQL.yml index c7d0b6dc98b9..aebea2b0564b 100644 --- a/.github/workflows/beam_PostCommit_SQL.yml +++ b/.github/workflows/beam_PostCommit_SQL.yml @@ -91,3 +91,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_TransformService_Direct.yml b/.github/workflows/beam_PostCommit_TransformService_Direct.yml index cb339eb9fb40..d0d72f3df13c 100644 --- a/.github/workflows/beam_PostCommit_TransformService_Direct.yml +++ b/.github/workflows/beam_PostCommit_TransformService_Direct.yml @@ -98,4 +98,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_Direct.yml b/.github/workflows/beam_PostCommit_XVR_Direct.yml index 023ae4f8cd31..af8b7fb1bf54 100644 --- a/.github/workflows/beam_PostCommit_XVR_Direct.yml +++ b/.github/workflows/beam_PostCommit_XVR_Direct.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_Flink.yml b/.github/workflows/beam_PostCommit_XVR_Flink.yml index 5cde38d24244..fe4404247448 100644 --- a/.github/workflows/beam_PostCommit_XVR_Flink.yml +++ b/.github/workflows/beam_PostCommit_XVR_Flink.yml @@ -47,7 +47,7 @@ env: DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} - FlinkVersion: 1.15 + FlinkVersion: 1.19 jobs: beam_PostCommit_XVR_Flink: @@ -110,4 +110,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_XVR_GoUsingJava_Dataflow.yml b/.github/workflows/beam_PostCommit_XVR_GoUsingJava_Dataflow.yml index 228f10b90cd0..0620023ce7d2 100644 --- a/.github/workflows/beam_PostCommit_XVR_GoUsingJava_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_XVR_GoUsingJava_Dataflow.yml @@ -102,3 +102,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_XVR_JavaUsingPython_Dataflow.yml b/.github/workflows/beam_PostCommit_XVR_JavaUsingPython_Dataflow.yml index 66770c9a1683..11a8a5c5f4f7 100644 --- a/.github/workflows/beam_PostCommit_XVR_JavaUsingPython_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_XVR_JavaUsingPython_Dataflow.yml @@ -95,4 +95,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.yml b/.github/workflows/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.yml index bfb602f89daf..c393a4113589 100644 --- a/.github/workflows/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.yml @@ -92,4 +92,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_PythonUsingJava_Dataflow.yml b/.github/workflows/beam_PostCommit_XVR_PythonUsingJava_Dataflow.yml index f1269a0ddd09..082aeb3f2ab2 100644 --- a/.github/workflows/beam_PostCommit_XVR_PythonUsingJava_Dataflow.yml +++ b/.github/workflows/beam_PostCommit_XVR_PythonUsingJava_Dataflow.yml @@ -95,4 +95,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_Samza.yml b/.github/workflows/beam_PostCommit_XVR_Samza.yml index 2d26c9131839..7e2dca61d41d 100644 --- a/.github/workflows/beam_PostCommit_XVR_Samza.yml +++ b/.github/workflows/beam_PostCommit_XVR_Samza.yml @@ -111,4 +111,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_XVR_Spark3.yml b/.github/workflows/beam_PostCommit_XVR_Spark3.yml index c1880e01292b..17fb58d9dd73 100644 --- a/.github/workflows/beam_PostCommit_XVR_Spark3.yml +++ b/.github/workflows/beam_PostCommit_XVR_Spark3.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PostRelease_NightlySnapshot.yml b/.github/workflows/beam_PostRelease_NightlySnapshot.yml index 9b7fb2af2f2d..0c4144af1a4f 100644 --- a/.github/workflows/beam_PostRelease_NightlySnapshot.yml +++ b/.github/workflows/beam_PostRelease_NightlySnapshot.yml @@ -59,6 +59,9 @@ jobs: uses: ./.github/actions/setup-environment-action with: java-version: default + - name: Setup temp local maven + id: setup_local_maven + run: echo "NEW_TEMP_DIR=$(mktemp -d)" >> $GITHUB_OUTPUT - name: run PostRelease validation script uses: ./.github/actions/gradle-command-self-hosted-action with: @@ -66,3 +69,7 @@ jobs: arguments: | -Pver='${{ github.event.inputs.RELEASE }}' \ -Prepourl='${{ github.event.inputs.SNAPSHOT_URL }}' \ + -PmavenLocalPath='${{ steps.setup_local_maven.outputs.NEW_TEMP_DIR }}' + - name: Clean up local maven + if: steps.setup_local_maven.outcome == 'success' + run: rm -rf '${{ steps.setup_local_maven.outputs.NEW_TEMP_DIR }}' diff --git a/.github/workflows/beam_PreCommit_Flink_Container.yml b/.github/workflows/beam_PreCommit_Flink_Container.yml new file mode 100644 index 000000000000..519b0273420a --- /dev/null +++ b/.github/workflows/beam_PreCommit_Flink_Container.yml @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PreCommit Flink Container + +on: + pull_request_target: + paths: + - 'model/**' + - 'sdks/python/**' + - 'release/**' + - 'sdks/java/io/kafka/**' + - 'runners/core-construction-java/**' + - 'runners/core-java/**' + - 'runners/extensions-java/**' + - 'runners/flink/**' + - 'runners/java-fn-execution/**' + - 'runners/reference/**' + - '.github/trigger_files/beam_PreCommit_Flink_Container.json' + - 'release/trigger_all_tests.json' + push: + branches: ['master', 'release-*'] + tags: 'v*' + schedule: + - cron: '0 */6 * * *' + workflow_dispatch: + +# Setting explicit permissions for the action to avoid the default permissions which are `write-all` +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + INFLUXDB_USER: ${{ secrets.INFLUXDB_USER }} + INFLUXDB_USER_PASSWORD: ${{ secrets.INFLUXDB_USER_PASSWORD }} + GCLOUD_ZONE: us-central1-a + CLUSTER_NAME: beam-precommit-flink-container-${{ github.run_id }} + GCS_BUCKET: gs://beam-flink-cluster + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz + HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar + FLINK_TASKMANAGER_SLOTS: 1 + DETACHED_MODE: true + HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest + ARTIFACTS_DIR: gs://beam-flink-cluster/beam-precommit-flink-container-${{ github.run_id }} + +jobs: + beam_PreCommit_Flink_Container: + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'push' || + github.event_name == 'schedule' || + github.event_name == 'pull_request_target' || + github.event.comment.body == 'Run Flink Container PreCommit' + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 45 + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + strategy: + matrix: + job_name: ["beam_PreCommit_Flink_Container"] + job_phrase: ["Run Flink Container PreCommit"] + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + python-version: default + - name: Prepare test arguments + uses: ./.github/actions/test-arguments-action + with: + test-type: precommit + test-language: go,python,java + argument-file-paths: | + ${{ github.workspace }}/.github/workflows/flink-tests-pipeline-options/go_Combine_Flink_Batch_small.txt + ${{ github.workspace }}/.github/workflows/flink-tests-pipeline-options/python_Combine_Flink_Batch_small.txt + ${{ github.workspace }}/.github/workflows/flink-tests-pipeline-options/java_Combine_Flink_Batch_small.txt + - name: get current time + run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV + - name: Start Flink with 2 workers + env: + FLINK_NUM_WORKERS: 2 + run: | + cd ${{ github.workspace }}/.test-infra/dataproc; ./flink_cluster.sh create + # Run a simple Go Combine load test to verify the Flink container + - name: Run Flink Container Test with Go Combine + timeout-minutes: 10 + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:go:test:load:run + arguments: | + -PloadTest.mainClass=combine \ + -Prunner=PortableRunner \ + '-PloadTest.args=${{ env.beam_PreCommit_Flink_Container_test_arguments_1 }} --job_name=flink-tests-go-${{env.NOW_UTC}}' + + # Run a Python Combine load test to verify the Flink container + - name: Run Flink Container Test with Python Combine + timeout-minutes: 20 + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:python:apache_beam:testing:load_tests:run + arguments: | + -PloadTest.mainClass=apache_beam.testing.load_tests.combine_test \ + -Prunner=FlinkRunner \ + '-PloadTest.args=${{ env.beam_PreCommit_Flink_Container_test_arguments_2 }} --job_name=flink-tests-python-${{env.NOW_UTC}}' + + # Run a Java Combine load test to verify the Flink container + - name: Run Flink Container Test with Java Combine + timeout-minutes: 10 + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:java:testing:load-tests:run + arguments: | + -PloadTest.mainClass=org.apache.beam.sdk.loadtests.CombineLoadTest \ + -Prunner=:runners:flink:1.17 \ + '-PloadTest.args=${{ env.beam_PreCommit_Flink_Container_test_arguments_3 }} --jobName=flink-tests-java11-${{env.NOW_UTC}}' + + - name: Teardown Flink + if: always() + run: | + ${{ github.workspace }}/.test-infra/dataproc/flink_cluster.sh delete diff --git a/.github/workflows/beam_PreCommit_ItFramework.yml b/.github/workflows/beam_PreCommit_ItFramework.yml index e078d4645757..e803fc023c67 100644 --- a/.github/workflows/beam_PreCommit_ItFramework.yml +++ b/.github/workflows/beam_PreCommit_ItFramework.yml @@ -101,4 +101,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java.yml b/.github/workflows/beam_PreCommit_Java.yml index 772eab98c343..bc25fb94f8f0 100644 --- a/.github/workflows/beam_PreCommit_Java.yml +++ b/.github/workflows/beam_PreCommit_Java.yml @@ -19,6 +19,7 @@ on: tags: ['v*'] branches: ['master', 'release-*'] paths: + - "buildSrc/**" - 'model/**' - 'sdks/java/**' - 'runners/**' @@ -197,6 +198,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml index ecbc85ca1b1d..cf0d0b660782 100644 --- a/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services2_IO_Direct.yml @@ -130,6 +130,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml index 55935251e6d9..9053bb730371 100644 --- a/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Amazon-Web-Services_IO_Direct.yml @@ -130,6 +130,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml index 4fbacecde4a4..8c0bb07e1acb 100644 --- a/.github/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Azure_IO_Direct.yml @@ -123,6 +123,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Cassandra_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Cassandra_IO_Direct.yml index e37bc5c56e2e..317b2e1f2ec1 100644 --- a/.github/workflows/beam_PreCommit_Java_Cassandra_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Cassandra_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Cdap_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Cdap_IO_Direct.yml index 68ebe3c28fb3..3e0208b758cc 100644 --- a/.github/workflows/beam_PreCommit_Java_Cdap_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Cdap_IO_Direct.yml @@ -109,6 +109,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Clickhouse_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Clickhouse_IO_Direct.yml index 5c0b169b0ba1..2be7607b5bc7 100644 --- a/.github/workflows/beam_PreCommit_Java_Clickhouse_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Clickhouse_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Csv_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Csv_IO_Direct.yml index ce91551c1121..6901e56c0bbb 100644 --- a/.github/workflows/beam_PreCommit_Java_Csv_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Csv_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Debezium_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Debezium_IO_Direct.yml index b6a0e6b999bd..6f32c3844b1a 100644 --- a/.github/workflows/beam_PreCommit_Java_Debezium_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Debezium_IO_Direct.yml @@ -114,6 +114,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_ElasticSearch_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_ElasticSearch_IO_Direct.yml index 78ab882d4774..11a95cf476c7 100644 --- a/.github/workflows/beam_PreCommit_Java_ElasticSearch_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_ElasticSearch_IO_Direct.yml @@ -117,6 +117,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml index 4bfb20a28e7c..8e22318bdcb9 100644 --- a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml +++ b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow.yml @@ -117,4 +117,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow_Java21.yml b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow_Java21.yml index 72fc945018f6..763de153b137 100644 --- a/.github/workflows/beam_PreCommit_Java_Examples_Dataflow_Java21.yml +++ b/.github/workflows/beam_PreCommit_Java_Examples_Dataflow_Java21.yml @@ -133,6 +133,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/beam_PreCommit_Java_File-schema-transform_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_File-schema-transform_IO_Direct.yml index e96dc7c883bf..e121fe1e53a2 100644 --- a/.github/workflows/beam_PreCommit_Java_File-schema-transform_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_File-schema-transform_IO_Direct.yml @@ -106,6 +106,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Flink_Versions.yml b/.github/workflows/beam_PreCommit_Java_Flink_Versions.yml index 19b0d56a8051..09bf906e5a38 100644 --- a/.github/workflows/beam_PreCommit_Java_Flink_Versions.yml +++ b/.github/workflows/beam_PreCommit_Java_Flink_Versions.yml @@ -104,4 +104,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml index 2256d0a91cb8..ee5bea3d3ab3 100644 --- a/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml @@ -127,6 +127,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Google-ads_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Google-ads_IO_Direct.yml index c481251bef03..0e6bd11e7f1e 100644 --- a/.github/workflows/beam_PreCommit_Java_Google-ads_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Google-ads_IO_Direct.yml @@ -103,6 +103,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_HBase_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_HBase_IO_Direct.yml index 3b99e30bfbac..c334edd7f32d 100644 --- a/.github/workflows/beam_PreCommit_Java_HBase_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_HBase_IO_Direct.yml @@ -107,6 +107,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_HCatalog_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_HCatalog_IO_Direct.yml index 6d45ba82aa49..ed079c1e9dd1 100644 --- a/.github/workflows/beam_PreCommit_Java_HCatalog_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_HCatalog_IO_Direct.yml @@ -122,6 +122,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Hadoop_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Hadoop_IO_Direct.yml index c2beaa3c1099..442085586a3c 100644 --- a/.github/workflows/beam_PreCommit_Java_Hadoop_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Hadoop_IO_Direct.yml @@ -145,6 +145,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_IOs_Direct.yml b/.github/workflows/beam_PreCommit_Java_IOs_Direct.yml index 4e19a56dde0c..cd73d402c7ea 100644 --- a/.github/workflows/beam_PreCommit_Java_IOs_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_IOs_Direct.yml @@ -122,6 +122,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_InfluxDb_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_InfluxDb_IO_Direct.yml index 903a7cd73526..977781de506f 100644 --- a/.github/workflows/beam_PreCommit_Java_InfluxDb_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_InfluxDb_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml index 071cdb3bda3e..4759d48d979f 100644 --- a/.github/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_JDBC_IO_Direct.yml @@ -112,6 +112,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml index 650036345274..935315463358 100644 --- a/.github/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Jms_IO_Direct.yml @@ -112,6 +112,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Kafka_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Kafka_IO_Direct.yml index 0ede01376ce7..f177ec85fada 100644 --- a/.github/workflows/beam_PreCommit_Java_Kafka_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Kafka_IO_Direct.yml @@ -114,6 +114,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml index 494a738abf45..785748e793e9 100644 --- a/.github/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Kinesis_IO_Direct.yml @@ -137,6 +137,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml index e38c9a761dee..853e52db14db 100644 --- a/.github/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Kudu_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml index 11be57c05759..b3292ac5f29b 100644 --- a/.github/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_MongoDb_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml index ac8800f55cdf..ed0189d8006b 100644 --- a/.github/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Mqtt_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Neo4j_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Neo4j_IO_Direct.yml index 553300f1889c..62429a611f2a 100644 --- a/.github/workflows/beam_PreCommit_Java_Neo4j_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Neo4j_IO_Direct.yml @@ -114,6 +114,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Batch.yml b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Batch.yml index 7019e4799ace..b459c4625547 100644 --- a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Batch.yml +++ b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Batch.yml @@ -94,7 +94,7 @@ jobs: - name: run validatesPortableRunnerBatch script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.15:job-server:validatesPortableRunnerBatch + gradle-command: :runners:flink:1.19:job-server:validatesPortableRunnerBatch env: CLOUDSDK_CONFIG: ${{ env.KUBELET_GCLOUD_CONFIG_PATH }} - name: Archive JUnit Test Results @@ -108,4 +108,4 @@ jobs: with: name: java-code-coverage-report path: "**/build/test-results/**/*.xml" -# TODO: Investigate 'Max retries exceeded' issue with EnricoMi/publish-unit-test-result-action@v2. \ No newline at end of file +# TODO: Investigate 'Max retries exceeded' issue with EnricoMi/publish-unit-test-result-action@v2. diff --git a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml index 3c16b57c9bce..48f165f4e59f 100644 --- a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml +++ b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml @@ -99,7 +99,7 @@ jobs: - name: run PreCommit Java PVR Flink Docker script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.15:job-server:validatesPortableRunnerDocker + gradle-command: :runners:flink:1.19:job-server:validatesPortableRunnerDocker env: CLOUDSDK_CONFIG: ${{ env.KUBELET_GCLOUD_CONFIG_PATH}} - name: Archive JUnit Test Results @@ -114,4 +114,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PostCommit_Sickbay_Python.yml b/.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml similarity index 53% rename from .github/workflows/beam_PostCommit_Sickbay_Python.yml rename to .github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml index 6d253e03723d..ea5cf9b5578e 100644 --- a/.github/workflows/beam_PostCommit_Sickbay_Python.yml +++ b/.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml @@ -15,16 +15,40 @@ # specific language governing permissions and limitations # under the License. -name: PostCommit Sickbay Python +name: PreCommit Java PVR Prism Loopback on: + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/go/cmd/prism/**' + - 'runners/prism/**' + - 'runners/java-fn-execution/**' + - 'sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/**' + - '.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml' pull_request_target: - paths: ['.github/trigger_files/beam_PostCommit_Sickbay_Python.json'] + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/go/cmd/prism/**' + - 'runners/prism/**' + - 'runners/java-fn-execution/**' + - 'sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/**' + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PreCommit_Java_PVR_Prism_Loopback.json' + issue_comment: + types: [created] + schedule: + - cron: '22 2/6 * * *' workflow_dispatch: # This allows a subsequently queued workflow run to interrupt previous runs concurrency: - group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' cancel-in-progress: true #Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event @@ -49,57 +73,42 @@ env: GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} jobs: - beam_PostCommit_Sickbay_Python: - name: ${{ matrix.job_name }} (${{ matrix.job_phrase_1 }} ${{ matrix.python_version }} ${{ matrix.job_phrase_2 }}) - runs-on: [self-hosted, ubuntu-20.04, main] - timeout-minutes: 180 + beam_PreCommit_Java_PVR_Prism_Loopback: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) strategy: - fail-fast: false matrix: - job_name: [beam_PostCommit_Sickbay_Python] - job_phrase_1: [Run Python] - job_phrase_2: [PostCommit Sickbay] - python_version: ['3.9', '3.10', '3.11', '3.12'] + job_name: ["beam_PreCommit_Java_PVR_Prism_Loopback"] + job_phrase: ["Run Java_PVR_Prism_Loopback PreCommit"] + timeout-minutes: 240 + runs-on: [self-hosted, ubuntu-20.04] if: | - github.event_name == 'workflow_dispatch' || + github.event_name == 'push' || github.event_name == 'pull_request_target' || (github.event_name == 'schedule' && github.repository == 'apache/beam') || - (startswith(github.event.comment.body, 'Run Python') && - endswith(github.event.comment.body, 'PostCommit Sickbay')) + github.event_name == 'workflow_dispatch' || + github.event.comment.body == 'Run Java_PVR_Prism_Loopback PreCommit' steps: - uses: actions/checkout@v4 - name: Setup repository uses: ./.github/actions/setup-action with: - comment_phrase: ${{ matrix.job_phrase_1 }} ${{ matrix.python_version }} ${{ matrix.job_phrase_2 }} + comment_phrase: ${{ matrix.job_phrase }} github_token: ${{ secrets.GITHUB_TOKEN }} - github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase_1 }} ${{ matrix.python_version }} ${{ matrix.job_phrase_2 }}) + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - with: - python-version: ${{ matrix.python_version }} - - name: Set PY_VER_CLEAN - id: set_py_ver_clean - run: | - PY_VER=${{ matrix.python_version }} - PY_VER_CLEAN=${PY_VER//.} - echo "py_ver_clean=$PY_VER_CLEAN" >> $GITHUB_OUTPUT - - name: run PostCommit Python ${{ matrix.python_version }} script + - name: run prismLoopbackValidatesRunnerTests script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :sdks:python:test-suites:dataflow:py${{steps.set_py_ver_clean.outputs.py_ver_clean}}:postCommitSickbay - arguments: | - -PpythonVersion=${{ matrix.python_version }} \ - - name: Archive Python Test Results + gradle-command: :runners:prism:java:prismLoopbackValidatesRunnerTests + - name: Archive JUnit Test Results uses: actions/upload-artifact@v4 - if: failure() + if: ${{ !success() }} with: - name: Python Test Results - path: '**/pytest*.xml' - - name: Publish Python Test Results - uses: EnricoMi/publish-unit-test-result-action@v2 - if: always() + name: JUnit Test Results + path: "**/build/reports/tests/" + - name: Upload test report + uses: actions/upload-artifact@v4 with: - commit: '${{ env.prsha || env.GITHUB_SHA }}' - comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' + name: java-code-coverage-report + path: "**/build/test-results/**/*.xml" diff --git a/.github/workflows/beam_PreCommit_Java_Parquet_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Parquet_IO_Direct.yml index 0bec073fc37b..d217f0e88c39 100644 --- a/.github/workflows/beam_PreCommit_Java_Parquet_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Parquet_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Pulsar_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Pulsar_IO_Direct.yml index e25b4ff6fa94..3a9d62fb64c6 100644 --- a/.github/workflows/beam_PreCommit_Java_Pulsar_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Pulsar_IO_Direct.yml @@ -123,6 +123,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_RabbitMq_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_RabbitMq_IO_Direct.yml index eb343f193395..c72b04bc108d 100644 --- a/.github/workflows/beam_PreCommit_Java_RabbitMq_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_RabbitMq_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Redis_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Redis_IO_Direct.yml index 13b9c26b4b81..cd4ddc387ffc 100644 --- a/.github/workflows/beam_PreCommit_Java_Redis_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Redis_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_RequestResponse_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_RequestResponse_IO_Direct.yml index f1e8c3699aa6..1037ab972447 100644 --- a/.github/workflows/beam_PreCommit_Java_RequestResponse_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_RequestResponse_IO_Direct.yml @@ -103,6 +103,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_SingleStore_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_SingleStore_IO_Direct.yml index 4d289882353e..478dad9989b9 100644 --- a/.github/workflows/beam_PreCommit_Java_SingleStore_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_SingleStore_IO_Direct.yml @@ -107,6 +107,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml index 03577eff2860..403c26ac0ab0 100644 --- a/.github/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml @@ -116,6 +116,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml index 5aeaaec11dec..ca05b44875cb 100644 --- a/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Solace_IO_Direct.yml @@ -112,6 +112,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml index e6138a0c10d9..80cd5e492992 100644 --- a/.github/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Spark3_Versions.yml b/.github/workflows/beam_PreCommit_Java_Spark3_Versions.yml index 18f5a6c0c86e..c6b2d7e57128 100644 --- a/.github/workflows/beam_PreCommit_Java_Spark3_Versions.yml +++ b/.github/workflows/beam_PreCommit_Java_Spark3_Versions.yml @@ -112,4 +112,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml index 73a1a0b5cdb2..53f3c4327739 100644 --- a/.github/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml index 4cddfa728cc1..b5336537c556 100644 --- a/.github/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml index e08b5048b359..195e9aa1f168 100644 --- a/.github/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml @@ -105,6 +105,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Prism_Python.yml b/.github/workflows/beam_PreCommit_Prism_Python.yml new file mode 100644 index 000000000000..5eb26d139ef5 --- /dev/null +++ b/.github/workflows/beam_PreCommit_Prism_Python.yml @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PreCommit Prism Python + +on: + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/python/**' + - 'release/**' + - '.github/workflows/beam_PreCommit_Prism_Python.yml' + pull_request_target: + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/python/**' + - 'release/**' + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PreCommit_Prism_Python.json' + issue_comment: + types: [created] + schedule: + - cron: '30 2/6 * * *' + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PreCommit_Prism_Python: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + timeout-minutes: 120 + runs-on: ['self-hosted', ubuntu-20.04, main] + strategy: + fail-fast: false + matrix: + job_name: ['beam_PreCommit_Prism_Python'] + job_phrase: ['Run Prism_Python PreCommit'] + python_version: ['3.9', '3.12'] + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + startsWith(github.event.comment.body, 'Run Prism_Python PreCommit') + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} ${{ matrix.python_version }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + java-version: default + python-version: | + ${{ matrix.python_version }} + 3.9 + - name: Set PY_VER_CLEAN + id: set_py_ver_clean + run: | + PY_VER=${{ matrix.python_version }} + PY_VER_CLEAN=${PY_VER//.} + echo "py_ver_clean=$PY_VER_CLEAN" >> $GITHUB_OUTPUT + - name: run Prism Python Validates Runner script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:python:test-suites:portable:py${{steps.set_py_ver_clean.outputs.py_ver_clean}}:prismValidatesRunner \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python.yml b/.github/workflows/beam_PreCommit_Python.yml index fb1c6c80873a..68c69ae953a4 100644 --- a/.github/workflows/beam_PreCommit_Python.yml +++ b/.github/workflows/beam_PreCommit_Python.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Coverage.yml b/.github/workflows/beam_PreCommit_Python_Coverage.yml index 0e295250817d..3c7c3b05d8bc 100644 --- a/.github/workflows/beam_PreCommit_Python_Coverage.yml +++ b/.github/workflows/beam_PreCommit_Python_Coverage.yml @@ -104,4 +104,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Dataframes.yml b/.github/workflows/beam_PreCommit_Python_Dataframes.yml index f045842e061d..ecbb1a30e5f7 100644 --- a/.github/workflows/beam_PreCommit_Python_Dataframes.yml +++ b/.github/workflows/beam_PreCommit_Python_Dataframes.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Examples.yml b/.github/workflows/beam_PreCommit_Python_Examples.yml index 09d46217d6d6..44329f63014d 100644 --- a/.github/workflows/beam_PreCommit_Python_Examples.yml +++ b/.github/workflows/beam_PreCommit_Python_Examples.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Integration.yml b/.github/workflows/beam_PreCommit_Python_Integration.yml index 20aade431f6d..3a709c70f077 100644 --- a/.github/workflows/beam_PreCommit_Python_Integration.yml +++ b/.github/workflows/beam_PreCommit_Python_Integration.yml @@ -116,4 +116,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_ML.yml b/.github/workflows/beam_PreCommit_Python_ML.yml index 714eceef5f6b..3b3a2150ac28 100644 --- a/.github/workflows/beam_PreCommit_Python_ML.yml +++ b/.github/workflows/beam_PreCommit_Python_ML.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_PVR_Flink.yml b/.github/workflows/beam_PreCommit_Python_PVR_Flink.yml index dbc1264fcc04..5dd12d49ccd9 100644 --- a/.github/workflows/beam_PreCommit_Python_PVR_Flink.yml +++ b/.github/workflows/beam_PreCommit_Python_PVR_Flink.yml @@ -125,4 +125,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Runners.yml b/.github/workflows/beam_PreCommit_Python_Runners.yml index 5db6e94be781..689d9b2c3c3f 100644 --- a/.github/workflows/beam_PreCommit_Python_Runners.yml +++ b/.github/workflows/beam_PreCommit_Python_Runners.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_Python_Transforms.yml b/.github/workflows/beam_PreCommit_Python_Transforms.yml index 820ca3e26df6..431b82c02fb7 100644 --- a/.github/workflows/beam_PreCommit_Python_Transforms.yml +++ b/.github/workflows/beam_PreCommit_Python_Transforms.yml @@ -109,4 +109,5 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/pytest*.xml' \ No newline at end of file + files: '**/pytest*.xml' + large_files: true \ No newline at end of file diff --git a/.github/workflows/beam_PreCommit_SQL.yml b/.github/workflows/beam_PreCommit_SQL.yml index b4002fcc2a79..5bc8bb581955 100644 --- a/.github/workflows/beam_PreCommit_SQL.yml +++ b/.github/workflows/beam_PreCommit_SQL.yml @@ -103,6 +103,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_SQL_Java17.yml b/.github/workflows/beam_PreCommit_SQL_Java17.yml index 0e5dcc87d16f..1cfd7502389d 100644 --- a/.github/workflows/beam_PreCommit_SQL_Java17.yml +++ b/.github/workflows/beam_PreCommit_SQL_Java17.yml @@ -110,6 +110,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_SQL_Java8.yml b/.github/workflows/beam_PreCommit_SQL_Java8.yml index 23938821b2e8..6b59739dd72d 100644 --- a/.github/workflows/beam_PreCommit_SQL_Java8.yml +++ b/.github/workflows/beam_PreCommit_SQL_Java8.yml @@ -114,6 +114,7 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/build/test-results/**/*.xml' + large_files: true - name: Archive SpotBugs Results uses: actions/upload-artifact@v4 if: always() diff --git a/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml b/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml index b17913946a7e..b9e310a7a133 100644 --- a/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml +++ b/.github/workflows/beam_PreCommit_Yaml_Xlang_Direct.yml @@ -105,3 +105,4 @@ jobs: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} files: '**/pytest*.xml' + large_files: true diff --git a/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml b/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml index 61ef31a00239..e3791119be90 100644 --- a/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml +++ b/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml @@ -66,7 +66,6 @@ jobs: - "java:container:java11" - "java:container:java17" - "java:container:java21" - - "python:container:py38" - "python:container:py39" - "python:container:py310" - "python:container:py311" @@ -80,6 +79,13 @@ jobs: comment_phrase: ${{ matrix.job_phrase }} github_token: ${{ secrets.GITHUB_TOKEN }} github_job: ${{ matrix.job_name }} (${{ matrix.container_task }}) + - name: Find Beam Version + # We extract the Beam version here and tag the containers with it. Version will be in the form "2.xx.y.dev". + # This is needed to run pipelines that use the default environment at HEAD, for example, when a + # pipeline uses an expansion service built from HEAD. + run: | + BEAM_VERSION_LINE=$(cat gradle.properties | grep "sdk_version") + echo "BEAM_VERSION=${BEAM_VERSION_LINE#*sdk_version=}" >> $GITHUB_ENV - name: Set up Docker Buildx uses: docker/setup-buildx-action@v1 - name: GCloud Docker credential helper @@ -102,6 +108,6 @@ jobs: arguments: | -Pjava11Home=$JAVA_HOME_11_X64 \ -Pdocker-repository-root=gcr.io/apache-beam-testing/beam-sdk \ - -Pdocker-tag-list=${{ github.sha }},latest \ + -Pdocker-tag-list=${{ github.sha }},${BEAM_VERSION},latest \ -Pcontainer-architecture-list=arm64,amd64 \ -Ppush-containers \ diff --git a/.github/workflows/beam_Publish_Docker_Snapshots.yml b/.github/workflows/beam_Publish_Docker_Snapshots.yml index 334fa537be56..e37a202267c4 100644 --- a/.github/workflows/beam_Publish_Docker_Snapshots.yml +++ b/.github/workflows/beam_Publish_Docker_Snapshots.yml @@ -83,7 +83,7 @@ jobs: - name: run Publish Docker Snapshots script for Flink uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.15:job-server-container:dockerPush + gradle-command: :runners:flink:1.17:job-server-container:dockerPush arguments: | -Pdocker-repository-root=gcr.io/apache-beam-testing/beam_portability \ -Pdocker-tag-list=latest \ No newline at end of file diff --git a/.github/workflows/beam_Wordcount_Python_Cost_Benchmark_Dataflow.yml b/.github/workflows/beam_Wordcount_Python_Cost_Benchmark_Dataflow.yml new file mode 100644 index 000000000000..51d1005affbc --- /dev/null +++ b/.github/workflows/beam_Wordcount_Python_Cost_Benchmark_Dataflow.yml @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Wordcount Python Cost Benchmarks Dataflow + +on: + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + INFLUXDB_USER: ${{ secrets.INFLUXDB_USER }} + INFLUXDB_USER_PASSWORD: ${{ secrets.INFLUXDB_USER_PASSWORD }} + +jobs: + beam_Inference_Python_Benchmarks_Dataflow: + if: | + github.event_name == 'workflow_dispatch' + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 900 + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + strategy: + matrix: + job_name: ["beam_Wordcount_Python_Cost_Benchmarks_Dataflow"] + job_phrase: ["Run Wordcount Cost Benchmark"] + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup Python environment + uses: ./.github/actions/setup-environment-action + with: + python-version: '3.10' + - name: Prepare test arguments + uses: ./.github/actions/test-arguments-action + with: + test-type: load + test-language: python + argument-file-paths: | + ${{ github.workspace }}/.github/workflows/cost-benchmarks-pipeline-options/python_wordcount.txt + # The env variables are created and populated in the test-arguments-action as "_test_arguments_" + - name: get current time + run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV + - name: run wordcount on Dataflow Python + uses: ./.github/actions/gradle-command-self-hosted-action + timeout-minutes: 30 + with: + gradle-command: :sdks:python:apache_beam:testing:load_tests:run + arguments: | + -PloadTest.mainClass=apache_beam.testing.benchmarks.wordcount.wordcount \ + -Prunner=DataflowRunner \ + -PpythonVersion=3.10 \ + '-PloadTest.args=${{ env.beam_Inference_Python_Benchmarks_Dataflow_test_arguments_1 }} --job_name=benchmark-tests-wordcount-python-${{env.NOW_UTC}} --output=gs://temp-storage-for-end-to-end-tests/wordcount/result_wordcount-${{env.NOW_UTC}}.txt' \ \ No newline at end of file diff --git a/.github/workflows/build_release_candidate.yml b/.github/workflows/build_release_candidate.yml index fdbae21336e5..54235a71c910 100644 --- a/.github/workflows/build_release_candidate.yml +++ b/.github/workflows/build_release_candidate.yml @@ -97,7 +97,7 @@ jobs: stage_java_source: if: ${{ fromJson(github.event.inputs.STAGE).java_source == 'yes'}} - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Mask Apache Password run: | @@ -161,7 +161,7 @@ jobs: stage_python_artifacts: if: ${{ fromJson(github.event.inputs.STAGE).python_artifacts == 'yes'}} - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v4 @@ -245,8 +245,7 @@ jobs: stage_docker: if: ${{ fromJson(github.event.inputs.STAGE).docker_artifacts == 'yes'}} - # Note: if this ever changes to self-hosted, remove the "Remove default github maven configuration" step - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-20.04, highmem] steps: - name: Checkout uses: actions/checkout@v4 @@ -282,7 +281,7 @@ jobs: beam_site_pr: if: ${{ fromJson(github.event.inputs.STAGE).beam_site_pr == 'yes'}} # Note: if this ever changes to self-hosted, remove the "Remove default github maven configuration" step - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: RC_TAG: "v${{ github.event.inputs.RELEASE }}-RC${{ github.event.inputs.RC }}" BRANCH_NAME: updates_release_${{ github.event.inputs.RELEASE }} @@ -310,11 +309,12 @@ jobs: uses: actions/setup-node@v4 with: node-version: '16' - - name: Install Java 11 + # TODO(https://github.com/apache/beam/issues/32726) switch to Java11 + - name: Install Java 8 uses: actions/setup-java@v4 with: distribution: 'temurin' - java-version: '11' + java-version: '8' - name: Remove default github maven configuration # This step is a workaround to avoid a decryption issue of Beam's # net.linguica.gradle.maven.settings plugin and github's provided maven @@ -401,7 +401,7 @@ jobs: build_and_stage_prism: if: ${{ fromJson(github.event.inputs.STAGE).prism == 'yes'}} - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v4 diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index d1e99f2bd579..828a6328c0cd 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -49,7 +49,7 @@ jobs: env: EVENT_NAME: ${{ github.event_name }} # Keep in sync with py_version matrix value below - if changed, change that as well. - PY_VERSIONS_FULL: "cp38-* cp39-* cp310-* cp311-* cp312-*" + PY_VERSIONS_FULL: "cp39-* cp310-* cp311-* cp312-*" outputs: gcp-variables-set: ${{ steps.check_gcp_variables.outputs.gcp-variables-set }} py-versions-full: ${{ steps.set-py-versions.outputs.py-versions-full }} @@ -202,7 +202,6 @@ jobs: if: needs.check_env_variables.outputs.gcp-variables-set == 'true' steps: - name: Download compressed sources from artifacts - # Pinned to v3 because of https://github.com/actions/download-artifact/issues/249 uses: actions/download-artifact@v4.1.8 with: name: source_zip @@ -229,18 +228,16 @@ jobs: {"os": "windows-latest", "runner": "windows-latest", "python": "${{ needs.check_env_variables.outputs.py-versions-test }}", arch: "auto" }, {"os": "ubuntu-20.04", "runner": [self-hosted, ubuntu-20.04, main], "python": "${{ needs.check_env_variables.outputs.py-versions-test }}", arch: "aarch64" } ] - # Keep in sync with PY_VERSIONS_FULL env var abvove - if changed, change that as well. - py_version: ["cp38-*", "cp39-*", "cp310-*", "cp311-*", "cp312-*"] + # Keep in sync (remove asterisks) with PY_VERSIONS_FULL env var above - if changed, change that as well. + py_version: ["cp39-", "cp310-", "cp311-", "cp312-"] steps: - name: Download python source distribution from artifacts - # Pinned to v3 because of https://github.com/actions/download-artifact/issues/249 uses: actions/download-artifact@v4.1.8 with: name: source path: apache-beam-source - name: Download Python SDK RC source distribution from artifacts if: ${{ needs.build_source.outputs.is_rc == 1 }} - # Pinned to v3 because of https://github.com/actions/download-artifact/issues/249 uses: actions/download-artifact@v4.1.8 with: name: source_rc${{ needs.build_source.outputs.rc_num }} @@ -260,7 +257,7 @@ jobs: if: ${{ contains(matrix.os_python.python, matrix.py_version) }} working-directory: apache-beam-source env: - CIBW_BUILD: ${{ matrix.py_version }} + CIBW_BUILD: ${{ matrix.py_version }}* # TODO: https://github.com/apache/beam/issues/23048 CIBW_SKIP: "*-musllinux_*" CIBW_BEFORE_BUILD: pip install cython==0.29.36 numpy --config-settings=setup-args="-Dallow-noblas=true" && pip install --upgrade setuptools @@ -279,17 +276,16 @@ jobs: shell: bash - name: Upload wheels as artifacts if: ${{ contains(matrix.os_python.python, matrix.py_version) }} - # Pinned to v3 because of https://github.com/actions/upload-artifact?tab=readme-ov-file#breaking-changes - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: - name: wheelhouse-${{ matrix.os_python.os }}${{ (matrix.os_python.arch == 'aarch64' && '-aarch64') || '' }} + name: wheelhouse-${{ matrix.py_version }}${{ matrix.os_python.os }}${{ (matrix.os_python.arch == 'aarch64' && '-aarch64') || '' }} path: apache-beam-source/wheelhouse/ - name: Build RC wheels # Only build wheel if it is one of the target versions for this platform, otherwise no-op if: ${{ needs.build_source.outputs.is_rc == 1 && contains(matrix.os_python.python, matrix.py_version) }} working-directory: apache-beam-source-rc env: - CIBW_BUILD: ${{ matrix.py_version }} + CIBW_BUILD: ${{ matrix.py_version }}* # TODO: https://github.com/apache/beam/issues/23048 CIBW_SKIP: "*-musllinux_*" CIBW_BEFORE_BUILD: pip install cython==0.29.36 numpy --config-settings=setup-args="-Dallow-noblas=true" && pip install --upgrade setuptools @@ -305,10 +301,9 @@ jobs: shell: bash - name: Upload RC wheels as artifacts if: ${{ needs.build_source.outputs.is_rc == 1 }} - # Pinned to v3 because of https://github.com/actions/download-artifact/issues/249 uses: actions/upload-artifact@v4 with: - name: wheelhouse-rc${{ needs.build_source.outputs.rc_num }}-${{ matrix.os_python.os }}${{ (matrix.arch == 'aarch64' && '-aarch64') || '' }} + name: wheelhouse-rc${{ needs.build_source.outputs.rc_num }}-${{ matrix.py_version }}${{ matrix.os_python.os }}${{ (matrix.os_python.arch == 'aarch64' && '-aarch64') || '' }} path: apache-beam-source-rc/wheelhouse/ upload_wheels_to_gcs: @@ -318,21 +313,12 @@ jobs: - check_env_variables runs-on: [self-hosted, ubuntu-20.04, main] if: needs.check_env_variables.outputs.gcp-variables-set == 'true' && github.event_name != 'pull_request' - strategy: - matrix: - # Temporarily pin to macos-13 because macos-latest breaks this build - # TODO(https://github.com/apache/beam/issues/31114) - os : [ubuntu-20.04, macos-13, windows-latest] - arch: [auto] - include: - - os: ubuntu-20.04 - arch: aarch64 steps: - name: Download wheels from artifacts - # Pinned to v3 because of https://github.com/actions/upload-artifact?tab=readme-ov-file#breaking-changes - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: - name: wheelhouse-${{ matrix.os }}${{ (matrix.arch == 'aarch64' && '-aarch64') || '' }} + pattern: wheelhouse-* + merge-multiple: true path: wheelhouse/ - name: Copy wheels to GCS bucket run: gsutil cp -r -a public-read wheelhouse/* ${{ env.GCP_PATH }} diff --git a/.github/workflows/cost-benchmarks-pipeline-options/python_wordcount.txt b/.github/workflows/cost-benchmarks-pipeline-options/python_wordcount.txt new file mode 100644 index 000000000000..424936ddad97 --- /dev/null +++ b/.github/workflows/cost-benchmarks-pipeline-options/python_wordcount.txt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--region=us-central1 +--machine_type=n1-standard-2 +--num_workers=1 +--disk_size_gb=50 +--autoscaling_algorithm=NONE +--input_options={} +--staging_location=gs://temp-storage-for-perf-tests/loadtests +--temp_location=gs://temp-storage-for-perf-tests/loadtests +--publish_to_big_query=true +--metrics_dataset=beam_run_inference +--metrics_table=python_wordcount +--runner=DataflowRunner \ No newline at end of file diff --git a/.github/workflows/dask_runner_tests.yml b/.github/workflows/dask_runner_tests.yml index f87c70d8b720..8faea77acc9b 100644 --- a/.github/workflows/dask_runner_tests.yml +++ b/.github/workflows/dask_runner_tests.yml @@ -26,6 +26,7 @@ on: branches: ['master', 'release-*'] tags: 'v*' paths: ['sdks/python/apache_beam/runners/dask/**'] + workflow_dispatch: # This allows a subsequently queued workflow run to interrupt previous runs concurrency: @@ -78,7 +79,7 @@ jobs: run: pip install tox - name: Install SDK with dask working-directory: ./sdks/python - run: pip install setuptools --upgrade && pip install -e .[gcp,dask,test] + run: pip install setuptools --upgrade && pip install -e .[dask,test,dataframes] - name: Run tests basic unix if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos') working-directory: ./sdks/python diff --git a/.github/workflows/finalize_release.yml b/.github/workflows/finalize_release.yml index 17ef17ed7841..126e1024908d 100644 --- a/.github/workflows/finalize_release.yml +++ b/.github/workflows/finalize_release.yml @@ -55,7 +55,7 @@ jobs: echo "Publish SDK docker images to Docker Hub." echo "================Pull RC Containers from DockerHub===========" - IMAGES=$(docker search apache/beam_ --format "{{.Name}}" --limit 100) + IMAGES=$(docker search apache/beam --format "{{.Name}}" --limit 100) KNOWN_IMAGES=() echo "We are using ${RC_VERSION} to push docker images for ${RELEASE}." while read IMAGE; do diff --git a/.github/workflows/flink-tests-pipeline-options/go_Combine_Flink_Batch_small.txt b/.github/workflows/flink-tests-pipeline-options/go_Combine_Flink_Batch_small.txt new file mode 100644 index 000000000000..6b44f53886b2 --- /dev/null +++ b/.github/workflows/flink-tests-pipeline-options/go_Combine_Flink_Batch_small.txt @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--input_options=''{\"num_records\":200,\"key_size\":1,\"value_size\":9}'' +--fanout=1 +--top_count=10 +--parallelism=2 +--endpoint=localhost:8099 +--environment_type=DOCKER +--environment_config=gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest +--runner=FlinkRunner \ No newline at end of file diff --git a/.github/workflows/flink-tests-pipeline-options/java_Combine_Flink_Batch_small.txt b/.github/workflows/flink-tests-pipeline-options/java_Combine_Flink_Batch_small.txt new file mode 100644 index 000000000000..e792682bfbc4 --- /dev/null +++ b/.github/workflows/flink-tests-pipeline-options/java_Combine_Flink_Batch_small.txt @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--sourceOptions={"numRecords":200,"keySizeBytes":1,"valueSizeBytes":9} +--fanout=1 +--iterations=1 +--topCount=10 +--parallelism=2 +--jobEndpoint=localhost:8099 +--defaultEnvironmentType=DOCKER +--defaultEnvironmentConfig=gcr.io/apache-beam-testing/beam-sdk/beam_java11_sdk:latest +--runner=FlinkRunner \ No newline at end of file diff --git a/.github/workflows/flink-tests-pipeline-options/python_Combine_Flink_Batch_small.txt b/.github/workflows/flink-tests-pipeline-options/python_Combine_Flink_Batch_small.txt new file mode 100644 index 000000000000..5522a8f9b823 --- /dev/null +++ b/.github/workflows/flink-tests-pipeline-options/python_Combine_Flink_Batch_small.txt @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--input_options=''{\\"num_records\\":200,\\"key_size\\":1,\\"value_size\\":9,\\"algorithm\\":\\"lcg\\"}'' +--parallelism=2 +--job_endpoint=localhost:8099 +--environment_type=DOCKER +--environment_config=gcr.io/apache-beam-testing/beam-sdk/beam_python3.9_sdk:latest +--top_count=10 +--runner=PortableRunner \ No newline at end of file diff --git a/.github/workflows/go_tests.yml b/.github/workflows/go_tests.yml index e85c4eba866b..5ae3609ed997 100644 --- a/.github/workflows/go_tests.yml +++ b/.github/workflows/go_tests.yml @@ -50,7 +50,7 @@ jobs: - name: Delete old coverage run: "cd sdks && rm -rf .coverage.txt || :" - name: Run coverage - run: cd sdks && go test -coverprofile=coverage.txt -covermode=atomic ./go/pkg/... ./go/container/... ./java/container/... ./python/container/... ./typescript/container/... + run: cd sdks && go test -timeout=25m -coverprofile=coverage.txt -covermode=atomic ./go/pkg/... ./go/container/... ./java/container/... ./python/container/... ./typescript/container/... - uses: codecov/codecov-action@v3 with: flags: go diff --git a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt index 650236a9c500..6280e01dccdb 100644 --- a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt +++ b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt @@ -27,4 +27,5 @@ --top_count=20 --streaming --use_stateful_load_generator ---runner=PortableRunner \ No newline at end of file +--runner=PortableRunner +--max_cache_memory_usage_mb=256 \ No newline at end of file diff --git a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt index 4208571fef62..e1b77d15b95b 100644 --- a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt +++ b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt @@ -27,4 +27,5 @@ --top_count=20 --streaming --use_stateful_load_generator ---runner=PortableRunner \ No newline at end of file +--runner=PortableRunner +--max_cache_memory_usage_mb=256 \ No newline at end of file diff --git a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_1.txt b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_1.txt new file mode 100644 index 000000000000..f16e9e4b06ef --- /dev/null +++ b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_1.txt @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--publish_to_big_query=true +--metrics_dataset=load_test +--metrics_table=python_flink_streaming_combine_4 +--influx_measurement=python_streaming_combine_4 +--input_options=''{\\"num_records\\":200000,\\"key_size\\":10,\\"value_size\\":90,\\"algorithm\\":\\"lcg\\"}'' +--parallelism=16 +--job_endpoint=localhost:8099 +--environment_type=DOCKER +--environment_config=gcr.io/apache-beam-testing/beam-sdk/beam_python3.9_sdk:latest +--fanout=1 +--top_count=20 +--streaming +--use_stateful_load_generator +--runner=PortableRunner +--max_cache_memory_usage_mb=256 \ No newline at end of file diff --git a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_2.txt b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_2.txt new file mode 100644 index 000000000000..5f66e519c31a --- /dev/null +++ b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_2.txt @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--publish_to_big_query=true +--metrics_dataset=load_test +--metrics_table=python_flink_streaming_combine_5 +--influx_measurement=python_streaming_combine_5 +--input_options=''{\\"num_records\\":200000,\\"key_size\\":10,\\"value_size\\":90,\\"algorithm\\":\\"lcg\\"}'' +--parallelism=16 +--job_endpoint=localhost:8099 +--environment_type=DOCKER +--environment_config=gcr.io/apache-beam-testing/beam-sdk/beam_python3.9_sdk:latest +--fanout=2 +--top_count=20 +--streaming +--use_stateful_load_generator +--runner=PortableRunner +--max_cache_memory_usage_mb=256 \ No newline at end of file diff --git a/.github/workflows/republish_released_docker_containers.yml b/.github/workflows/republish_released_docker_containers.yml new file mode 100644 index 000000000000..ed6e74ecf13d --- /dev/null +++ b/.github/workflows/republish_released_docker_containers.yml @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Workflow that enables republishing released docker images to avoid vulnerabilities + +name: Republish Released Docker Images + +on: + workflow_dispatch: + inputs: + RELEASE: + description: Beam version of current release (e.g. 2.XX.0) + required: true + default: '' + RC: + description: Integer RC version for the release (e.g. 3 for RC3) + required: true + default: '' + schedule: + - cron: "0 6 * * 1" +env: + docker_registry: gcr.io + release: ${{ github.event.inputs.RELEASE || "2.61.0" }} + rc: ${{ github.event.inputs.RC || "3" }} + +jobs: + + build: + runs-on: [self-hosted, ubuntu-20.04, highmem] + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: "v${{ env.release }}-RC${{ env.rc }}" + repository: apache/beam + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@v1.3.0 + - name: Install Java 11 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + - name: Install Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: Remove default github maven configuration + # This step is a workaround to avoid a decryption issue of Beam's + # net.linguica.gradle.maven.settings plugin and github's provided maven + # settings.xml file + run: rm ~/.m2/settings.xml || true + - name: GCloud Docker credential helper + run: | + gcloud auth configure-docker ${{ env.docker_registry }} + - name: Push docker images + run: ./gradlew :pushAllDockerImages -PisRelease -Pdocker-pull-licenses -Pprune-images -Pdocker-repository-root=gcr.io/apache-beam-testing/updated_released_container_images -Pdocker-tag=${{ env.release }}rc${{ env.rc }} --no-daemon --no-parallel + diff --git a/.github/workflows/self-assign.yml b/.github/workflows/self-assign.yml index 084581db7340..6c2f2219b4e3 100644 --- a/.github/workflows/self-assign.yml +++ b/.github/workflows/self-assign.yml @@ -40,12 +40,16 @@ jobs: repo: context.repo.repo, assignees: [context.payload.comment.user.login] }); - github.rest.issues.removeLabel({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - name: 'awaiting triage' - }); + try { + github.rest.issues.removeLabel({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + name: 'awaiting triage' + }); + } catch (error) { + console.log(`Failed to remove awaiting triage label. It may not exist on this issue. Error ${error}`); + } } else if (bodyString == '.close-issue') { console.log('Closing issue'); if (i + 1 < body.length && body[i+1].toLowerCase() == 'not_planned') { diff --git a/.github/workflows/update_python_dependencies.yml b/.github/workflows/update_python_dependencies.yml index a91aff39f29a..0ab52e97b9f0 100644 --- a/.github/workflows/update_python_dependencies.yml +++ b/.github/workflows/update_python_dependencies.yml @@ -56,7 +56,6 @@ jobs: uses: ./.github/actions/setup-environment-action with: python-version: | - 3.8 3.9 3.10 3.11 diff --git a/.test-infra/dataproc/flink_cluster.sh b/.test-infra/dataproc/flink_cluster.sh index b623e890d08f..4a97850f5ac1 100755 --- a/.test-infra/dataproc/flink_cluster.sh +++ b/.test-infra/dataproc/flink_cluster.sh @@ -17,7 +17,7 @@ # Provide the following environment to run this script: # # GCLOUD_ZONE: Google cloud zone. Optional. Default: "us-central1-a" -# DATAPROC_VERSION: Dataproc version. Optional. Default: 2.1 +# DATAPROC_VERSION: Dataproc version. Optional. Default: 2.2 # CLUSTER_NAME: Cluster name # GCS_BUCKET: GCS bucket url for Dataproc resources (init actions) # HARNESS_IMAGES_TO_PULL: Urls to SDK Harness' images to pull on dataproc workers (optional: 0, 1 or multiple urls for every harness image) @@ -35,8 +35,8 @@ # HARNESS_IMAGES_TO_PULL='gcr.io//python:latest gcr.io//java:latest' \ # JOB_SERVER_IMAGE=gcr.io//job-server-flink:latest \ # ARTIFACTS_DIR=gs:// \ -# FLINK_DOWNLOAD_URL=https://archive.apache.org/dist/flink/flink-1.12.3/flink-1.12.3-bin-scala_2.11.tgz \ -# HADOOP_DOWNLOAD_URL=https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-9.0/flink-shaded-hadoop-2-uber-2.8.3-9.0.jar \ +# FLINK_DOWNLOAD_URL=https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz \ +# HADOOP_DOWNLOAD_URL=https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-9.0.jar \ # FLINK_NUM_WORKERS=2 \ # FLINK_TASKMANAGER_SLOTS=1 \ # DETACHED_MODE=false \ @@ -46,7 +46,7 @@ set -Eeuxo pipefail # GCloud properties GCLOUD_ZONE="${GCLOUD_ZONE:=us-central1-a}" -DATAPROC_VERSION="${DATAPROC_VERSION:=2.1-debian}" +DATAPROC_VERSION="${DATAPROC_VERSION:=2.2-debian}" GCLOUD_REGION=`echo $GCLOUD_ZONE | sed -E "s/(-[a-z])?$//"` MASTER_NAME="$CLUSTER_NAME-m" @@ -129,13 +129,26 @@ function create_cluster() { local image_version=$DATAPROC_VERSION echo "Starting dataproc cluster. Dataproc version: $image_version" - # Docker init action restarts yarn so we need to start yarn session after this restart happens. - # This is why flink init action is invoked last. - # TODO(11/11/2022) remove --worker-machine-type and --master-machine-type once N2 CPUs quota relaxed - # Dataproc 2.1 uses n2-standard-2 by default but there is N2 CPUs=24 quota limit - gcloud dataproc clusters create $CLUSTER_NAME --region=$GCLOUD_REGION --num-workers=$FLINK_NUM_WORKERS \ - --master-machine-type=n1-standard-2 --worker-machine-type=n1-standard-2 --metadata "${metadata}", \ - --image-version=$image_version --zone=$GCLOUD_ZONE --optional-components=FLINK,DOCKER --quiet + local worker_machine_type="n1-standard-2" # Default worker type + local master_machine_type="n1-standard-2" # Default master type + + if [[ -n "${HIGH_MEM_MACHINE:=}" ]]; then + worker_machine_type="${HIGH_MEM_MACHINE}" + master_machine_type="${HIGH_MEM_MACHINE}" + + gcloud dataproc clusters create $CLUSTER_NAME --enable-component-gateway --region=$GCLOUD_REGION --num-workers=$FLINK_NUM_WORKERS --public-ip-address \ + --master-machine-type=${master_machine_type} --worker-machine-type=${worker_machine_type} --metadata "${metadata}", \ + --image-version=$image_version --zone=$GCLOUD_ZONE --optional-components=FLINK,DOCKER \ + --properties="${HIGH_MEM_FLINK_PROPS}" + else + # Docker init action restarts yarn so we need to start yarn session after this restart happens. + # This is why flink init action is invoked last. + # TODO(11/22/2024) remove --worker-machine-type and --master-machine-type once N2 CPUs quota relaxed + # Dataproc 2.1 uses n2-standard-2 by default but there is N2 CPUs=24 quota limit for this project + gcloud dataproc clusters create $CLUSTER_NAME --enable-component-gateway --region=$GCLOUD_REGION --num-workers=$FLINK_NUM_WORKERS --public-ip-address \ + --master-machine-type=${master_machine_type} --worker-machine-type=${worker_machine_type} --metadata "${metadata}", \ + --image-version=$image_version --zone=$GCLOUD_ZONE --optional-components=FLINK,DOCKER --quiet + fi } # Runs init actions for Docker, Portability framework (Beam) and Flink cluster diff --git a/.test-infra/jenkins/CommonTestProperties.groovy b/.test-infra/jenkins/CommonTestProperties.groovy index c6870dea59a1..0670b96ef47c 100644 --- a/.test-infra/jenkins/CommonTestProperties.groovy +++ b/.test-infra/jenkins/CommonTestProperties.groovy @@ -26,7 +26,7 @@ class CommonTestProperties { } static String getFlinkVersion() { - return "1.15" + return "1.17" } static String getSparkVersion() { diff --git a/.test-infra/jenkins/Flink.groovy b/.test-infra/jenkins/Flink.groovy deleted file mode 100644 index 34f3b60709c0..000000000000 --- a/.test-infra/jenkins/Flink.groovy +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -class Flink { - private static final String flinkDownloadUrl = 'https://archive.apache.org/dist/flink/flink-1.15.0/flink-1.15.0-bin-scala_2.12.tgz' - private static final String hadoopDownloadUrl = 'https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar' - private static final String FLINK_DIR = '"$WORKSPACE/src/.test-infra/dataproc"' - private static final String FLINK_SCRIPT = 'flink_cluster.sh' - private def job - private String jobName - - Flink(job, String jobName) { - this.job = job - this.jobName = jobName - } - - /** - * Creates Flink cluster and specifies cleanup steps. - * - * @param sdkHarnessImages - the list of published SDK Harness images tags - * @param workerCount - the initial number of worker nodes - * @param jobServerImage - the Flink job server image tag. If left empty, cluster will be set up without the job server. - * @param slotsPerTaskmanager - the number of slots per Flink task manager - */ - void setUp(List sdkHarnessImages, Integer workerCount, String jobServerImage = '', Integer slotsPerTaskmanager = 1) { - setupFlinkCluster(sdkHarnessImages, workerCount, jobServerImage, slotsPerTaskmanager) - addTeardownFlinkStep() - } - - private void setupFlinkCluster(List sdkHarnessImages, Integer workerCount, String jobServerImage, Integer slotsPerTaskmanager) { - String gcsBucket = 'gs://beam-flink-cluster' - String clusterName = getClusterName() - String artifactsDir = "${gcsBucket}/${clusterName}" - String imagesToPull = sdkHarnessImages.join(' ') - - job.steps { - environmentVariables { - env("GCLOUD_ZONE", "us-central1-a") - env("CLUSTER_NAME", clusterName) - env("GCS_BUCKET", gcsBucket) - env("FLINK_DOWNLOAD_URL", flinkDownloadUrl) - env("HADOOP_DOWNLOAD_URL", hadoopDownloadUrl) - env("FLINK_NUM_WORKERS", workerCount) - env("FLINK_TASKMANAGER_SLOTS", slotsPerTaskmanager) - env("DETACHED_MODE", 'true') - - if(imagesToPull) { - env("HARNESS_IMAGES_TO_PULL", imagesToPull) - } - - if(jobServerImage) { - env("JOB_SERVER_IMAGE", jobServerImage) - env("ARTIFACTS_DIR", artifactsDir) - } - } - - shell('echo Setting up flink cluster') - shell("cd ${FLINK_DIR}; ./${FLINK_SCRIPT} create") - } - } - - /** - * Updates the number of worker nodes in a cluster. - * - * @param workerCount - the new number of worker nodes in the cluster - */ - void scaleCluster(Integer workerCount) { - job.steps { - shell("echo Changing number of workers to ${workerCount}") - environmentVariables { - env("FLINK_NUM_WORKERS", workerCount) - } - shell("cd ${FLINK_DIR}; ./${FLINK_SCRIPT} restart") - } - } - - private GString getClusterName() { - return "${jobName.toLowerCase().replace("_", "-")}-\$BUILD_ID" - } - - private void addTeardownFlinkStep() { - job.publishers { - postBuildScript { - buildSteps { - postBuildStep { - stopOnFailure(false) - results([ - 'SUCCESS', - 'UNSTABLE', - 'FAILURE', - 'NOT_BUILT', - 'ABORTED' - ]) - buildSteps { - shell { - command("cd ${FLINK_DIR}; ./${FLINK_SCRIPT} delete") - } - } - } - } - markBuildUnstable(false) - } - } - } -} diff --git a/.test-infra/jenkins/metrics_report/tox.ini b/.test-infra/jenkins/metrics_report/tox.ini index 026db5dc4860..d143a0dcf59c 100644 --- a/.test-infra/jenkins/metrics_report/tox.ini +++ b/.test-infra/jenkins/metrics_report/tox.ini @@ -17,7 +17,7 @@ ; TODO(https://github.com/apache/beam/issues/20209): Don't hardcode Py3.8 version. [tox] skipsdist = True -envlist = py38-test,py38-generate-report +envlist = py39-test,py39-generate-report [testenv] commands_pre = diff --git a/.test-infra/metrics/sync/github/github_runs_prefetcher/code/config.yaml b/.test-infra/metrics/sync/github/github_runs_prefetcher/code/config.yaml index 8d711c7fcb86..eccaaa5f3b17 100644 --- a/.test-infra/metrics/sync/github/github_runs_prefetcher/code/config.yaml +++ b/.test-infra/metrics/sync/github/github_runs_prefetcher/code/config.yaml @@ -43,6 +43,7 @@ categories: - "Cancel Stale Dataflow Jobs" - "pr-bot-pr-updates" - "pr-bot-new-prs" + - "Republish Released Docker Images" # Tests we want monitored more closely. Only add suites if you are willing to help keep them green :) # Usually will be postcommits since PreCommits are more likely to be noticed. - name: important_signals @@ -123,7 +124,6 @@ categories: - "PostCommit Java PVR Samza" - "PreCommit Java Tika IO Direct" - "PostCommit Java SingleStoreIO IT" - - "PostCommit Java Sickbay" - "PostCommit Java ValidatesRunner Direct" - "PreCommit Java SingleStore IO Direct" - "PreCommit Java InfluxDb IO Direct" @@ -226,7 +226,6 @@ categories: - "PreCommit Python Transforms" - "Build python source distribution and wheels" - "Python tests" - - "PostCommit Sickbay Python" - "PreCommit Portable Python" - "PreCommit Python Coverage" - "PreCommit Python Docker" @@ -316,7 +315,6 @@ categories: - "PostCommit PortableJar Spark" - "PreCommit Integration and Load Test Framework" - "pr-bot-update-reviewers" - - "Cut Release Branch" - "Generate issue report" - "Dask Runner Tests" - "PreCommit Typescript" @@ -327,6 +325,12 @@ categories: - "Assign Milestone on issue close" - "Local environment tests" - "PreCommit SQL" - - "LabelPrs" + - "LabelPrs" + - name: safe_to_ignore + groupThreshold: 0 + tests: - "build_release_candidate" + - "Cut Release Branch" + - "PostCommit Java Sickbay" + - "PostCommit Sickbay Python" diff --git a/.test-infra/mock-apis/pyproject.toml b/.test-infra/mock-apis/pyproject.toml index 680bf489ba13..c98d9152cfb9 100644 --- a/.test-infra/mock-apis/pyproject.toml +++ b/.test-infra/mock-apis/pyproject.toml @@ -27,7 +27,7 @@ packages = [ ] [tool.poetry.dependencies] -python = "^3.8" +python = "^3.9" google = "^3.0.0" grpcio = "^1.53.0" grpcio-tools = "^1.53.0" diff --git a/.test-infra/tools/python_installer.sh b/.test-infra/tools/python_installer.sh index b1b05e597cb3..04e10555243a 100644 --- a/.test-infra/tools/python_installer.sh +++ b/.test-infra/tools/python_installer.sh @@ -20,7 +20,7 @@ set -euo pipefail # Variable containing the python versions to install -python_versions_arr=("3.8.16" "3.9.16" "3.10.10" "3.11.4") +python_versions_arr=("3.9.16" "3.10.10" "3.11.4", "3.12.6") # Install pyenv dependencies. pyenv_dep(){ diff --git a/CHANGES.md b/CHANGES.md index be4e0ba4d0f6..dbadd588ae3f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -53,7 +53,7 @@ * ([#X](https://github.com/apache/beam/issues/X)). --> -# [2.61.0] - Unreleased +# [2.62.0] - Unreleased ## Highlights @@ -62,7 +62,9 @@ ## I/Os +* gcs-connector config options can be set via GcsOptions (Java) ([#32769](https://github.com/apache/beam/pull/32769)). * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Upgraded the default version of Hadoop dependencies to 3.4.1. Hadoop 2.10.2 is still supported (Java) ([#33011](https://github.com/apache/beam/issues/33011)). ## New Features / Improvements @@ -70,6 +72,7 @@ ## Breaking Changes +* Upgraded ZetaSQL to 2024.11.1 ([#32902](https://github.com/apache/beam/pull/32902)). Java11+ is now needed if Beam's ZetaSQL component is used. * X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). ## Deprecations @@ -79,15 +82,54 @@ ## Bugfixes * Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Fixed EventTimeTimer ordering in Prism. ([#32222](https://github.com/apache/beam/issues/32222)). ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). +* Fixed (CVE-2024-47561)[https://www.cve.org/CVERecord?id=CVE-2024-47561] (Java) by upgrading Avro version to 1.11.4 ## Known Issues * ([#X](https://github.com/apache/beam/issues/X)). -# [2.60.0] - Unreleased +# [2.61.0] - 2024-11-25 + +## Highlights + +* [Python] Introduce Managed Transforms API ([#31495](https://github.com/apache/beam/pull/31495)) +* Flink 1.19 support added ([#32648](https://github.com/apache/beam/pull/32648)) + +## I/Os + +* [Managed Iceberg] Support creating tables if needed ([#32686](https://github.com/apache/beam/pull/32686)) +* [Managed Iceberg] Now available in Python SDK ([#31495](https://github.com/apache/beam/pull/31495)) +* [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) +* BigQuery CDC writes are now available in Python SDK, only supported when using StorageWrite API at least once mode ([#32527](https://github.com/apache/beam/issues/32527)) +* [Managed Iceberg] Allow updating table partition specs during pipeline runtime ([#32879](https://github.com/apache/beam/pull/32879)) +* Added BigQueryIO as a Managed IO ([#31486](https://github.com/apache/beam/pull/31486)) +* Support for writing to [Solace messages queues](https://solace.com/) (`SolaceIO.Write`) added (Java) ([#31905](https://github.com/apache/beam/issues/31905)). + +## New Features / Improvements + +* Added support for read with metadata in MqttIO (Java) ([#32195](https://github.com/apache/beam/issues/32195)) +* Added support for processing events which use a global sequence to "ordered" extension (Java) ([#32540](https://github.com/apache/beam/pull/32540)) +* Add new meta-transform FlattenWith and Tee that allow one to introduce branching + without breaking the linear/chaining style of pipeline construction. +* Use Prism as a fallback to the Python Portable runner when running a pipeline with the Python Direct runner ([#32876](https://github.com/apache/beam/pull/32876)) + +## Deprecations + +* Removed support for Flink 1.15 and 1.16 +* Removed support for Python 3.8 + +## Bugfixes + +* (Java) Fixed tearDown not invoked when DoFn throws on Portable Runners ([#18592](https://github.com/apache/beam/issues/18592), [#31381](https://github.com/apache/beam/issues/31381)). +* (Java) Fixed protobuf error with MapState.remove() in Dataflow Streaming Java Legacy Runner without Streaming Engine ([#32892](https://github.com/apache/beam/issues/32892)). +* Adding flag to support conditionally disabling auto-commit in JdbcIO ReadFn ([#31111](https://github.com/apache/beam/issues/31111)) +* (Python) Fixed BigQuery Enrichment bug that can lead to multiple conditions returning duplicate rows, batching returning incorrect results and conditions not scoped by row during batching ([#32780](https://github.com/apache/beam/pull/32780)). + +# [2.60.0] - 2024-10-17 ## Highlights @@ -96,6 +138,10 @@ * [Managed Iceberg] Added auto-sharding for streaming writes ([#32612](https://github.com/apache/beam/pull/32612)) * [Managed Iceberg] Added support for writing to dynamic destinations ([#32565](https://github.com/apache/beam/pull/32565)) +## I/Os + +* PubsubIO can validate that the Pub/Sub topic exists before running the Read/Write pipeline (Java) ([#32465](https://github.com/apache/beam/pull/32465)) + ## New Features / Improvements * Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). @@ -107,6 +153,7 @@ * Significantly improved performance of Kafka IO reads that enable [commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--) by removing the data reshuffle from SDF implementation. ([#31682](https://github.com/apache/beam/pull/31682)). * Added support for dynamic writing in MqttIO (Java) ([#19376](https://github.com/apache/beam/issues/19376)) * Optimized Spark Runner parDo transform evaluator (Java) ([#32537](https://github.com/apache/beam/issues/32537)) +* [Managed Iceberg] More efficient manifest file writes/commits ([#32666](https://github.com/apache/beam/issues/32666)) ## Breaking Changes @@ -128,6 +175,14 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) * (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)). * (Java, Python, Go) Fixed PeriodicSequence backlog bytes reporting, which was preventing Dataflow Runner autoscaling from functioning properly ([#32506](https://github.com/apache/beam/issues/32506)). * (Java) Fix improper decoding of rows with schemas containing nullable fields when encoded with a schema with equal encoding positions but modified field order. ([#32388](https://github.com/apache/beam/issues/32388)). +* (Java) Skip close on bundles in BigtableIO.Read ([#32661](https://github.com/apache/beam/pull/32661), [#32759](https://github.com/apache/beam/pull/32759)). + +## Known Issues + +* BigQuery Enrichment (Python): The following issues are present when using the BigQuery enrichment transform ([#32780](https://github.com/apache/beam/pull/32780)): + * Duplicate Rows: Multiple conditions may be applied incorrectly, leading to the duplication of rows in the output. + * Incorrect Results with Batched Requests: Conditions may not be correctly scoped to individual rows within the batch, potentially causing inaccurate results. + * Fixed in 2.61.0. # [2.59.0] - 2024-09-11 @@ -172,6 +227,10 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) * If your pipeline is having difficulty with the Python or Java direct runners, but runs well on Prism, please let us know. * Java file-based IOs read or write lots (100k+) files could experience slowness and/or broken metrics visualization on Dataflow UI [#32649](https://github.com/apache/beam/issues/32649). +* BigQuery Enrichment (Python): The following issues are present when using the BigQuery enrichment transform ([#32780](https://github.com/apache/beam/pull/32780)): + * Duplicate Rows: Multiple conditions may be applied incorrectly, leading to the duplication of rows in the output. + * Incorrect Results with Batched Requests: Conditions may not be correctly scoped to individual rows within the batch, potentially causing inaccurate results. + * Fixed in 2.61.0. # [2.58.1] - 2024-08-15 @@ -183,6 +242,10 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) * Large Dataflow graphs using runner v2, or pipelines explicitly enabling the `upload_graph` experiment, will fail at construction time ([#32159](https://github.com/apache/beam/issues/32159)). * Python pipelines that run with 2.53.0-2.58.0 SDKs and read data from GCS might be affected by a data corruption issue ([#32169](https://github.com/apache/beam/issues/32169)). The issue will be fixed in 2.59.0 ([#32135](https://github.com/apache/beam/pull/32135)). To work around this, update the google-cloud-storage package to version 2.18.2 or newer. +* BigQuery Enrichment (Python): The following issues are present when using the BigQuery enrichment transform ([#32780](https://github.com/apache/beam/pull/32780)): + * Duplicate Rows: Multiple conditions may be applied incorrectly, leading to the duplication of rows in the output. + * Incorrect Results with Batched Requests: Conditions may not be correctly scoped to individual rows within the batch, potentially causing inaccurate results. + * Fixed in 2.61.0. # [2.58.0] - 2024-08-06 @@ -214,6 +277,10 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) * Large Dataflow graphs using runner v2, or pipelines explicitly enabling the `upload_graph` experiment, will fail at construction time ([#32159](https://github.com/apache/beam/issues/32159)). * Python pipelines that run with 2.53.0-2.58.0 SDKs and read data from GCS might be affected by a data corruption issue ([#32169](https://github.com/apache/beam/issues/32169)). The issue will be fixed in 2.59.0 ([#32135](https://github.com/apache/beam/pull/32135)). To work around this, update the google-cloud-storage package to version 2.18.2 or newer. * [KafkaIO] Records read with `ReadFromKafkaViaSDF` are redistributed and may contain duplicates regardless of the configuration. This affects Java pipelines with Dataflow v2 runner and xlang pipelines reading from Kafka, ([#32196](https://github.com/apache/beam/issues/32196)) +* BigQuery Enrichment (Python): The following issues are present when using the BigQuery enrichment transform ([#32780](https://github.com/apache/beam/pull/32780)): + * Duplicate Rows: Multiple conditions may be applied incorrectly, leading to the duplication of rows in the output. + * Incorrect Results with Batched Requests: Conditions may not be correctly scoped to individual rows within the batch, potentially causing inaccurate results. + * Fixed in 2.61.0. # [2.57.0] - 2024-06-26 @@ -270,6 +337,10 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) * Large Dataflow graphs using runner v2, or pipelines explicitly enabling the `upload_graph` experiment, will fail at construction time ([#32159](https://github.com/apache/beam/issues/32159)). * Python pipelines that run with 2.53.0-2.58.0 SDKs and read data from GCS might be affected by a data corruption issue ([#32169](https://github.com/apache/beam/issues/32169)). The issue will be fixed in 2.59.0 ([#32135](https://github.com/apache/beam/pull/32135)). To work around this, update the google-cloud-storage package to version 2.18.2 or newer. +* BigQuery Enrichment (Python): The following issues are present when using the BigQuery enrichment transform ([#32780](https://github.com/apache/beam/pull/32780)): + * Duplicate Rows: Multiple conditions may be applied incorrectly, leading to the duplication of rows in the output. + * Incorrect Results with Batched Requests: Conditions may not be correctly scoped to individual rows within the batch, potentially causing inaccurate results. + * Fixed in 2.61.0. # [2.56.0] - 2024-05-01 @@ -479,6 +550,7 @@ classes finally moved to `extensions/avro`. In case if it's still required to us as a workaround, a copy of "old" `CountingSource` class should be placed into a project code and used directly ([#25252](https://github.com/apache/beam/issues/25252)). * Renamed `host` to `firestoreHost` in `FirestoreOptions` to avoid potential conflict of command line arguments (Java) ([#29201](https://github.com/apache/beam/pull/29201)). +* Transforms which use `SnappyCoder` are update incompatible with previous versions of the same transform (Java) on some runners. This includes PubSubIO's read ([#28655](https://github.com/apache/beam/pull/28655#issuecomment-2407839769)). ## Bugfixes @@ -496,6 +568,7 @@ as a workaround, a copy of "old" `CountingSource` class should be placed into a * MLTransform drops the identical elements in the output PCollection. For any duplicate elements, a single element will be emitted downstream. ([#29600](https://github.com/apache/beam/issues/29600)). * Some Python pipelines that run with 2.52.0-2.54.0 SDKs and use large materialized side inputs might be affected by a performance regression. To restore the prior behavior on these SDK versions, supply the `--max_cache_memory_usage_mb=0` pipeline option. (Python) ([#30360](https://github.com/apache/beam/issues/30360)). * Users who lauch Python pipelines in an environment without internet access and use the `--setup_file` pipeline option might experience an increase in pipeline submission time. This has been fixed in 2.56.0 ([#31070](https://github.com/apache/beam/pull/31070)). +* Transforms which use `SnappyCoder` are update incompatible with previous versions of the same transform (Java) on some runners. This includes PubSubIO's read ([#28655](https://github.com/apache/beam/pull/28655#issuecomment-2407839769)). # [2.51.0] - 2023-10-03 diff --git a/build.gradle.kts b/build.gradle.kts index 38b58b6979ee..0adb29058479 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -501,23 +501,11 @@ tasks.register("pythonFormatterPreCommit") { dependsOn("sdks:python:test-suites:tox:pycommon:formatter") } -tasks.register("python38PostCommit") { - dependsOn(":sdks:python:test-suites:dataflow:py38:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py38:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py38:hdfsIntegrationTest") - dependsOn(":sdks:python:test-suites:direct:py38:azureIntegrationTest") - dependsOn(":sdks:python:test-suites:portable:py38:postCommitPy38") - // TODO: https://github.com/apache/beam/issues/22651 - // The default container uses Python 3.8. The goal here is to - // duild Docker images for TensorRT tests during run time for python versions - // other than 3.8 and add these tests in other python postcommit suites. - dependsOn(":sdks:python:test-suites:dataflow:py38:inferencePostCommitIT") - dependsOn(":sdks:python:test-suites:direct:py38:inferencePostCommitIT") -} - tasks.register("python39PostCommit") { dependsOn(":sdks:python:test-suites:dataflow:py39:postCommitIT") dependsOn(":sdks:python:test-suites:direct:py39:postCommitIT") + dependsOn(":sdks:python:test-suites:direct:py39:hdfsIntegrationTest") + dependsOn(":sdks:python:test-suites:direct:py39:azureIntegrationTest") dependsOn(":sdks:python:test-suites:portable:py39:postCommitPy39") // TODO (https://github.com/apache/beam/issues/23966) // Move this to Python 3.10 test suite once tfx-bsl has python 3.10 wheel. @@ -528,6 +516,11 @@ tasks.register("python310PostCommit") { dependsOn(":sdks:python:test-suites:dataflow:py310:postCommitIT") dependsOn(":sdks:python:test-suites:direct:py310:postCommitIT") dependsOn(":sdks:python:test-suites:portable:py310:postCommitPy310") + // TODO: https://github.com/apache/beam/issues/22651 + // The default container uses Python 3.10. The goal here is to + // duild Docker images for TensorRT tests during run time for python versions + // other than 3.10 and add these tests in other python postcommit suites. + dependsOn(":sdks:python:test-suites:dataflow:py310:inferencePostCommitIT") } tasks.register("python311PostCommit") { @@ -654,6 +647,22 @@ tasks.register("checkSetup") { dependsOn(":examples:java:wordCount") } +// if not disabled make spotlessApply dependency of compileJava and compileTestJava +val disableSpotlessCheck: String by project +val isSpotlessDisabled = project.hasProperty("disableSpotlessCheck") && + disableSpotlessCheck == "true" +if (!isSpotlessDisabled) { + subprojects { + afterEvaluate { + tasks.findByName("spotlessApply")?.let { + listOf("compileJava", "compileTestJava").forEach { + t -> tasks.findByName(t)?.let { f -> f.dependsOn("spotlessApply") } + } + } + } + } +} + // Generates external transform config project.tasks.register("generateExternalTransformsConfig") { dependsOn(":sdks:python:generateExternalTransformsConfig") diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy index cd46c1270f83..b3949223f074 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy @@ -59,6 +59,7 @@ class BeamDockerPlugin implements Plugin { boolean load = false boolean push = false String builder = null + String target = null File resolvedDockerfile = null File resolvedDockerComposeTemplate = null @@ -130,6 +131,7 @@ class BeamDockerPlugin implements Plugin { group = 'Docker' description = 'Builds Docker image.' dependsOn prepare + environment 'DOCKER_BUILDKIT', '1' }) Task tag = project.tasks.create('dockerTag', { @@ -288,6 +290,9 @@ class BeamDockerPlugin implements Plugin { } else { buildCommandLine.addAll(['-t', "${-> ext.name}", '.']) } + if (ext.target != null && ext.target != "") { + buildCommandLine.addAll '--target', ext.target + } logger.debug("${buildCommandLine}" as String) return buildCommandLine } diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index a7e129211757..a59c1d7630b0 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -603,18 +603,18 @@ class BeamModulePlugin implements Plugin { def dbcp2_version = "2.9.0" def errorprone_version = "2.10.0" // [bomupgrader] determined by: com.google.api:gax, consistent with: google_cloud_platform_libraries_bom - def gax_version = "2.52.0" + def gax_version = "2.55.0" def google_ads_version = "33.0.0" def google_clients_version = "2.0.0" def google_cloud_bigdataoss_version = "2.2.16" // [bomupgrader] determined by: com.google.cloud:google-cloud-spanner, consistent with: google_cloud_platform_libraries_bom - def google_cloud_spanner_version = "6.74.0" + def google_cloud_spanner_version = "6.79.0" def google_code_gson_version = "2.10.1" def google_oauth_clients_version = "1.34.1" // [bomupgrader] determined by: io.grpc:grpc-netty, consistent with: google_cloud_platform_libraries_bom - def grpc_version = "1.66.0" + def grpc_version = "1.67.1" def guava_version = "33.1.0-jre" - def hadoop_version = "2.10.2" + def hadoop_version = "3.4.1" def hamcrest_version = "2.1" def influxdb_version = "2.19" def httpclient_version = "4.5.13" @@ -627,16 +627,17 @@ class BeamModulePlugin implements Plugin { def log4j2_version = "2.20.0" def nemo_version = "0.1" // [bomupgrader] determined by: io.grpc:grpc-netty, consistent with: google_cloud_platform_libraries_bom - def netty_version = "4.1.100.Final" + def netty_version = "4.1.110.Final" def postgres_version = "42.2.16" def powermock_version = "2.0.9" // [bomupgrader] determined by: com.google.protobuf:protobuf-java, consistent with: google_cloud_platform_libraries_bom - def protobuf_version = "3.25.4" + def protobuf_version = "3.25.5" def qpid_jms_client_version = "0.61.0" def quickcheck_version = "1.0" def sbe_tool_version = "1.25.1" def singlestore_jdbc_version = "1.1.4" def slf4j_version = "1.7.30" + def snakeyaml_engine_version = "2.6" def snakeyaml_version = "2.2" def solace_version = "10.21.0" def spark2_version = "2.4.8" @@ -668,7 +669,7 @@ class BeamModulePlugin implements Plugin { antlr_runtime : "org.antlr:antlr4-runtime:4.7", args4j : "args4j:args4j:2.33", auto_value_annotations : "com.google.auto.value:auto-value-annotations:$autovalue_version", - avro : "org.apache.avro:avro:1.11.3", + avro : "org.apache.avro:avro:1.11.4", avro_tests : "org.apache.avro:avro:1.11.3:tests", aws_java_sdk_cloudwatch : "com.amazonaws:aws-java-sdk-cloudwatch:$aws_java_sdk_version", aws_java_sdk_core : "com.amazonaws:aws-java-sdk-core:$aws_java_sdk_version", @@ -714,7 +715,7 @@ class BeamModulePlugin implements Plugin { cdap_plugin_zendesk : "io.cdap.plugin:zendesk-plugins:1.0.0", checker_qual : "org.checkerframework:checker-qual:$checkerframework_version", classgraph : "io.github.classgraph:classgraph:$classgraph_version", - commons_codec : "commons-codec:commons-codec:1.17.0", + commons_codec : "commons-codec:commons-codec:1.17.1", commons_collections : "commons-collections:commons-collections:3.2.2", commons_compress : "org.apache.commons:commons-compress:1.26.2", commons_csv : "org.apache.commons:commons-csv:1.8", @@ -736,12 +737,12 @@ class BeamModulePlugin implements Plugin { google_api_client_gson : "com.google.api-client:google-api-client-gson:$google_clients_version", google_api_client_java6 : "com.google.api-client:google-api-client-java6:$google_clients_version", google_api_common : "com.google.api:api-common", // google_cloud_platform_libraries_bom sets version - google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20240815-2.0.0", // [bomupgrader] sets version + google_api_services_bigquery : "com.google.apis:google-api-services-bigquery:v2-rev20240919-2.0.0", // [bomupgrader] sets version google_api_services_cloudresourcemanager : "com.google.apis:google-api-services-cloudresourcemanager:v1-rev20240310-2.0.0", // [bomupgrader] sets version google_api_services_dataflow : "com.google.apis:google-api-services-dataflow:v1b3-rev20240817-$google_clients_version", google_api_services_healthcare : "com.google.apis:google-api-services-healthcare:v1-rev20240130-$google_clients_version", google_api_services_pubsub : "com.google.apis:google-api-services-pubsub:v1-rev20220904-$google_clients_version", - google_api_services_storage : "com.google.apis:google-api-services-storage:v1-rev20240706-2.0.0", // [bomupgrader] sets version + google_api_services_storage : "com.google.apis:google-api-services-storage:v1-rev20240924-2.0.0", // [bomupgrader] sets version google_auth_library_credentials : "com.google.auth:google-auth-library-credentials", // google_cloud_platform_libraries_bom sets version google_auth_library_oauth2_http : "com.google.auth:google-auth-library-oauth2-http", // google_cloud_platform_libraries_bom sets version google_cloud_bigquery : "com.google.cloud:google-cloud-bigquery", // google_cloud_platform_libraries_bom sets version @@ -753,13 +754,13 @@ class BeamModulePlugin implements Plugin { google_cloud_core_grpc : "com.google.cloud:google-cloud-core-grpc", // google_cloud_platform_libraries_bom sets version google_cloud_datacatalog_v1beta1 : "com.google.cloud:google-cloud-datacatalog", // google_cloud_platform_libraries_bom sets version google_cloud_dataflow_java_proto_library_all: "com.google.cloud.dataflow:google-cloud-dataflow-java-proto-library-all:0.5.160304", - google_cloud_datastore_v1_proto_client : "com.google.cloud.datastore:datastore-v1-proto-client:2.21.2", // [bomupgrader] sets version + google_cloud_datastore_v1_proto_client : "com.google.cloud.datastore:datastore-v1-proto-client:2.23.0", // [bomupgrader] sets version google_cloud_firestore : "com.google.cloud:google-cloud-firestore", // google_cloud_platform_libraries_bom sets version google_cloud_pubsub : "com.google.cloud:google-cloud-pubsub", // google_cloud_platform_libraries_bom sets version google_cloud_pubsublite : "com.google.cloud:google-cloud-pubsublite", // google_cloud_platform_libraries_bom sets version // [bomupgrader] the BOM version is set by scripts/tools/bomupgrader.py. If update manually, also update // libraries-bom version on sdks/java/container/license_scripts/dep_urls_java.yaml - google_cloud_platform_libraries_bom : "com.google.cloud:libraries-bom:26.45.0", + google_cloud_platform_libraries_bom : "com.google.cloud:libraries-bom:26.49.0", google_cloud_secret_manager : "com.google.cloud:google-cloud-secretmanager", // google_cloud_platform_libraries_bom sets version google_cloud_spanner : "com.google.cloud:google-cloud-spanner", // google_cloud_platform_libraries_bom sets version google_cloud_spanner_test : "com.google.cloud:google-cloud-spanner:$google_cloud_spanner_version:tests", @@ -794,6 +795,7 @@ class BeamModulePlugin implements Plugin { grpc_xds : "io.grpc:grpc-xds", // google_cloud_platform_libraries_bom sets version guava : "com.google.guava:guava:$guava_version", guava_testlib : "com.google.guava:guava-testlib:$guava_version", + hadoop_auth : "org.apache.hadoop:hadoop-auth:$hadoop_version", hadoop_client : "org.apache.hadoop:hadoop-client:$hadoop_version", hadoop_common : "org.apache.hadoop:hadoop-common:$hadoop_version", hadoop_mapreduce_client_core : "org.apache.hadoop:hadoop-mapreduce-client-core:$hadoop_version", @@ -870,6 +872,7 @@ class BeamModulePlugin implements Plugin { singlestore_jdbc : "com.singlestore:singlestore-jdbc-client:$singlestore_jdbc_version", slf4j_api : "org.slf4j:slf4j-api:$slf4j_version", snake_yaml : "org.yaml:snakeyaml:$snakeyaml_version", + snakeyaml_engine : "org.snakeyaml:snakeyaml-engine:$snakeyaml_engine_version", slf4j_android : "org.slf4j:slf4j-android:$slf4j_version", slf4j_ext : "org.slf4j:slf4j-ext:$slf4j_version", slf4j_jdk14 : "org.slf4j:slf4j-jdk14:$slf4j_version", @@ -907,7 +910,7 @@ class BeamModulePlugin implements Plugin { testcontainers_solace : "org.testcontainers:solace:$testcontainers_version", truth : "com.google.truth:truth:1.1.5", threetenbp : "org.threeten:threetenbp:1.6.8", - vendored_grpc_1_60_1 : "org.apache.beam:beam-vendor-grpc-1_60_1:0.2", + vendored_grpc_1_60_1 : "org.apache.beam:beam-vendor-grpc-1_60_1:0.3", vendored_guava_32_1_2_jre : "org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1", vendored_calcite_1_28_0 : "org.apache.beam:beam-vendor-calcite-1_28_0:0.2", woodstox_core_asl : "org.codehaus.woodstox:woodstox-core-asl:4.4.1", @@ -2511,6 +2514,8 @@ class BeamModulePlugin implements Plugin { def taskName = "run${config.type}Java${config.runner}" def releaseVersion = project.findProperty('ver') ?: project.version def releaseRepo = project.findProperty('repourl') ?: 'https://repository.apache.org/content/repositories/snapshots' + // shared maven local path for maven archetype projects + def sharedMavenLocal = project.findProperty('mavenLocalPath') ?: '' def argsNeeded = [ "--ver=${releaseVersion}", "--repourl=${releaseRepo}" @@ -2530,6 +2535,9 @@ class BeamModulePlugin implements Plugin { if (config.pubsubTopic) { argsNeeded.add("--pubsubTopic=${config.pubsubTopic}") } + if (sharedMavenLocal) { + argsNeeded.add("--mavenLocalPath=${sharedMavenLocal}") + } project.evaluationDependsOn(':release') project.task(taskName, dependsOn: ':release:classes', type: JavaExec) { group = "Verification" @@ -3145,7 +3153,6 @@ class BeamModulePlugin implements Plugin { mustRunAfter = [ ":runners:flink:${project.ext.latestFlinkVersion}:job-server:shadowJar", ':runners:spark:3:job-server:shadowJar', - ':sdks:python:container:py38:docker', ':sdks:python:container:py39:docker', ':sdks:python:container:py310:docker', ':sdks:python:container:py311:docker', diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_60_1.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_60_1.groovy index b2c7053dfb60..62733efb507c 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_60_1.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_60_1.groovy @@ -189,6 +189,9 @@ class GrpcVendoring_1_60_1 { "org/junit/**", "org/mockito/**", "org/objenesis/**", + // proto source files + "google/**/*.proto", + "grpc/**/*.proto", ] } diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy index bb08e79edd3c..a3ae6833d579 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy @@ -28,11 +28,13 @@ class KafkaTestUtilities { @Inject KafkaBatchIT(String delimited, String undelimited, Boolean sdfCompatible, ConfigurationContainer configurations, Project runningProject){ + def kafkaioProject = runningProject.findProject(":sdks:java:io:kafka") group = "Verification" description = "Runs KafkaIO IT tests with Kafka clients API $delimited" outputs.upToDateWhen { false } testClassesDirs = runningProject.findProject(":sdks:java:io:kafka").sourceSets.test.output.classesDirs - classpath = configurations."kafkaVersion$undelimited" + runningProject.sourceSets.test.runtimeClasspath + runningProject.findProject(":sdks:java:io:kafka").sourceSets.test.runtimeClasspath + classpath = runningProject.sourceSets.test.runtimeClasspath + kafkaioProject.configurations."kafkaVersion$undelimited" + kafkaioProject.sourceSets.test.runtimeClasspath + systemProperty "beam.target.kafka.version", delimited def pipelineOptions = [ '--sourceOptions={' + diff --git a/contributor-docs/code-change-guide.md b/contributor-docs/code-change-guide.md index f0785d3509d0..b4300103454c 100644 --- a/contributor-docs/code-change-guide.md +++ b/contributor-docs/code-change-guide.md @@ -145,7 +145,7 @@ in the Google Cloud documentation. Depending on the languages involved, your `PATH` file needs to have the following elements configured. -* A Java environment that uses a supported Java version, preferably Java 8. +* A Java environment that uses a supported Java version, preferably Java 11. * This environment is needed for all development, because Beam is a Gradle project that uses JVM. * Recommended: To manage Java versions, use [sdkman](https://sdkman.io/install). @@ -624,6 +624,11 @@ Tips for using the Dataflow runner: ## Appendix +### Common Issues + +* If you run into some strange errors such as `java.lang.NoClassDefFoundError`, run `./gradlew clean` first +* To run one single Java test with gradle, use `--tests` to filter, for example, `./gradlew :it:google-cloud-platform:WordCountIntegrationTest --tests "org.apache.beam.it.gcp.WordCountIT.testWordCountDataflow"` + ### Directories of snapshot builds * https://repository.apache.org/content/groups/snapshots/org/apache/beam/ Java SDK build (nightly) diff --git a/contributor-docs/discussion-docs/2024.md b/contributor-docs/discussion-docs/2024.md index baea7c9fc462..124fe8ef9bb7 100644 --- a/contributor-docs/discussion-docs/2024.md +++ b/contributor-docs/discussion-docs/2024.md @@ -35,11 +35,12 @@ limitations under the License. | 18 | Danny McCormick | [GSoC Proposal : Implement RAG Pipelines using Beam](https://docs.google.com/document/d/1M_8fvqKVBi68hQo_x1AMQ8iEkzeXTcSl0CwTH00cr80) | 2024-05-01 16:12:23 | | 19 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - June 2024](https://s.apache.org/beam-draft-report-2024-06) | 2024-05-23 14:57:16 | | 20 | Jack McCluskey | [Embeddings in MLTransform](https://docs.google.com/document/d/1En4bfbTu4rvu7LWJIKV3G33jO-xJfTdbaSFSURmQw_s) | 2024-05-29 10:26:47 | -| 21 | Bartosz Zab��ocki | [[External] Solace IO - Read Connector](https://docs.google.com/document/d/1Gvq67VrcHCnlO8f_NzMM1Y4c7wCNSdvo6qqLWg8upfw) | 2024-05-29 12:00:23 | +| 21 | Bartosz Zabłocki | [[External] Solace IO - Read Connector](https://docs.google.com/document/d/1Gvq67VrcHCnlO8f_NzMM1Y4c7wCNSdvo6qqLWg8upfw) | 2024-05-29 12:00:23 | | 22 | Danny McCormick | [RunInference Timeouts](https://docs.google.com/document/d/19ves6iv-m_6DFmePJZqYpLm-bCooPu6wQ-Ti6kAl2Jo) | 2024-08-07 07:11:38 | | 23 | Jack McCluskey | [BatchElements in Beam Python](https://docs.google.com/document/d/1fOjIjIUH5dxllOGp5Z4ZmpM7BJhAJc2-hNjTnyChvgc) | 2024-08-15 14:56:26 | | 24 | XQ Hu | [[Public] Beam 3.0: a discussion doc](https://docs.google.com/document/d/13r4NvuvFdysqjCTzMHLuUUXjKTIEY3d7oDNIHT6guww) | 2024-08-19 17:17:26 | | 25 | Danny McCormick | [Beam Patch Release Process](https://docs.google.com/document/d/1o4UK444hCm1t5KZ9ufEu33e_o400ONAehXUR9A34qc8) | 2024-08-23 04:51:48 | | 26 | Jack McCluskey | [Beam Python Type Hinting](https://s.apache.org/beam-python-type-hinting-overview) | 2024-08-26 14:16:42 | | 27 | Ahmed Abualsaud | [Python Multi-language with SchemaTransforms](https://docs.google.com/document/d/1_embA3pGwoYG7sbHaYzAkg3hNxjTughhFCY8ThcoK_Q) | 2024-08-26 19:53:10 | -| 28 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - September 2024](https://s.apache.org/beam-draft-report-2024-09) | 2024-09-11 15:01:55 | \ No newline at end of file +| 28 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - September 2024](https://s.apache.org/beam-draft-report-2024-09) | 2024-09-11 15:01:55 | +| 29 | Jeff Kinard | [Beam YA(ML)^2](https://docs.google.com/document/d/1z9lNlSBfqDVdOP1frJNv_NJoMR1F1VBI29wn788x6IE/) | 2024-09-11 15:01:55 | diff --git a/contributor-docs/release-guide.md b/contributor-docs/release-guide.md index 4fe35aa4aac2..00f11c32aa7d 100644 --- a/contributor-docs/release-guide.md +++ b/contributor-docs/release-guide.md @@ -507,7 +507,7 @@ with tags: `${RELEASE_VERSION}rc${RC_NUM}` Verify that third party licenses are included in Docker. You can do this with a simple script: RC_TAG=${RELEASE_VERSION}rc${RC_NUM} - for pyver in 3.8 3.9 3.10 3.11; do + for pyver in 3.9 3.10 3.11 3.12; do docker run --rm --entrypoint sh \ apache/beam_python${pyver}_sdk:${RC_TAG} \ -c 'ls -al /opt/apache/beam/third_party_licenses/ | wc -l' @@ -554,10 +554,10 @@ to PyPI with an `rc` suffix. __Attention:__ Verify that: - [ ] The File names version include ``rc-#`` suffix -- [ ] [Download Files](https://pypi.org/project/apache-beam/#files) have: - - [ ] All wheels uploaded as artifacts - - [ ] Release source's zip published - - [ ] Signatures and hashes do not need to be uploaded +- [Download Files](https://pypi.org/project/apache-beam/#files) have: +- [ ] All wheels uploaded as artifacts +- [ ] Release source's zip published +- [ ] Signatures and hashes do not need to be uploaded ### Propose pull requests for website updates @@ -756,7 +756,7 @@ template; please adjust as you see fit. Reviewers are encouraged to test their own use cases with the release candidate, and vote +1 if no issues are found. Only PMC member votes will count towards the final vote, but votes from all - community members is encouraged and helpful for finding regressions; you can either test your own + community members are encouraged and helpful for finding regressions; you can either test your own use cases [13] or use cases from the validation sheet [10]. The complete staging area is available for your review, which includes: @@ -765,7 +765,7 @@ template; please adjust as you see fit. * all artifacts to be deployed to the Maven Central Repository [4], * source code tag "v1.2.3-RC3" [5], * website pull request listing the release [6], the blog post [6], and publishing the API reference manual [7]. - * Python artifacts are deployed along with the source release to the dist.apache.org [2] and PyPI[8]. + * Python artifacts are deployed along with the source release to dist.apache.org [2] and PyPI[8]. * Go artifacts and documentation are available at pkg.go.dev [9] * Validation sheet with a tab for 1.2.3 release to help with validation [10]. * Docker images published to Docker Hub [11]. @@ -887,7 +887,7 @@ write to BigQuery, and create a cluster of machines for running containers (for ``` **Flink Local Runner** ``` - ./gradlew :runners:flink:1.18:runQuickstartJavaFlinkLocal \ + ./gradlew :runners:flink:1.19:runQuickstartJavaFlinkLocal \ -Prepourl=https://repository.apache.org/content/repositories/orgapachebeam-${KEY} \ -Pver=${RELEASE_VERSION} ``` @@ -1148,7 +1148,7 @@ All wheels should be published, in addition to the zip of the release source. ### Merge Website pull requests Merge all of the website pull requests -- [listing the release](/get-started/downloads/) +- [listing the release](https://beam.apache.org/get-started/downloads/) - publishing the [Python API reference manual](https://beam.apache.org/releases/pydoc/) and the [Java API reference manual](https://beam.apache.org/releases/javadoc/), and - adding the release blog post. @@ -1309,6 +1309,12 @@ You can also update the versions in https://github.com/apache/beam-starter-pytho https://github.com/apache/beam-starter-go if you would like. This is optional because dependabot will automatically open a PR to do this if you don't. +### Update the container republishing workflow + +After the Beam release is published, update the default versions in https://github.com/apache/beam/blob/master/.github/workflows/republish_released_docker_containers.yml#L37 +to point to the most recent release and its accepted RC version. This script will then regularly +republish containers using the same underlying source (but updated base images) to allow users to stay ahead of vulnerabilities. + ### Update Beam Playground After new Beam Release is published, Beam Playground can be updated following the steps below. If any steps fail, make diff --git a/examples/java/build.gradle b/examples/java/build.gradle index af91fa83fe91..4f1902cf1679 100644 --- a/examples/java/build.gradle +++ b/examples/java/build.gradle @@ -66,6 +66,8 @@ dependencies { implementation project(":sdks:java:extensions:python") implementation project(":sdks:java:io:google-cloud-platform") implementation project(":sdks:java:io:kafka") + runtimeOnly project(":sdks:java:io:iceberg") + implementation project(":sdks:java:managed") implementation project(":sdks:java:extensions:ml") implementation library.java.avro implementation library.java.bigdataoss_util @@ -100,6 +102,8 @@ dependencies { implementation "org.apache.httpcomponents:httpcore:4.4.13" implementation "com.fasterxml.jackson.core:jackson-annotations:2.14.1" implementation "com.fasterxml.jackson.core:jackson-core:2.14.1" + runtimeOnly library.java.hadoop_client + runtimeOnly library.java.bigdataoss_gcs_connector testImplementation project(path: ":runners:direct-java", configuration: "shadow") testImplementation project(":sdks:java:io:google-cloud-platform") testImplementation project(":sdks:java:extensions:ml") diff --git a/examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java b/examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java new file mode 100644 index 000000000000..446d11d03be4 --- /dev/null +++ b/examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples.cookbook; + +import java.util.Arrays; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.transforms.Filter; +import org.apache.beam.sdk.transforms.JsonToRow; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +/** + * Reads real-time NYC taxi ride information from {@code + * projects/pubsub-public-data/topics/taxirides-realtime} and writes to Iceberg tables using Beam's + * {@link Managed} IcebergIO sink. + * + *

This is a streaming pipeline that writes records to Iceberg tables dynamically, depending on + * each record's passenger count. New tables are created as needed. We set a triggering frequency of + * 10s; at around this interval, the sink will accumulate records and write them to the appropriate + * table, creating a new snapshot each time. + */ +public class IcebergTaxiExamples { + private static final String TAXI_RIDES_TOPIC = + "projects/pubsub-public-data/topics/taxirides-realtime"; + private static final Schema TAXI_RIDE_INFO_SCHEMA = + Schema.builder() + .addStringField("ride_id") + .addInt32Field("point_idx") + .addDoubleField("latitude") + .addDoubleField("longitude") + .addStringField("timestamp") + .addDoubleField("meter_reading") + .addDoubleField("meter_increment") + .addStringField("ride_status") + .addInt32Field("passenger_count") + .build(); + + public static void main(String[] args) { + IcebergPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).as(IcebergPipelineOptions.class); + options.setProject("apache-beam-testing"); + + // each record's 'passenger_count' value will be substituted in to determine + // its final table destination + // e.g. an event with 3 passengers will be written to 'iceberg_taxi.3_passengers' + String tableIdentifierTemplate = "iceberg_taxi.{passenger_count}_passengers"; + + Map catalogProps = + ImmutableMap.builder() + .put("catalog-impl", options.getCatalogImpl()) + .put("warehouse", options.getWarehouse()) + .build(); + Map icebergWriteConfig = + ImmutableMap.builder() + .put("table", tableIdentifierTemplate) + .put("catalog_name", options.getCatalogName()) + .put("catalog_properties", catalogProps) + .put("triggering_frequency_seconds", 10) + // perform a final filter to only write these two columns + .put("keep", Arrays.asList("ride_id", "meter_reading")) + .build(); + + Pipeline p = Pipeline.create(options); + p + // Read taxi ride data + .apply(PubsubIO.readStrings().fromTopic(TAXI_RIDES_TOPIC)) + // Convert JSON strings to Beam Rows + .apply(JsonToRow.withSchema(TAXI_RIDE_INFO_SCHEMA)) + // Filter to only include drop-offs + .apply(Filter.create().whereFieldName("ride_status", "dropoff"::equals)) + // Write to Iceberg tables + .apply(Managed.write(Managed.ICEBERG).withConfig(icebergWriteConfig)); + p.run(); + } + + public interface IcebergPipelineOptions extends GcpOptions { + @Description("Warehouse location where the table's data will be written to.") + @Default.String("gs://apache-beam-samples/iceberg-examples") + String getWarehouse(); + + void setWarehouse(String warehouse); + + @Description("Fully-qualified name of the catalog class to use.") + @Default.String("org.apache.iceberg.hadoop.HadoopCatalog") + String getCatalogImpl(); + + void setCatalogImpl(String catalogName); + + @Validation.Required + @Default.String("example-catalog") + String getCatalogName(); + + void setCatalogName(String catalogName); + } +} diff --git a/examples/java/src/main/java/org/apache/beam/examples/subprocess/ExampleEchoPipeline.java b/examples/java/src/main/java/org/apache/beam/examples/subprocess/ExampleEchoPipeline.java index 289ff66cc46b..4aa20fc10dfb 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/subprocess/ExampleEchoPipeline.java +++ b/examples/java/src/main/java/org/apache/beam/examples/subprocess/ExampleEchoPipeline.java @@ -33,16 +33,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * In this example batch pipeline we will invoke a simple Echo C++ library within a DoFn The sample - * makes use of a ExternalLibraryDoFn class which abstracts the setup and processing of the - * executable, logs and results. For this example we are using commands passed to the library based - * on ordinal position but for a production system you should use a mechanism like ProtoBuffers with - * Base64 encoding to pass the parameters to the library To test this example you will need to build - * the files Echo.cc and EchoAgain.cc in a linux env matching the runner that you are using (using - * g++ with static option). Once built copy them to the SourcePath defined in {@link - * SubProcessPipelineOptions} - */ +/** Please see the Readme.MD file for instructions to execute this pipeline. */ public class ExampleEchoPipeline { public static void main(String[] args) throws Exception { diff --git a/examples/java/src/main/java/org/apache/beam/examples/subprocess/Readme.MD b/examples/java/src/main/java/org/apache/beam/examples/subprocess/Readme.MD new file mode 100644 index 000000000000..50f9a2864de8 --- /dev/null +++ b/examples/java/src/main/java/org/apache/beam/examples/subprocess/Readme.MD @@ -0,0 +1,86 @@ + +# Apache Beam Subprocess Example + +This example demonstrates how to execute external C++ binaries as subprocesses within an Apache Beam pipeline using the `SubProcessKernel`. + +## Prerequisites + +* **Google Cloud Project:** A Google Cloud project with billing enabled. +* **Dataflow API:** Enable the Dataflow API for your project. +* **C++ compiler:** You'll need a C++ compiler (like g++) to compile the C++ binaries. + +## Steps + +1. **Create a [Maven Example project](https://beam.apache.org/get-started/quickstart-java/) that builds against the latest Beam release:** + + ```bash + mvn archetype:generate \ + -DarchetypeGroupId=org.apache.beam \ + -DarchetypeArtifactId=beam-sdks-java-maven-archetypes-examples \ + -DarchetypeVersion=2.60.0 \ + -DgroupId=org.example \ + -DartifactId=word-count-beam \ + -Dversion="0.1" \ + -Dpackage=org.apache.beam.examples \ + -DinteractiveMode=false + ``` + +2. **Build the project:** + + * Navigate to the root of the repository (`word-count-beam/`): + + ```bash + cd word-count-beam/ + ``` + + * Build the project using Maven: + + ```bash + mvn clean install + ``` + +3. **Run the pipeline on Dataflow:** + + ```bash + mvn compile exec:java \ + -Dexec.mainClass=org.apache.beam.examples.subprocess.ExampleEchoPipeline \ + -Dexec.args="--sourcePath=/absolute/path/to/your/subprocess/directory \ + --workerPath=/absolute/path/to/your/subprocess/directory \ + --concurrency=5 \ + --filesToStage=/absolute/path/to/your/subprocess/directory/echo,/absolute/path/to/your/subprocess/directory/echoagain \ + --runner=DataflowRunner \ + --project=your-project-id \ + --region=your-gcp-region \ + --tempLocation=gs://your-gcs-bucket/temp" + ``` + + * Replace the placeholders with your actual paths, project ID, region, and Cloud Storage bucket. + +## Important notes + +* **Dependencies:** Ensure your `pom.xml` includes the Dataflow runner dependency (`beam-runners-google-cloud-dataflow-java`). +* **Authentication:** Authenticate your environment to your Google Cloud project. +* **DirectRunner:** On `DirectRunner`, you will see the error ` Process succeded but no result file was found`, showing that the Process is successful. + +## Code overview + +* **`ExampleEchoPipeline.java`:** This Java file defines the Beam pipeline that executes the `Echo` and `EchoAgain` binaries as subprocesses. +* **`Echo.cc` and `Echoagain.cc`:** These C++ files contain the code for the external binaries. These won't be visible when running the example with the created example project. You will need to compile these (using `g++ Echo.cc -o Echo` and `g++ EchoAgain.cc -o EchoAgain`), and then provide their path via the `sourcePath` and `workerPath` flags as listed above. +* **`SubProcessKernel.java`:** This class in the Beam Java SDK handles the execution of external binaries and captures their output. diff --git a/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java b/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java index 6b9916e87271..43b06347a39a 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java @@ -25,7 +25,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.MapElements; @@ -33,7 +32,6 @@ import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -66,7 +64,6 @@ public void testExtractWordsFn() throws Exception { /** Example test that tests a PTransform by using an in-memory input and inspecting the output. */ @Test - @Category(ValidatesRunner.class) public void testCountWords() throws Exception { PCollection input = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of())); diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java index 96d10a1f72ed..614d289d2d60 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java @@ -21,12 +21,10 @@ import java.util.Arrays; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,7 +35,6 @@ public class TopWikipediaSessionsTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testComputeTopUsers() { PCollection output = diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java index 33d3c5699477..9c99c3aafdcc 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java @@ -24,13 +24,11 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -69,7 +67,6 @@ public class GameStatsTest implements Serializable { /** Test the calculation of 'spammy users'. */ @Test - @Category(ValidatesRunner.class) public void testCalculateSpammyUsers() throws Exception { PCollection> input = p.apply(Create.of(USER_SCORES)); PCollection> output = input.apply(new CalculateSpammyUsers()); diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java index 1d89351adcf8..46d7b41746ab 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java @@ -26,7 +26,6 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.MapElements; @@ -37,7 +36,6 @@ import org.joda.time.Instant; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -90,7 +88,6 @@ public class HourlyTeamScoreTest implements Serializable { /** Test the filtering. */ @Test - @Category(ValidatesRunner.class) public void testUserScoresFilter() throws Exception { final Instant startMinTimestamp = new Instant(1447965680000L); diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java index 04aa122054bd..22fe98a50304 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java @@ -26,7 +26,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; @@ -36,7 +35,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -114,7 +112,6 @@ public void testParseEventFn() throws Exception { /** Tests ExtractAndSumScore("user"). */ @Test - @Category(ValidatesRunner.class) public void testUserScoreSums() throws Exception { PCollection input = p.apply(Create.of(GAME_EVENTS)); @@ -133,7 +130,6 @@ public void testUserScoreSums() throws Exception { /** Tests ExtractAndSumScore("team"). */ @Test - @Category(ValidatesRunner.class) public void testTeamScoreSums() throws Exception { PCollection input = p.apply(Create.of(GAME_EVENTS)); @@ -152,7 +148,6 @@ public void testTeamScoreSums() throws Exception { /** Test that bad input data is dropped appropriately. */ @Test - @Category(ValidatesRunner.class) public void testUserScoresBadInput() throws Exception { PCollection input = p.apply(Create.of(GAME_EVENTS2).withCoder(StringUtf8Coder.of())); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java index 2bd37a3caa52..110349c353e8 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java @@ -22,7 +22,6 @@ import org.apache.beam.examples.cookbook.BigQueryTornadoes.FormatCountsFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -31,7 +30,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,7 +39,6 @@ public class BigQueryTornadoesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testExtractTornadoes() { TableRow row = new TableRow().set("month", "6").set("tornado", true); PCollection input = p.apply(Create.of(ImmutableList.of(row))); @@ -51,7 +48,6 @@ public void testExtractTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testNoTornadoes() { TableRow row = new TableRow().set("month", 6).set("tornado", false); PCollection inputs = p.apply(Create.of(ImmutableList.of(row))); @@ -61,7 +57,6 @@ public void testNoTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testEmpty() { PCollection> inputs = p.apply(Create.empty(new TypeDescriptor>() {})); @@ -71,7 +66,6 @@ public void testEmpty() { } @Test - @Category(ValidatesRunner.class) public void testFormatCounts() { PCollection> inputs = p.apply(Create.of(KV.of(3, 0L), KV.of(4, Long.MAX_VALUE), KV.of(5, Long.MIN_VALUE))); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java index 988a492ad4a9..7ec889f0d2b4 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java @@ -22,13 +22,11 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Distinct; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -39,7 +37,6 @@ public class DistinctExampleTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testDistinct() { List strings = Arrays.asList("k1", "k5", "k5", "k2", "k1", "k2", "k3"); @@ -52,7 +49,6 @@ public void testDistinct() { } @Test - @Category(ValidatesRunner.class) public void testDistinctEmpty() { List strings = Arrays.asList(); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java index dedc0e313350..4eeb5c4b7dd0 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java @@ -22,13 +22,11 @@ import org.apache.beam.examples.cookbook.FilterExamples.ProjectionFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -68,7 +66,6 @@ public class FilterExamplesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testProjectionFn() { PCollection input = p.apply(Create.of(row1, row2, row3)); @@ -79,7 +76,6 @@ public void testProjectionFn() { } @Test - @Category(ValidatesRunner.class) public void testFilterSingleMonthDataFn() { PCollection input = p.apply(Create.of(outRow1, outRow2, outRow3)); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java index b91ca985ddcd..d27572667752 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java @@ -24,14 +24,12 @@ import org.apache.beam.examples.cookbook.JoinExamples.ExtractEventDataFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -107,7 +105,6 @@ public void testExtractCountryInfoFn() throws Exception { } @Test - @Category(ValidatesRunner.class) public void testJoin() throws java.lang.Exception { PCollection input1 = p.apply("CreateEvent", Create.of(EVENT_ARRAY)); PCollection input2 = p.apply("CreateCC", Create.of(CC_ARRAY)); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java index 5c32f36660d6..410b151ed32f 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java @@ -23,7 +23,6 @@ import org.apache.beam.examples.cookbook.MaxPerKeyExamples.FormatMaxesFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -31,7 +30,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -78,7 +76,6 @@ public class MaxPerKeyExamplesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testExtractTempFn() { PCollection> results = p.apply(Create.of(TEST_ROWS)).apply(ParDo.of(new ExtractTempFn())); @@ -87,7 +84,6 @@ public void testExtractTempFn() { } @Test - @Category(ValidatesRunner.class) public void testFormatMaxesFn() { PCollection results = p.apply(Create.of(TEST_KVS)).apply(ParDo.of(new FormatMaxesFn())); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java index 7e922bc87965..fb08730a9f54 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java @@ -22,7 +22,6 @@ import org.apache.beam.examples.cookbook.MinimalBigQueryTornadoes.FormatCountsFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -31,7 +30,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,7 +39,6 @@ public class MinimalBigQueryTornadoesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testExtractTornadoes() { TableRow row = new TableRow().set("month", "6").set("tornado", true); PCollection input = p.apply(Create.of(ImmutableList.of(row))); @@ -51,7 +48,6 @@ public void testExtractTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testNoTornadoes() { TableRow row = new TableRow().set("month", 6).set("tornado", false); PCollection inputs = p.apply(Create.of(ImmutableList.of(row))); @@ -61,7 +57,6 @@ public void testNoTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testEmpty() { PCollection> inputs = p.apply(Create.empty(new TypeDescriptor>() {})); @@ -71,7 +66,6 @@ public void testEmpty() { } @Test - @Category(ValidatesRunner.class) public void testFormatCounts() { PCollection> inputs = p.apply(Create.of(KV.of(3, 0L), KV.of(4, Long.MAX_VALUE), KV.of(5, Long.MIN_VALUE))); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java index 8f076a9d8d89..19c83c6eb73c 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java @@ -27,7 +27,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; @@ -42,7 +41,6 @@ import org.joda.time.Instant; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -118,7 +116,6 @@ public void testExtractTotalFlow() { } @Test - @Category(ValidatesRunner.class) public void testTotalFlow() { PCollection> flow = pipeline diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt index 56702a3a1746..514727878a44 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt @@ -39,7 +39,6 @@ class DistinctExampleTest { fun pipeline(): TestPipeline = pipeline @Test - @Category(ValidatesRunner::class) fun testDistinct() { val strings = listOf("k1", "k5", "k5", "k2", "k1", "k2", "k3") val input = pipeline.apply(Create.of(strings).withCoder(StringUtf8Coder.of())) @@ -49,7 +48,6 @@ class DistinctExampleTest { } @Test - @Category(ValidatesRunner::class) fun testDistinctEmpty() { val strings = listOf() val input = pipeline.apply(Create.of(strings).withCoder(StringUtf8Coder.of())) diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt index d5cb544a7606..8cfffe15fc65 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt @@ -64,7 +64,6 @@ class FilterExamplesTest { fun pipeline(): TestPipeline = pipeline @Test - @Category(ValidatesRunner::class) fun testProjectionFn() { val input = pipeline.apply(Create.of(row1, row2, row3)) val results = input.apply(ParDo.of(ProjectionFn())) @@ -73,7 +72,6 @@ class FilterExamplesTest { } @Test - @Category(ValidatesRunner::class) fun testFilterSingleMonthDataFn() { val input = pipeline.apply(Create.of(outRow1, outRow2, outRow3)) val results = input.apply(ParDo.of(FilterSingleMonthDataFn(7))) diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt index 8728a827229b..6bb818f5efa9 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt @@ -93,7 +93,6 @@ class JoinExamplesTest { } @Test - @Category(ValidatesRunner::class) fun testJoin() { val input1 = pipeline.apply("CreateEvent", Create.of(EVENT_ARRAY)) val input2 = pipeline.apply("CreateCC", Create.of(CC_ARRAY)) diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt index 7995d9c1c795..409434e02686 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt @@ -69,7 +69,6 @@ class MaxPerKeyExamplesTest { fun pipeline(): TestPipeline = pipeline @Test - @Category(ValidatesRunner::class) fun testExtractTempFn() { val results = pipeline.apply(Create.of(testRows)).apply(ParDo.of>(MaxPerKeyExamples.ExtractTempFn())) PAssert.that(results).containsInAnyOrder(ImmutableList.of(kv1, kv2, kv3)) @@ -77,7 +76,6 @@ class MaxPerKeyExamplesTest { } @Test - @Category(ValidatesRunner::class) fun testFormatMaxesFn() { val results = pipeline.apply(Create.of(testKvs)).apply(ParDo.of, TableRow>(MaxPerKeyExamples.FormatMaxesFn())) PAssert.that(results).containsInAnyOrder(resultRow1, resultRow2, resultRow3) diff --git a/examples/notebooks/beam-ml/automatic_model_refresh.ipynb b/examples/notebooks/beam-ml/automatic_model_refresh.ipynb index 5b5d2ed484c0..2f80846f313b 100644 --- a/examples/notebooks/beam-ml/automatic_model_refresh.ipynb +++ b/examples/notebooks/beam-ml/automatic_model_refresh.ipynb @@ -1,22 +1,21 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "OsFaZscKSPvo" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -36,22 +35,13 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "OsFaZscKSPvo" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "ZUSiAR62SgO8" + }, "source": [ "# Update ML models in running pipelines\n", "\n", @@ -63,20 +53,13 @@ " View source on GitHub\n", " \n", "\n" - ], - "metadata": { - "id": "ZUSiAR62SgO8" - }, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "tBtqF5UpKJNZ" + }, "source": [ "This notebook demonstrates how to perform automatic model updates without stopping your Apache Beam pipeline.\n", "You can use side inputs to update your model in real time, even while the Apache Beam pipeline is running. The side input is passed in a `ModelHandler` configuration object. You can update the model either by leveraging one of Apache Beam's provided patterns, such as the `WatchFilePattern`, or by configuring a custom side input `PCollection` that defines the logic for the model update.\n", @@ -85,36 +68,19 @@ "For more information about side inputs, see the [Side inputs](https://beam.apache.org/documentation/programming-guide/#side-inputs) section in the Apache Beam Programming Guide.\n", "\n", "This example uses `WatchFilePattern` as a side input. `WatchFilePattern` is used to watch for file updates that match the `file_pattern` based on timestamps. It emits the latest `ModelMetadata`, which is used in the RunInference `PTransform` to automatically update the ML model without stopping the Apache Beam pipeline.\n" - ], - "metadata": { - "id": "tBtqF5UpKJNZ" - }, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "SPuXFowiTpWx" + }, "source": [ "## Before you begin\n", "Install the dependencies required to run this notebook.\n", "\n", "To use RunInference with side inputs for automatic model updates, use Apache Beam version 2.46.0 or later." - ], - "metadata": { - "id": "SPuXFowiTpWx" - }, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "code", @@ -122,25 +88,39 @@ "metadata": { "id": "1RyTYsFEIOlA" }, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ - "!pip install apache_beam[gcp]>=2.46.0 --quiet\n", - "!pip install tensorflow==2.15.0 --quiet\n", - "!pip install tensorflow_hub --quiet" + "!pip install apache_beam[gcp]>=2.46.0 tensorflow==2.15.0 tensorflow_hub==0.16.1 keras==2.15.0 Pillow==11.0.0 --quiet" ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Rs4cwwNrIV9H" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# Imports required for the notebook.\n", "import logging\n", "import time\n", + "import os\n", "from typing import Iterable\n", "from typing import Tuple\n", "\n", @@ -158,21 +138,23 @@ "import numpy\n", "from PIL import Image\n", "import tensorflow as tf" - ], - "metadata": { - "id": "Rs4cwwNrIV9H" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jAKpPcmmGm03" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# Authenticate to your Google Cloud account.\n", "def auth_to_colab():\n", @@ -180,21 +162,13 @@ " auth.authenticate_user()\n", "\n", "auth_to_colab()" - ], - "metadata": { - "id": "jAKpPcmmGm03" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "ORYNKhH3WQyP" + }, "source": [ "## Configure the runner\n", "\n", @@ -204,24 +178,37 @@ "* Configure the pipeline options for the pipeline to run on Dataflow. Make sure the pipeline is using streaming mode.\n", "\n", "In the following code, replace `BUCKET_NAME` with the the name of your Cloud Storage bucket." - ], - "metadata": { - "id": "ORYNKhH3WQyP" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wWjbnq6X-4uE" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "options = PipelineOptions()\n", "options.view_as(StandardOptions).streaming = True\n", "\n", - "BUCKET_NAME = '' # Replace with your bucket name.\n", + "# Replace with your bucket name.\n", + "BUCKET_NAME = '' # @param {type:'string'} \n", + "os.environ['BUCKET_NAME'] = BUCKET_NAME\n", "\n", "# Provide required pipeline options for the Dataflow Runner.\n", "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n", "\n", "# Set the project to the default project in your current Google Cloud environment.\n", - "options.view_as(GoogleCloudOptions).project = ''\n", + "PROJECT_NAME = '' # @param {type:'string'}\n", + "options.view_as(GoogleCloudOptions).project = PROJECT_NAME\n", "\n", "# Set the Google Cloud region that you want to run Dataflow in.\n", "options.view_as(GoogleCloudOptions).region = 'us-central1'\n", @@ -246,113 +233,120 @@ "# To expedite the model update process, it's recommended to set num_workers>1.\n", "# https://github.com/apache/beam/issues/28776\n", "options.view_as(WorkerOptions).num_workers = 5" - ], - "metadata": { - "id": "wWjbnq6X-4uE" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", - "source": [ - "Install the `tensorflow` and `tensorflow_hub` dependencies on Dataflow. Use the `requirements_file` pipeline option to pass these dependencies." - ], "metadata": { "id": "HTJV8pO2Wcw4" - } + }, + "source": [ + "Install the `tensorflow` and `tensorflow_hub` dependencies on Dataflow. Use the `requirements_file` pipeline option to pass these dependencies." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lEy4PkluWbdm" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# In a requirements file, define the dependencies required for the pipeline.\n", - "!printf 'tensorflow==2.15.0\\ntensorflow_hub>=0.10.0\\nPillow>=9.0.0' > ./requirements.txt\n", + "!printf 'tensorflow==2.15.0\\ntensorflow_hub==0.16.1\\nkeras==2.15.0\\nPillow==11.0.0' > ./requirements.txt\n", "# Install the pipeline dependencies on Dataflow.\n", "options.view_as(SetupOptions).requirements_file = './requirements.txt'" - ], - "metadata": { - "id": "lEy4PkluWbdm" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "_AUNH_GJk_NE" + }, "source": [ "## Use the TensorFlow model handler\n", " This example uses `TFModelHandlerTensor` as the model handler and the `resnet_101` model trained on [ImageNet](https://www.image-net.org/).\n", "\n", "\n", "For the Dataflow runner, you need to store the model in a remote location that the Apache Beam pipeline can access. For this example, download the `ResNet101` model, and upload it to the Google Cloud Storage bucket.\n" - ], - "metadata": { - "id": "_AUNH_GJk_NE" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ibkWiwVNvyrn" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "model = tf.keras.applications.resnet.ResNet101()\n", "model.save('resnet101_weights_tf_dim_ordering_tf_kernels.keras')\n", "# After saving the model locally, upload the model to GCS bucket and provide that gcs bucket `URI` as `model_uri` to the `TFModelHandler`\n", - "# Replace `BUCKET_NAME` value with actual bucket name.\n", - "!gsutil cp resnet101_weights_tf_dim_ordering_tf_kernels.keras gs:///dataflow/resnet101_weights_tf_dim_ordering_tf_kernels.keras" - ], - "metadata": { - "id": "ibkWiwVNvyrn" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "!gsutil cp resnet101_weights_tf_dim_ordering_tf_kernels.keras gs://${BUCKET_NAME}/dataflow/resnet101_weights_tf_dim_ordering_tf_kernels.keras" + ] }, { "cell_type": "code", - "source": [ - "model_handler = TFModelHandlerTensor(\n", - " model_uri=dataflow_gcs_location + \"/resnet101_weights_tf_dim_ordering_tf_kernels.keras\")" - ], + "execution_count": null, "metadata": { "id": "kkSnsxwUk-Sp" }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "model_handler = TFModelHandlerTensor(\n", + " model_uri=dataflow_gcs_location + \"/resnet101_weights_tf_dim_ordering_tf_kernels.keras\")" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "tZH0r0sL-if5" + }, "source": [ "## Preprocess images\n", "\n", "Use `preprocess_image` to run the inference, read the image, and convert the image to a TensorFlow tensor." - ], - "metadata": { - "id": "tZH0r0sL-if5" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dU5imgTt-8Ne" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "def preprocess_image(image_name, image_dir):\n", " img = tf.keras.utils.get_file(image_name, image_dir + image_name)\n", @@ -360,21 +354,23 @@ " img = numpy.array(img) / 255.0\n", " img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)\n", " return img_tensor" - ], - "metadata": { - "id": "dU5imgTt-8Ne" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6V5tJxO6-gyt" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "class PostProcessor(beam.DoFn):\n", " \"\"\"Process the PredictionResult to get the predicted label.\n", @@ -389,62 +385,66 @@ " imagenet_labels = numpy.array(open(labels_path).read().splitlines())\n", " predicted_class_name = imagenet_labels[predicted_class]\n", " yield predicted_class_name.title(), element.model_id" - ], - "metadata": { - "id": "6V5tJxO6-gyt" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "code", - "source": [ - "# Define the pipeline object.\n", - "pipeline = beam.Pipeline(options=options)" - ], + "execution_count": null, "metadata": { "id": "GpdKk72O_NXT" }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Define the pipeline object.\n", + "pipeline = beam.Pipeline(options=options)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "elZ53uxc_9Hv" + }, "source": [ "Next, review the pipeline steps and examine the code.\n", "\n", "### Pipeline steps\n" - ], - "metadata": { - "id": "elZ53uxc_9Hv" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "305tkV2sAD-S" + }, "source": [ "1. Create a `PeriodicImpulse` transform, which emits output every `n` seconds. The `PeriodicImpulse` transform generates an infinite sequence of elements with a given runtime interval.\n", "\n", " In this example, `PeriodicImpulse` mimics the Pub/Sub source. Because the inputs in a streaming pipeline arrive in intervals, use `PeriodicImpulse` to output elements at `m` intervals.\n", "To learn more about `PeriodicImpulse`, see the [`PeriodicImpulse` code](https://github.com/apache/beam/blob/9c52e0594d6f0e59cd17ee005acfb41da508e0d5/sdks/python/apache_beam/transforms/periodicsequence.py#L150)." - ], - "metadata": { - "id": "305tkV2sAD-S" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vUFStz66_Tbb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "start_timestamp = time.time() # start timestamp of the periodic impulse\n", "end_timestamp = start_timestamp + 60 * 20 # end timestamp of the periodic impulse (will run for 20 minutes).\n", @@ -457,72 +457,76 @@ " start_timestamp=start_timestamp,\n", " stop_timestamp=end_timestamp,\n", " fire_interval=main_input_fire_interval))" - ], - "metadata": { - "id": "vUFStz66_Tbb" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "8-sal2rFAxP2" + }, "source": [ "2. To read and preprocess the images, use the `preprocess_image` function. This example uses `Cat-with-beanie.jpg` for all inferences.\n", "\n", " **Note**: The image used for prediction is licensed in CC-BY. The creator is listed in the [LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt) file." - ], - "metadata": { - "id": "8-sal2rFAxP2" - } + ] }, { "cell_type": "markdown", - "source": [ - "![download.png]()" - ], "metadata": { "id": "gW4cE8bhXS-d" - } + }, + "source": [ + "![download.png]()" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dGg11TpV_aV6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "image_data = (periodic_impulse | beam.Map(lambda x: \"Cat-with-beanie.jpg\")\n", " | \"ReadImage\" >> beam.Map(lambda image_name: preprocess_image(\n", " image_name=image_name, image_dir='https://storage.googleapis.com/apache-beam-samples/image_captioning/')))" - ], - "metadata": { - "id": "dGg11TpV_aV6" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "eB0-ewd-BCKE" + }, "source": [ "3. Pass the images to the RunInference `PTransform`. RunInference takes `model_handler` and `model_metadata_pcoll` as input parameters.\n", " * `model_metadata_pcoll` is a side input `PCollection` to the RunInference `PTransform`. This side input updates the `model_uri` in the `model_handler` while the Apache Beam pipeline runs.\n", " * Use `WatchFilePattern` as side input to watch a `file_pattern` matching `.keras` files. In this case, the `file_pattern` is `'gs://BUCKET_NAME/dataflow/*keras'`.\n", "\n" - ], - "metadata": { - "id": "eB0-ewd-BCKE" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_AjvvexJ_hUq" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ " # The side input used to watch for the .keras file and update the model_uri of the TFModelHandlerTensor.\n", "file_pattern = dataflow_gcs_location + '/*.keras'\n", @@ -536,108 +540,117 @@ " | \"ApplyWindowing\" >> beam.WindowInto(beam.window.FixedWindows(10))\n", " | \"RunInference\" >> RunInference(model_handler=model_handler,\n", " model_metadata_pcoll=side_input_pcoll))" - ], - "metadata": { - "id": "_AjvvexJ_hUq" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "lTA4wRWNDVis" + }, "source": [ "4. Post-process the `PredictionResult` object.\n", "When the inference is complete, RunInference outputs a `PredictionResult` object that contains the fields `example`, `inference`, and `model_id`. The `model_id` field identifies the model used to run the inference. The `PostProcessor` returns the predicted label and the model ID used to run the inference on the predicted label." - ], - "metadata": { - "id": "lTA4wRWNDVis" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9TB76fo-_vZJ" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "post_processor = (\n", " inferences\n", " | \"PostProcessResults\" >> beam.ParDo(PostProcessor())\n", " | \"LogResults\" >> beam.Map(logging.info))" - ], - "metadata": { - "id": "9TB76fo-_vZJ" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "wYp-mBHHjOjA" + }, "source": [ "### Watch for the model update\n", "\n", "After the pipeline starts processing data, when you see output emitted from the RunInference `PTransform`, upload a `resnet152` model saved in the `.keras` format to a Google Cloud Storage bucket location that matches the `file_pattern` you defined earlier.\n" - ], - "metadata": { - "id": "wYp-mBHHjOjA" - } + ] }, { "cell_type": "code", - "source": [ - "model = tf.keras.applications.resnet.ResNet152()\n", - "model.save('resnet152_weights_tf_dim_ordering_tf_kernels.keras')\n", - "# Replace the `BUCKET_NAME` with the actual bucket name.\n", - "!gsutil cp resnet152_weights_tf_dim_ordering_tf_kernels.keras gs:///resnet152_weights_tf_dim_ordering_tf_kernels.keras" - ], + "execution_count": null, "metadata": { "id": "FpUfNBSWH9Xy" }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "model = tf.keras.applications.resnet.ResNet152()\n", + "model.save('resnet152_weights_tf_dim_ordering_tf_kernels.keras')\n", + "!gsutil cp resnet152_weights_tf_dim_ordering_tf_kernels.keras gs://${BUCKET_NAME}/resnet152_weights_tf_dim_ordering_tf_kernels.keras" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "_ty03jDnKdKR" + }, "source": [ "## Run the pipeline\n", "\n", "Use the following code to run the pipeline." - ], - "metadata": { - "id": "_ty03jDnKdKR" - } + ] }, { "cell_type": "code", - "source": [ - "# Run the pipeline.\n", - "result = pipeline.run().wait_until_finish()" - ], + "execution_count": null, "metadata": { "id": "wd0VJLeLEWBU" }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Run the pipeline.\n", + "result = pipeline.run().wait_until_finish()" + ] } - ] + ], + "metadata": { + "colab": { + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb new file mode 100644 index 000000000000..dedaa6b65a5e --- /dev/null +++ b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb @@ -0,0 +1,771 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "55h6JBJeJGqg" + }, + "outputs": [], + "source": [ + "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", + "\n", + "# Licensed to the Apache Software Foundation (ASF) under one\n", + "# or more contributor license agreements. See the NOTICE file\n", + "# distributed with this work for additional information\n", + "# regarding copyright ownership. The ASF licenses this file\n", + "# to you under the Apache License, Version 2.0 (the\n", + "# \"License\"); you may not use this file except in compliance\n", + "# with the License. You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing,\n", + "# software distributed under the License is distributed on an\n", + "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + "# KIND, either express or implied. See the License for the\n", + "# specific language governing permissions and limitations\n", + "# under the License" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YrOuxMeKJZxC" + }, + "source": [ + "# Use Apache Beam and BigQuery to enrich data\n", + "\n", + "\n", + " \n", + " \n", + "
\n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pf2bL-PmJScZ" + }, + "source": [ + "This notebook shows how to enrich data by using the Apache Beam [enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) with [BigQuery](https://cloud.google.com/bigquery/docs/overview). The enrichment transform is an Apache Beam turnkey transform that lets you enrich data by using a key-value lookup. This transform has the following features:\n", + "\n", + "- The transform has a built-in Apache Beam handler that interacts with BigQuery data during enrichment.\n", + "- The enrichment transform uses client-side throttling to rate limit the requests. The default retry strategy uses exponential backoff. You can configure rate limiting to suit your use case.\n", + "\n", + "This notebook demonstrates the following telecommunications company use case:\n", + "\n", + "A telecom company wants to predict which customers are likely to cancel their subscriptions so that the company can proactively offer these customers incentives to stay. The example uses customer demographic data and usage data stored in BigQuery to enrich a stream of customer IDs. The enriched data is then used to predict the likelihood of customer churn.\n", + "\n", + "## Before you begin\n", + "Set up your environment and download dependencies.\n", + "\n", + "### Install Apache Beam\n", + "To use the enrichment transform with the built-in BigQuery handler, install the Apache Beam SDK version 2.57.0 or later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oVbWf73FJSzf" + }, + "outputs": [], + "source": [ + "!pip install torch\n", + "!pip install apache_beam[interactive,gcp]==2.57.0 --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "siSUsfR5tKX9" + }, + "source": [ + "Import the following modules:\n", + "- Pub/Sub for streaming data\n", + "- BigQuery for enrichment\n", + "- Apache Beam for running the streaming pipeline\n", + "- PyTorch to predict customer churn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "p6bruDqFJkXE" + }, + "outputs": [], + "source": [ + "import datetime\n", + "import json\n", + "import math\n", + "\n", + "from typing import Any\n", + "from typing import Dict\n", + "\n", + "import torch\n", + "from google.cloud import pubsub_v1\n", + "from google.cloud import bigquery\n", + "from google.api_core.exceptions import Conflict\n", + "\n", + "import apache_beam as beam\n", + "import apache_beam.runners.interactive.interactive_beam as ib\n", + "from apache_beam.ml.inference.base import KeyedModelHandler\n", + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n", + "from apache_beam.options import pipeline_options\n", + "from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n", + "from apache_beam.transforms.enrichment import Enrichment\n", + "from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler\n", + "\n", + "import pandas as pd\n", + "\n", + "from sklearn.preprocessing import LabelEncoder" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t0QfhuUlJozO" + }, + "source": [ + "### Authenticate with Google Cloud\n", + "This notebook reads data from Pub/Sub and BigQuery. To use your Google Cloud account, authenticate this notebook.\n", + "To prepare for this step, replace `` with your Google Cloud project ID." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RwoBZjD1JwnD" + }, + "outputs": [], + "source": [ + "PROJECT_ID = \"\" # @param {type:'string'}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rVAyQxoeKflB" + }, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "auth.authenticate_user(project_id=PROJECT_ID)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1vDwknoHKoa-" + }, + "source": [ + "### Set up the BigQuery tables\n", + "\n", + "Create sample BigQuery tables for this notebook.\n", + "\n", + "- Replace `` with the name of your BigQuery dataset. Only letters (uppercase or lowercase), numbers, and underscores are allowed.\n", + "- If the dataset does not exist, a new dataset with this ID is created." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UxeGFqSJu-G6" + }, + "outputs": [], + "source": [ + "DATASET_ID = \"\" # @param {type:'string'}\n", + "\n", + "CUSTOMERS_TABLE_ID = f'{PROJECT_ID}.{DATASET_ID}.customers'\n", + "USAGE_TABLE_ID = f'{PROJECT_ID}.{DATASET_ID}.usage'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Gw4RfZavyfpo" + }, + "source": [ + "Create customer and usage tables, and insert fake data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-QRZC4v0KipK" + }, + "outputs": [], + "source": [ + "client = bigquery.Client(project=PROJECT_ID)\n", + "\n", + "# Create dataset if it does not exist.\n", + "client.create_dataset(bigquery.Dataset(f\"{PROJECT_ID}.{DATASET_ID}\"), exists_ok=True)\n", + "print(f\"Created dataset {DATASET_ID}\")\n", + "\n", + "# Prepare the fake customer data.\n", + "customer_data = {\n", + " 'customer_id': [1, 2, 3, 4, 5],\n", + " 'age': [35, 28, 45, 62, 22],\n", + " 'plan': ['Gold', 'Silver', 'Bronze', 'Gold', 'Silver'],\n", + " 'contract_length': [12, 24, 6, 36, 12]\n", + "}\n", + "\n", + "customers_df = pd.DataFrame(customer_data)\n", + "\n", + "# Insert customer data.\n", + "job_config = bigquery.LoadJobConfig(\n", + " schema=[\n", + " bigquery.SchemaField(\"customer_id\", \"INTEGER\"),\n", + " bigquery.SchemaField(\"age\", \"INTEGER\"),\n", + " bigquery.SchemaField(\"plan\", \"STRING\"),\n", + " bigquery.SchemaField(\"contract_length\", \"INTEGER\"),\n", + " ],\n", + " write_disposition=\"WRITE_TRUNCATE\",\n", + ")\n", + "\n", + "job = client.load_table_from_dataframe(\n", + " customers_df, CUSTOMERS_TABLE_ID, job_config=job_config\n", + ")\n", + "job.result() # Wait for the job to complete.\n", + "print(f\"Customers table created and populated: {CUSTOMERS_TABLE_ID}\")\n", + "\n", + "# Prepare the fake usage data.\n", + "usage_data = {\n", + " 'customer_id': [1, 1, 2, 2, 3, 3, 4, 4, 5, 5],\n", + " 'date': pd.to_datetime(['2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01']),\n", + " 'calls_made': [50, 65, 20, 18, 100, 110, 30, 28, 60, 70],\n", + " 'data_usage_gb': [10, 12, 5, 4, 20, 22, 8, 7, 15, 18]\n", + "}\n", + "usage_df = pd.DataFrame(usage_data)\n", + "\n", + "# Insert usage data.\n", + "job_config = bigquery.LoadJobConfig(\n", + " schema=[\n", + " bigquery.SchemaField(\"customer_id\", \"INTEGER\"),\n", + " bigquery.SchemaField(\"date\", \"DATE\"),\n", + " bigquery.SchemaField(\"calls_made\", \"INTEGER\"),\n", + " bigquery.SchemaField(\"data_usage_gb\", \"FLOAT\"),\n", + " ],\n", + " write_disposition=\"WRITE_TRUNCATE\",\n", + ")\n", + "job = client.load_table_from_dataframe(\n", + " usage_df, USAGE_TABLE_ID, job_config=job_config\n", + ")\n", + "job.result() # Wait for the job to complete.\n", + "\n", + "print(f\"Usage table created and populated: {USAGE_TABLE_ID}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PZCjCzxaLOJt" + }, + "source": [ + "### Train the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R4dIHclDLfIj" + }, + "source": [ + "Create sample data and train a simple model for churn prediction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YoMjdqJ1KxOM" + }, + "outputs": [], + "source": [ + "# Create fake training data\n", + "data = {\n", + " 'customer_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n", + " 'age': [35, 28, 45, 62, 22, 38, 55, 25, 40, 30],\n", + " 'plan': ['Gold', 'Silver', 'Bronze', 'Gold', 'Silver', 'Bronze', 'Gold', 'Silver', 'Bronze', 'Silver'],\n", + " 'contract_length': [12, 24, 6, 36, 12, 18, 30, 12, 24, 18],\n", + " 'avg_monthly_calls': [57.5, 19, 100, 30, 60, 45, 25, 70, 50, 35],\n", + " 'avg_monthly_data_usage_gb': [11, 4.5, 20, 8, 15, 10, 7, 18, 12, 8],\n", + " 'churned': [0, 0, 1, 0, 1, 0, 0, 1, 0, 1] # Target variable\n", + "}\n", + "plan_encoder = LabelEncoder()\n", + "plan_encoder.fit(data['plan'])\n", + "df = pd.DataFrame(data)\n", + "df['plan'] = plan_encoder.transform(data['plan'])\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EgIFJx76MF3v" + }, + "source": [ + "Preprocess the data:\n", + "\n", + "1. Convert the lists to tensors.\n", + "2. Separate the features from the expected prediction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "P-8lKzdzLnGo" + }, + "outputs": [], + "source": [ + "features = ['age', 'plan', 'contract_length', 'avg_monthly_calls', 'avg_monthly_data_usage_gb']\n", + "target = 'churned'\n", + "\n", + "X = torch.tensor(df[features].values, dtype=torch.float)\n", + "Y = torch.tensor(df[target], dtype=torch.float)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4mcNOez1MQZP" + }, + "source": [ + "Define a model that has five input features and predicts a single value." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YvdPNlzoMTtl" + }, + "outputs": [], + "source": [ + "def build_model(n_inputs, n_outputs):\n", + " \"\"\"build_model builds and returns a model that takes\n", + " `n_inputs` features and predicts `n_outputs` value\"\"\"\n", + " return torch.nn.Sequential(\n", + " torch.nn.Linear(n_inputs, 8),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(8, 16),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(16, n_outputs),\n", + " torch.nn.Sigmoid())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GaLBmcvrMOWy" + }, + "source": [ + "Train the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0XqctMiPMaim" + }, + "outputs": [], + "source": [ + "model = build_model(n_inputs=5, n_outputs=1)\n", + "\n", + "loss_fn = torch.nn.BCELoss()\n", + "optimizer = torch.optim.Adam(model.parameters())\n", + "\n", + "for epoch in range(1000):\n", + " print(f'Epoch {epoch}: ---')\n", + " optimizer.zero_grad()\n", + " for i in range(len(X)):\n", + " pred = model(X[i])\n", + " loss = loss_fn(pred, Y[i].unsqueeze(0))\n", + " loss.backward()\n", + " optimizer.step()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "m7MD6RwGMdyU" + }, + "source": [ + "Save the model to the `STATE_DICT_PATH` variable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q9WIjw53MgcR" + }, + "outputs": [], + "source": [ + "STATE_DICT_PATH = './model.pth'\n", + "torch.save(model.state_dict(), STATE_DICT_PATH)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CJVYA0N0MnZS" + }, + "source": [ + "### Publish messages to Pub/Sub\n", + "Create the Pub/Sub topic and subscription to use for data streaming." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0uwZz_ijyzL8" + }, + "outputs": [], + "source": [ + "# Replace with the name of your Pub/Sub topic.\n", + "TOPIC = \"\" # @param {type:'string'}\n", + "\n", + "# Replace with the subscription for your topic.\n", + "SUBSCRIPTION = \"\" # @param {type:'string'}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hIgsCWIozdDu" + }, + "outputs": [], + "source": [ + "from google.api_core.exceptions import AlreadyExists\n", + "\n", + "publisher = pubsub_v1.PublisherClient()\n", + "topic_path = publisher.topic_path(PROJECT_ID, TOPIC)\n", + "try:\n", + " topic = publisher.create_topic(request={\"name\": topic_path})\n", + " print(f\"Created topic: {topic.name}\")\n", + "except AlreadyExists:\n", + " print(f\"Topic {topic_path} already exists.\")\n", + "\n", + "subscriber = pubsub_v1.SubscriberClient()\n", + "subscription_path = subscriber.subscription_path(PROJECT_ID, SUBSCRIPTION)\n", + "try:\n", + " subscription = subscriber.create_subscription(\n", + " request={\"name\": subscription_path, \"topic\": topic_path}\n", + " )\n", + " print(f\"Created subscription: {subscription.name}\")\n", + "except AlreadyExists:\n", + " print(f\"Subscription {subscription_path} already exists.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VqUaFm_yywjU" + }, + "source": [ + "\n", + "Use the Pub/Sub Python client to publish messages." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fOq1uNXvMku-" + }, + "outputs": [], + "source": [ + "messages = [\n", + " {'customer_id': i}\n", + " for i in range(1, 6)\n", + "]\n", + "\n", + "for message in messages:\n", + " data = json.dumps(message).encode('utf-8')\n", + " publish_future = publisher.publish(topic_path, data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "giXOGruKM8ZL" + }, + "source": [ + "## Use the BigQuery enrichment handler\n", + "\n", + "The [`BigQueryEnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigquery.html#apache_beam.transforms.enrichment_handlers.bigquery.BigQueryEnrichmentHandler) is a built-in handler included in the Apache Beam SDK versions 2.57.0 and later.\n", + "\n", + "Configure the `BigQueryEnrichmentHandler` handler with the following parameters.\n", + "\n", + "### Required parameters\n", + "\n", + "The following parameters are required.\n", + "\n", + "* `project` (str): The Google Cloud project ID for the BigQuery table\n", + "\n", + "You must also provide one of the following combinations:\n", + "* `table_name`, `row_restriction_template`, and `fields`\n", + "* `table_name`, `row_restriction_template`, and `condition_value_fn`\n", + "* `query_fn`\n", + "\n", + "### Optional parameters\n", + "\n", + "The following parameters are optional.\n", + "\n", + "* `table_name` (str): The fully qualified BigQuery table name in the format `project.dataset.table`\n", + "* `row_restriction_template` (str): A template string for the `WHERE` clause in the BigQuery query with placeholders (`{}`) to dynamically filter rows based on input data\n", + "* `fields` (Optional[List[str]]): A list of field names present in the input `beam.Row`. These fields names are used to construct the `WHERE` clause if `condition_value_fn` is not provided.\n", + "* `column_names` (Optional[List[str]]): The names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected.\n", + "* `condition_value_fn` (Optional[Callable[[beam.Row], List[Any]]]): A function that takes a `beam.Row` and returns a list of values to populate in the placeholder `{}` of the `WHERE` clause in the query\n", + "* `query_fn` (Optional[Callable[[beam.Row], str]]): A function that takes a `beam.Row` and returns a complete BigQuery SQL query string\n", + "* `min_batch_size` (int): The minimum number of rows to batch together when querying BigQuery. Defaults to `1` if `query_fn` is not specified.\n", + "* `max_batch_size` (int): The maximum number of rows to batch together. Defaults to `10,000` if `query_fn` is not specified.\n", + "\n", + "### Parameter requirements\n", + "\n", + "When you use parameters, consider the following requirements.\n", + "\n", + "* You can't define the `min_batch_size` and `max_batch_size` parameters if you provide the `query_fn` parameter.\n", + "* You must provide either the `fields` parameter or the `condition_value_fn` parameter for query construction if you don't provide the `query_fn` parameter.\n", + "* You must grant the appropriate permissions to access BigQuery.\n", + "\n", + "### Create handlers\n", + "\n", + "In this example, you create two handlers:\n", + "\n", + "* One for customer data that specifies `table_name` and `row_restriction_template`\n", + "* One for usage data that uses a custom aggregation query by using the `query_fn` function\n", + "\n", + "These handlers are used in the Enrichment transforms in this pipeline to fetch and join data from BigQuery with the streaming data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "C8XLmBDeMyrB" + }, + "outputs": [], + "source": [ + "user_data_handler = BigQueryEnrichmentHandler(\n", + " project=PROJECT_ID,\n", + " table_name=f\"`{CUSTOMERS_TABLE_ID}`\",\n", + " row_restriction_template='customer_id = {}',\n", + " fields=['customer_id']\n", + ")\n", + "\n", + "# Define the SQL query for usage data aggregation.\n", + "usage_data_query_template = f\"\"\"\n", + "WITH monthly_aggregates AS (\n", + " SELECT\n", + " customer_id,\n", + " DATE_TRUNC(date, MONTH) as month,\n", + " SUM(calls_made) as total_calls,\n", + " SUM(data_usage_gb) as total_data_usage_gb\n", + " FROM\n", + " `{USAGE_TABLE_ID}`\n", + " WHERE\n", + " customer_id = @customer_id\n", + " GROUP BY\n", + " customer_id, month\n", + ")\n", + "SELECT\n", + " customer_id,\n", + " AVG(total_calls) as avg_monthly_calls,\n", + " AVG(total_data_usage_gb) as avg_monthly_data_usage_gb\n", + "FROM\n", + " monthly_aggregates\n", + "GROUP BY\n", + " customer_id\n", + "\"\"\"\n", + "\n", + "def usage_data_query_fn(row: beam.Row) -> str:\n", + " return usage_data_query_template.replace('@customer_id', str(row.customer_id))\n", + "\n", + "usage_data_handler = BigQueryEnrichmentHandler(\n", + " project=PROJECT_ID,\n", + " query_fn=usage_data_query_fn\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3oPYypvmPiyg" + }, + "source": [ + "In this example:\n", + "1. The `user_data_handler` handler uses the `table_name`, `row_restriction_template`, and `fields` parameter combination to fetch customer data.\n", + "2. The `usage_data_handler` handler uses the `query_fn` parameter to execute a more complex query that aggregates usage data." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ksON9uOBQbZm" + }, + "source": [ + "## Use the `PytorchModelHandlerTensor` interface to run inference\n", + "\n", + "Define functions to convert enriched data to the tensor format for the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XgPontIVP0Cv" + }, + "outputs": [], + "source": [ + "def convert_row_to_tensor(customer_data):\n", + " import pandas as pd\n", + " customer_df = pd.DataFrame([customer_data[1].as_dict()])\n", + " customer_df['plan'] = plan_encoder.transform(customer_df['plan'])\n", + " return (customer_data[0], torch.tensor(customer_df[features].values, dtype=torch.float))\n", + "\n", + "keyed_model_handler = KeyedModelHandler(PytorchModelHandlerTensor(\n", + " state_dict_path=STATE_DICT_PATH,\n", + " model_class=build_model,\n", + " model_params={'n_inputs':5, 'n_outputs':1}\n", + ")).with_preprocess_fn(convert_row_to_tensor)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O9e7ddgGQxh2" + }, + "source": [ + "Define a `DoFn` to format the output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NMj0V5VyQukk" + }, + "outputs": [], + "source": [ + "class PostProcessor(beam.DoFn):\n", + " def process(self, element, *args, **kwargs):\n", + " print('Customer %d churn risk: %s' % (element[0], \"High\" if element[1].inference[0].item() > 0.5 else \"Low\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-N3a1s2FQ66z" + }, + "source": [ + "## Run the pipeline\n", + "\n", + "Configure the pipeline to run in streaming mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rgJeV-jWQ4wo" + }, + "outputs": [], + "source": [ + "options = pipeline_options.PipelineOptions()\n", + "options.view_as(pipeline_options.StandardOptions).streaming = True # Streaming mode is set True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NRljYVR5RCMi" + }, + "source": [ + "Pub/Sub sends the data in bytes. Convert the data to `beam.Row` objects by using a `DoFn`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Bb-e3yjtQ2iU" + }, + "outputs": [], + "source": [ + "class DecodeBytes(beam.DoFn):\n", + " \"\"\"\n", + " The DecodeBytes `DoFn` converts the data read from Pub/Sub to `beam.Row`.\n", + " First, decode the encoded string. Convert the output to\n", + " a `dict` with `json.loads()`, which is used to create a `beam.Row`.\n", + " \"\"\"\n", + " def process(self, element, *args, **kwargs):\n", + " element_dict = json.loads(element.decode('utf-8'))\n", + " yield beam.Row(**element_dict)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q1HV8wH-RIbj" + }, + "source": [ + "Use the following code to run the pipeline.\n", + "\n", + "**Note:** Because this pipeline is a streaming pipeline, you need to manually stop the cell. If you don't stop the cell, the pipeline continues to run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "y6HBH8yoRFp2" + }, + "outputs": [], + "source": [ + "with beam.Pipeline(options=options) as p:\n", + " _ = (p\n", + " | \"Read from Pub/Sub\" >> beam.io.ReadFromPubSub(subscription=f\"projects/{PROJECT_ID}/subscriptions/{SUBSCRIPTION}\")\n", + " | \"ConvertToRow\" >> beam.ParDo(DecodeBytes())\n", + " | \"Enrich with customer data\" >> Enrichment(user_data_handler)\n", + " | \"Enrich with usage data\" >> Enrichment(usage_data_handler)\n", + " | \"Key data\" >> beam.Map(lambda x: (x.customer_id, x))\n", + " | \"RunInference\" >> RunInference(keyed_model_handler)\n", + " | \"Format Output\" >> beam.ParDo(PostProcessor())\n", + " )" + ] + } + ], + "metadata": { + "colab": { + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb index 95be8b1d957c..f2e63d2e4f06 100644 --- a/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb +++ b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb @@ -151,9 +151,9 @@ }, "outputs": [], "source": [ - "PROJECT_ID = \"\"\n", - "INSTANCE_ID = \"\"\n", - "TABLE_ID = \"\"" + "PROJECT_ID = \"\" # @param {type:'string'}\n", + "INSTANCE_ID = \"\" # @param {type:'string'}\n", + "TABLE_ID = \"\" # @param {type:'string'}" ] }, { @@ -457,10 +457,10 @@ "outputs": [], "source": [ "# Replace with the name of your Pub/Sub topic.\n", - "TOPIC = \"\"\n", + "TOPIC = \"\" # @param {type:'string'}\n", "\n", "# Replace with the subscription for your topic.\n", - "SUBSCRIPTION = \"\"\n" + "SUBSCRIPTION = \"\" # @param {type:'string'}\n" ] }, { @@ -532,16 +532,16 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "UEpjy_IsW4P4" + }, "source": [ "The `row_key` parameter represents the field in input schema (`beam.Row`) that contains the row key for a row in the table.\n", "\n", "Starting with Apache Beam version 2.54.0, you can perform either of the following tasks when a table uses composite row keys:\n", "* Modify the input schema to contain the row key in the format required by Bigtable.\n", "* Use a custom enrichment handler. For more information, see the [example handler with composite row key support](https://www.toptal.com/developers/paste-gd/BYFGUL08#)." - ], - "metadata": { - "id": "UEpjy_IsW4P4" - } + ] }, { "cell_type": "code", @@ -636,6 +636,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "fe3bIclV1jZ5" + }, "source": [ "To provide a `lambda` function for using a custom join with the enrichment transform, see the following example.\n", "\n", @@ -648,13 +651,13 @@ " ...\n", " )\n", "```" - ], - "metadata": { - "id": "fe3bIclV1jZ5" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "uilxdknE3ihO" + }, "source": [ "Because the enrichment transform makes API calls to the remote service, use the `timeout` parameter to specify a timeout duration of 10 seconds:\n", "\n", @@ -667,10 +670,7 @@ " ...\n", " )\n", "```" - ], - "metadata": { - "id": "uilxdknE3ihO" - } + ] }, { "cell_type": "markdown", @@ -855,11 +855,11 @@ ], "metadata": { "colab": { - "provenance": [], - "toc_visible": true, "collapsed_sections": [ "RpqZFfFfA_Dt" - ] + ], + "provenance": [], + "toc_visible": true }, "kernelspec": { "display_name": "Python 3", diff --git a/examples/notebooks/beam-ml/data_preprocessing/vertex_ai_text_embeddings.ipynb b/examples/notebooks/beam-ml/data_preprocessing/vertex_ai_text_embeddings.ipynb index 49e2f35b13be..4d816ef97fb0 100644 --- a/examples/notebooks/beam-ml/data_preprocessing/vertex_ai_text_embeddings.ipynb +++ b/examples/notebooks/beam-ml/data_preprocessing/vertex_ai_text_embeddings.ipynb @@ -1,18 +1,4 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "code", @@ -44,6 +30,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "ZUSiAR62SgO8" + }, "source": [ "# Generate text embeddings by using the Vertex AI API\n", "\n", @@ -55,13 +44,13 @@ " View source on GitHub\n", " \n", "\n" - ], - "metadata": { - "id": "ZUSiAR62SgO8" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "bkpSCGCWlqAf" + }, "source": [ "Text embeddings are a way to represent text as numerical vectors. This process lets computers understand and process text data, which is essential for many natural language processing (NLP) tasks.\n", "\n", @@ -84,71 +73,72 @@ "* Do one of the following tasks:\n", " * Configure credentials for your Google Cloud project. For more information, see [Google Auth Library for Python](https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth).\n", " * Store the path to a service account JSON file by using the [GOOGLE_APPLICATION_CREDENTIALS](https://cloud.google.com/docs/authentication/application-default-credentials#GAC) environment variable." - ], - "metadata": { - "id": "bkpSCGCWlqAf" - } + ] }, { "cell_type": "markdown", - "source": [ - "To use your Google Cloud account, authenticate this notebook." - ], "metadata": { "id": "W29FgO5Qv2ew" - } + }, + "source": [ + "To use your Google Cloud account, authenticate this notebook." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nYyyGYt3licq" + }, + "outputs": [], "source": [ "from google.colab import auth\n", "auth.authenticate_user()\n", "\n", - "project = '' # Replace with a valid Google Cloud project ID." - ], - "metadata": { - "id": "nYyyGYt3licq" - }, - "execution_count": null, - "outputs": [] + "# Replace with a valid Google Cloud project ID.\n", + "project = '' # @param {type:'string'}" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "UQROd16ZDN5y" + }, "source": [ "## Install dependencies\n", " Install Apache Beam and the dependencies required for the Vertex AI text-embeddings API." - ], - "metadata": { - "id": "UQROd16ZDN5y" - } + ] }, { "cell_type": "code", - "source": [ - "! pip install apache_beam[gcp]>=2.53.0 --quiet" - ], + "execution_count": null, "metadata": { "id": "BTxob7d5DLBM" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "! pip install apache_beam[gcp]>=2.53.0 --quiet" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SkMhR7H6n1P0" + }, + "outputs": [], "source": [ "import tempfile\n", "import apache_beam as beam\n", "from apache_beam.ml.transforms.base import MLTransform\n", "from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAITextEmbeddings" - ], - "metadata": { - "id": "SkMhR7H6n1P0" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "cokOaX2kzyke" + }, "source": [ "## Transform the data\n", "\n", @@ -156,25 +146,27 @@ "\n", "### Use MLTransform in write mode\n", "\n", - "In `write` mode, `MLTransform` saves the transforms and their attributes to an artifact location. Then, when you run `MLTransform` in `read` mode, these transforms are used. This process ensures that you're applying the same preprocessing steps when you train your model and when you serve the model in production or test its accuracy." - ], - "metadata": { - "id": "cokOaX2kzyke" - } + "In `write` mode, `MLTransform` saves the transforms and their attributes to an artifact location. Then, when you run `MLTransform` in `read` mode, these transforms are used. This process ensures that you're applying the same preprocessing steps when you train your model and when you serve the model in production or test its accuracy." + ] }, { "cell_type": "markdown", + "metadata": { + "id": "-x7fVvuy-aDs" + }, "source": [ "### Get the data\n", "\n", "`MLTransform` processes dictionaries that include column names and their associated text data. To generate embeddings for specific columns, specify these column names in the `columns` argument of `VertexAITextEmbeddings`. This transform uses the the Vertex AI text-embeddings API for online predictions to generate an embeddings vector for each sentence." - ], - "metadata": { - "id": "-x7fVvuy-aDs" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "be-vR159pylF" + }, + "outputs": [], "source": [ "artifact_location = tempfile.mkdtemp(prefix='vertex_ai')\n", "\n", @@ -201,32 +193,11 @@ " for key in d.keys():\n", " d[key] = d[key][:10]\n", " return d" - ], - "metadata": { - "id": "be-vR159pylF" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "embedding_transform = VertexAITextEmbeddings(\n", - " model_name=text_embedding_model_name, columns=['x'], project=project)\n", - "\n", - "with beam.Pipeline() as pipeline:\n", - " data_pcoll = (\n", - " pipeline\n", - " | \"CreateData\" >> beam.Create(content))\n", - " transformed_pcoll = (\n", - " data_pcoll\n", - " | \"MLTransform\" >> MLTransform(write_artifact_location=artifact_location).with_transform(embedding_transform))\n", - "\n", - " # Show only the first ten elements of the embeddings to prevent clutter in the output.\n", - " transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)\n", - "\n", - " transformed_pcoll | \"PrintEmbeddingShape\" >> beam.Map(lambda x: print(f\"Embedding shape: {len(x['x'])}\"))" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -234,11 +205,10 @@ "id": "UQGm1be3p7lM", "outputId": "b41172ca-1c73-4952-ca87-bfe45ca88a6c" }, - "execution_count": null, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "{'x': [0.041293490678071976, -0.010302993468940258, -0.048611514270305634, -0.01360565796494484, 0.06441926211118698, 0.022573700174689293, 0.016446372494101524, -0.033894773572683334, 0.004581860266625881, 0.060710687190294266]}\n", "Embedding shape: 10\n", @@ -248,23 +218,58 @@ "Embedding shape: 10\n" ] } + ], + "source": [ + "embedding_transform = VertexAITextEmbeddings(\n", + " model_name=text_embedding_model_name, columns=['x'], project=project)\n", + "\n", + "with beam.Pipeline() as pipeline:\n", + " data_pcoll = (\n", + " pipeline\n", + " | \"CreateData\" >> beam.Create(content))\n", + " transformed_pcoll = (\n", + " data_pcoll\n", + " | \"MLTransform\" >> MLTransform(write_artifact_location=artifact_location).with_transform(embedding_transform))\n", + "\n", + " # Show only the first ten elements of the embeddings to prevent clutter in the output.\n", + " transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)\n", + "\n", + " transformed_pcoll | \"PrintEmbeddingShape\" >> beam.Map(lambda x: print(f\"Embedding shape: {len(x['x'])}\"))" ] }, { "cell_type": "markdown", + "metadata": { + "id": "JLkmQkiLx_6h" + }, "source": [ "### Use MLTransform in read mode\n", "\n", "In `read` mode, `MLTransform` uses the artifacts saved during `write` mode. In this example, the transform and its attributes are loaded from the saved artifacts. You don't need to specify artifacts again during `read` mode.\n", "\n", "In this way, `MLTransform` provides consistent preprocessing steps for training and inference workloads." - ], - "metadata": { - "id": "JLkmQkiLx_6h" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "r8Y5vgfLx_Xu", + "outputId": "e7cbf6b7-5c31-4efa-90cf-7a8a108ecc77" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'x': [0.04782044142484665, -0.010078949853777885, -0.05793016776442528, -0.026060665026307106, 0.05756739526987076, 0.02292264811694622, 0.014818413183093071, -0.03718176111578941, -0.005486017093062401, 0.04709304869174957]}\n", + "{'x': [0.042911216616630554, -0.007554919924587011, -0.08996245265007019, -0.02607591263949871, 0.0008614308317191899, -0.023671219125390053, 0.03999944031238556, -0.02983051724731922, -0.015057179145514965, 0.022963201627135277]}\n" + ] + } + ], "source": [ "test_content = [\n", " {\n", @@ -284,25 +289,21 @@ " | \"MLTransform\" >> MLTransform(read_artifact_location=artifact_location))\n", "\n", " transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)\n" - ], - "metadata": { - "id": "r8Y5vgfLx_Xu", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "e7cbf6b7-5c31-4efa-90cf-7a8a108ecc77" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'x': [0.04782044142484665, -0.010078949853777885, -0.05793016776442528, -0.026060665026307106, 0.05756739526987076, 0.02292264811694622, 0.014818413183093071, -0.03718176111578941, -0.005486017093062401, 0.04709304869174957]}\n", - "{'x': [0.042911216616630554, -0.007554919924587011, -0.08996245265007019, -0.02607591263949871, 0.0008614308317191899, -0.023671219125390053, 0.03999944031238556, -0.02983051724731922, -0.015057179145514965, 0.022963201627135277]}\n" - ] - } ] } - ] + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb b/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb index a488caf7d3ac..3af7455222a9 100644 --- a/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb +++ b/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb @@ -2,6 +2,12 @@ "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "sARMhsXz8yR1" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -21,16 +27,13 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "sARMhsXz8yR1" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "A8xNRyZMW1yK" + }, "source": [ "# Preprocessing with the Apache Beam DataFrames API\n", "\n", @@ -44,13 +47,13 @@ " View source on GitHub\n", " \n", "\n" - ], - "metadata": { - "id": "A8xNRyZMW1yK" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "iFZC1inKuUCy" + }, "source": [ "For rapid execution, Pandas loads all of the data into memory on a single machine (one node). This configuration works well when dealing with small-scale datasets. However, many projects involve datasets that are too big to fit in memory. These use cases generally require parallel data processing frameworks, such as Apache Beam.\n", "\n", @@ -71,21 +74,18 @@ "\n", "In this example, the first section demonstrates how to build and execute a pipeline locally using the interactive runner.\n", "The second section uses a distributed runner to demonstrate how to run the pipeline on the full dataset.\n" - ], - "metadata": { - "id": "iFZC1inKuUCy" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "A0f2HJ22D4lt" + }, "source": [ "## Install Apache Beam\n", "\n", "To explore the elements within a `PCollection`, install Apache Beam with the `interactive` component to use the Interactive runner. The DataFrames API methods invoked in this example are available in Apache Beam SDK versions 2.43 and later.\n" - ], - "metadata": { - "id": "A0f2HJ22D4lt" - } + ] }, { "cell_type": "markdown", @@ -100,8 +100,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "-OJC0Xn5Um-C", - "beam:comment": "TODO(https://github.com/apache/issues/23961): Just install 2.43.0 once it's released, [`issue 23276`](https://github.com/apache/beam/issues/23276) is currently not implemented for Beam 2.42 (required fix for implementing `str.get_dummies()`" + "beam:comment": "TODO(https://github.com/apache/issues/23961): Just install 2.43.0 once it's released, [`issue 23276`](https://github.com/apache/beam/issues/23276) is currently not implemented for Beam 2.42 (required fix for implementing `str.get_dummies()`", + "id": "-OJC0Xn5Um-C" }, "outputs": [], "source": [ @@ -114,6 +114,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "3NO6RgB7GkkE" + }, "source": [ "## Local exploration with the Interactive Beam runner\n", "Use the [Interactive Beam](https://beam.apache.org/releases/pydoc/2.20.0/apache_beam.runners.interactive.interactive_beam.html) runner to explore and develop your pipeline.\n", @@ -121,10 +124,7 @@ "\n", "\n", "This section uses a subset of the original dataset, because the notebook instance has limited compute resources.\n" - ], - "metadata": { - "id": "3NO6RgB7GkkE" - } + ] }, { "cell_type": "markdown", @@ -186,13 +186,13 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "cvAu5T0ENjuQ" + }, "source": [ "\n", "Inspect the dataset columns and their types." - ], - "metadata": { - "id": "cvAu5T0ENjuQ" - } + ] }, { "cell_type": "code", @@ -206,7 +206,6 @@ }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "spk_id int64\n", @@ -225,8 +224,9 @@ "dtype: object" ] }, + "execution_count": 27, "metadata": {}, - "execution_count": 27 + "output_type": "execute_result" } ], "source": [ @@ -235,12 +235,12 @@ }, { "cell_type": "markdown", - "source": [ - "When using Interactive Beam, to bring a Beam DataFrame into local memory as a Pandas DataFrame, use `ib.collect()`." - ], "metadata": { "id": "1Wa6fpbyQige" - } + }, + "source": [ + "When using Interactive Beam, to bring a Beam DataFrame into local memory as a Pandas DataFrame, use `ib.collect()`." + ] }, { "cell_type": "code", @@ -255,11 +255,7 @@ }, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -268,101 +264,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_79206f341d7de09f6cacdd05be309575\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_79206f341d7de09f6cacdd05be309575\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_79206f341d7de09f6cacdd05be309575\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_79206f341d7de09f6cacdd05be309575\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " spk_id full_name near_earth_object \\\n", - "0 2000001 1 Ceres N \n", - "1 2000002 2 Pallas N \n", - "2 2000003 3 Juno N \n", - "3 2000004 4 Vesta N \n", - "4 2000005 5 Astraea N \n", - "... ... ... ... \n", - "9994 2009995 9995 Alouette (4805 P-L) N \n", - "9995 2009996 9996 ANS (9070 P-L) N \n", - "9996 2009997 9997 COBE (1217 T-1) N \n", - "9997 2009998 9998 ISO (1293 T-1) N \n", - "9998 2009999 9999 Wiles (4196 T-2) N \n", - "\n", - " absolute_magnitude diameter albedo diameter_sigma eccentricity \\\n", - "0 3.40 939.400 0.0900 0.200 0.076009 \n", - "1 4.20 545.000 0.1010 18.000 0.229972 \n", - "2 5.33 246.596 0.2140 10.594 0.256936 \n", - "3 3.00 525.400 0.4228 0.200 0.088721 \n", - "4 6.90 106.699 0.2740 3.140 0.190913 \n", - "... ... ... ... ... ... \n", - "9994 15.10 2.564 0.2450 0.550 0.160610 \n", - "9995 13.60 8.978 0.1130 0.376 0.235174 \n", - "9996 14.30 NaN NaN NaN 0.113059 \n", - "9997 15.10 2.235 0.3880 0.373 0.093852 \n", - "9998 13.00 7.148 0.2620 0.065 0.071351 \n", - "\n", - " inclination moid_ld object_class semi_major_axis_au_unit \\\n", - "0 10.594067 620.640533 MBA 2.769165 \n", - "1 34.832932 480.348639 MBA 2.773841 \n", - "2 12.991043 402.514639 MBA 2.668285 \n", - "3 7.141771 443.451432 MBA 2.361418 \n", - "4 5.367427 426.433027 MBA 2.574037 \n", - "... ... ... ... ... \n", - "9994 2.311731 388.723233 MBA 2.390249 \n", - "9995 7.657713 444.194746 MBA 2.796605 \n", - "9996 2.459643 495.460110 MBA 2.545674 \n", - "9997 3.912263 373.848377 MBA 2.160961 \n", - "9998 3.198839 632.144398 MBA 2.839917 \n", - "\n", - " hazardous_flag \n", - "0 N \n", - "1 N \n", - "2 N \n", - "3 N \n", - "4 N \n", - "... ... \n", - "9994 N \n", - "9995 N \n", - "9996 N \n", - "9997 N \n", - "9998 N \n", - "\n", - "[9999 rows x 13 columns]" - ], "text/html": [ "\n", "

\n", @@ -657,10 +575,66 @@ "
\n", " \n", " " + ], + "text/plain": [ + " spk_id full_name near_earth_object \\\n", + "0 2000001 1 Ceres N \n", + "1 2000002 2 Pallas N \n", + "2 2000003 3 Juno N \n", + "3 2000004 4 Vesta N \n", + "4 2000005 5 Astraea N \n", + "... ... ... ... \n", + "9994 2009995 9995 Alouette (4805 P-L) N \n", + "9995 2009996 9996 ANS (9070 P-L) N \n", + "9996 2009997 9997 COBE (1217 T-1) N \n", + "9997 2009998 9998 ISO (1293 T-1) N \n", + "9998 2009999 9999 Wiles (4196 T-2) N \n", + "\n", + " absolute_magnitude diameter albedo diameter_sigma eccentricity \\\n", + "0 3.40 939.400 0.0900 0.200 0.076009 \n", + "1 4.20 545.000 0.1010 18.000 0.229972 \n", + "2 5.33 246.596 0.2140 10.594 0.256936 \n", + "3 3.00 525.400 0.4228 0.200 0.088721 \n", + "4 6.90 106.699 0.2740 3.140 0.190913 \n", + "... ... ... ... ... ... \n", + "9994 15.10 2.564 0.2450 0.550 0.160610 \n", + "9995 13.60 8.978 0.1130 0.376 0.235174 \n", + "9996 14.30 NaN NaN NaN 0.113059 \n", + "9997 15.10 2.235 0.3880 0.373 0.093852 \n", + "9998 13.00 7.148 0.2620 0.065 0.071351 \n", + "\n", + " inclination moid_ld object_class semi_major_axis_au_unit \\\n", + "0 10.594067 620.640533 MBA 2.769165 \n", + "1 34.832932 480.348639 MBA 2.773841 \n", + "2 12.991043 402.514639 MBA 2.668285 \n", + "3 7.141771 443.451432 MBA 2.361418 \n", + "4 5.367427 426.433027 MBA 2.574037 \n", + "... ... ... ... ... \n", + "9994 2.311731 388.723233 MBA 2.390249 \n", + "9995 7.657713 444.194746 MBA 2.796605 \n", + "9996 2.459643 495.460110 MBA 2.545674 \n", + "9997 3.912263 373.848377 MBA 2.160961 \n", + "9998 3.198839 632.144398 MBA 2.839917 \n", + "\n", + " hazardous_flag \n", + "0 N \n", + "1 N \n", + "2 N \n", + "3 N \n", + "4 N \n", + "... ... \n", + "9994 N \n", + "9995 N \n", + "9996 N \n", + "9997 N \n", + "9998 N \n", + "\n", + "[9999 rows x 13 columns]" ] }, + "execution_count": 28, "metadata": {}, - "execution_count": 28 + "output_type": "execute_result" } ], "source": [ @@ -669,34 +643,29 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "8jV9odKhNyF2" + }, "source": [ "The datasets contain the following two types of columns:\n", "\n", "* **Numerical columns:** Use [normalization](https://developers.google.com/machine-learning/data-prep/transform/normalization) to transform these columns so that they can be used to train a machine learning model.\n", "\n", "* **Categorical columns:** Transform those columns with [one-hot encoding](https://developers.google.com/machine-learning/data-prep/transform/transform-categorical) to use them during training. \n" - ], - "metadata": { - "id": "8jV9odKhNyF2" - } + ] }, { "cell_type": "markdown", - "source": [ - "Use the standard pandas command `DataFrame.describe()` to generate descriptive statistics for the numerical columns, such as percentile, mean, std, and so on. " - ], "metadata": { "id": "MGAErO0lAYws" - } + }, + "source": [ + "Use the standard pandas command `DataFrame.describe()` to generate descriptive statistics for the numerical columns, such as percentile, mean, std, and so on. " + ] }, { "cell_type": "code", - "source": [ - "with dataframe.allow_non_parallel_operations():\n", - " beam_df_description = ib.collect(beam_df.describe())\n", - "\n", - "beam_df_description" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -705,14 +674,9 @@ "id": "Befv697VBGM7", "outputId": "bb465020-94e4-4b3c-fda6-6e43da199be1" }, - "execution_count": null, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -721,77 +685,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_98687cb0060a8077a8abab6e464e4a75\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_98687cb0060a8077a8abab6e464e4a75\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_98687cb0060a8077a8abab6e464e4a75\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_98687cb0060a8077a8abab6e464e4a75\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " spk_id absolute_magnitude diameter albedo \\\n", - "count 9.999000e+03 9999.000000 8688.000000 8672.000000 \n", - "mean 2.005000e+06 12.675380 19.245446 0.197723 \n", - "std 2.886607e+03 1.639609 30.190191 0.138819 \n", - "min 2.000001e+06 3.000000 0.300000 0.008000 \n", - "25% 2.002500e+06 11.900000 5.614000 0.074000 \n", - "50% 2.005000e+06 12.900000 9.814000 0.187000 \n", - "75% 2.007500e+06 13.700000 19.156750 0.283000 \n", - "max 2.009999e+06 20.700000 939.400000 1.000000 \n", - "\n", - " diameter_sigma eccentricity inclination moid_ld \\\n", - "count 8591.000000 9999.000000 9999.000000 9999.000000 \n", - "mean 0.454072 0.148716 7.890742 509.805237 \n", - "std 1.093676 0.083803 6.336244 205.046582 \n", - "min 0.006000 0.001003 0.042716 0.131028 \n", - "25% 0.120000 0.093780 3.220137 377.829197 \n", - "50% 0.201000 0.140335 6.018836 470.650523 \n", - "75% 0.375000 0.187092 10.918176 636.010802 \n", - "max 39.297000 0.889831 68.018875 4241.524913 \n", - "\n", - " semi_major_axis_au_unit \n", - "count 9999.000000 \n", - "mean 2.689836 \n", - "std 0.607190 \n", - "min 0.832048 \n", - "25% 2.340816 \n", - "50% 2.614468 \n", - "75% 3.005449 \n", - "max 24.667968 " - ], "text/html": [ "\n", "
\n", @@ -1001,27 +911,65 @@ "
\n", " \n", " " - ] - }, - "metadata": {}, - "execution_count": 21 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D9uJtHLSSAMC" - }, - "source": [ - "Before running any transformations, verify that all of the columns need to be used for model training. Start by looking at the column description provided by the [JPL website](https://ssd.jpl.nasa.gov/sbdb_query.cgi):\n", - "\n", - "* **spk_id:** Object primary SPK-ID.\n", - "* **full_name:** Asteroid name.\n", - "* **near_earth_object:** Near-earth object flag.\n", - "* **absolute_magnitude:** The apparent magnitude an object would have if it were located at a distance of 10 parsecs.\n", - "* **diameter:** Object diameter (from equivalent sphere) km unit.\n", - "* **albedo:** A measure of the diffuse reflection of solar radiation out of the total solar radiation, measured on a scale from 0 to 1.\n", + ], + "text/plain": [ + " spk_id absolute_magnitude diameter albedo \\\n", + "count 9.999000e+03 9999.000000 8688.000000 8672.000000 \n", + "mean 2.005000e+06 12.675380 19.245446 0.197723 \n", + "std 2.886607e+03 1.639609 30.190191 0.138819 \n", + "min 2.000001e+06 3.000000 0.300000 0.008000 \n", + "25% 2.002500e+06 11.900000 5.614000 0.074000 \n", + "50% 2.005000e+06 12.900000 9.814000 0.187000 \n", + "75% 2.007500e+06 13.700000 19.156750 0.283000 \n", + "max 2.009999e+06 20.700000 939.400000 1.000000 \n", + "\n", + " diameter_sigma eccentricity inclination moid_ld \\\n", + "count 8591.000000 9999.000000 9999.000000 9999.000000 \n", + "mean 0.454072 0.148716 7.890742 509.805237 \n", + "std 1.093676 0.083803 6.336244 205.046582 \n", + "min 0.006000 0.001003 0.042716 0.131028 \n", + "25% 0.120000 0.093780 3.220137 377.829197 \n", + "50% 0.201000 0.140335 6.018836 470.650523 \n", + "75% 0.375000 0.187092 10.918176 636.010802 \n", + "max 39.297000 0.889831 68.018875 4241.524913 \n", + "\n", + " semi_major_axis_au_unit \n", + "count 9999.000000 \n", + "mean 2.689836 \n", + "std 0.607190 \n", + "min 0.832048 \n", + "25% 2.340816 \n", + "50% 2.614468 \n", + "75% 3.005449 \n", + "max 24.667968 " + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with dataframe.allow_non_parallel_operations():\n", + " beam_df_description = ib.collect(beam_df.describe())\n", + "\n", + "beam_df_description" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D9uJtHLSSAMC" + }, + "source": [ + "Before running any transformations, verify that all of the columns need to be used for model training. Start by looking at the column description provided by the [JPL website](https://ssd.jpl.nasa.gov/sbdb_query.cgi):\n", + "\n", + "* **spk_id:** Object primary SPK-ID.\n", + "* **full_name:** Asteroid name.\n", + "* **near_earth_object:** Near-earth object flag.\n", + "* **absolute_magnitude:** The apparent magnitude an object would have if it were located at a distance of 10 parsecs.\n", + "* **diameter:** Object diameter (from equivalent sphere) km unit.\n", + "* **albedo:** A measure of the diffuse reflection of solar radiation out of the total solar radiation, measured on a scale from 0 to 1.\n", "* **diameter_sigma:** 1-sigma uncertainty in object diameter km unit.\n", "* **eccentricity:** A value between 0 and 1 that refers to how flat or round the asteroid is.\n", "* **inclination:** The angle with respect to the x-y ecliptic plane.\n", @@ -1073,19 +1021,15 @@ }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/content/beam/sdks/python/apache_beam/dataframe/frame_base.py:145: RuntimeWarning: invalid value encountered in long_scalars\n", " lambda left, right: getattr(left, op)(right), name=op, args=[other])\n" ] }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1094,45 +1038,22 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_868f8ad001ab00c7013b65472a513917\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_868f8ad001ab00c7013b65472a513917\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_868f8ad001ab00c7013b65472a513917\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_868f8ad001ab00c7013b65472a513917\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { "text/plain": [ "near_earth_object 0.000000\n", @@ -1149,8 +1070,9 @@ "dtype: float64" ] }, + "execution_count": 30, "metadata": {}, - "execution_count": 30 + "output_type": "execute_result" } ], "source": [ @@ -1170,20 +1092,16 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "tHYeCHREwvyB", "colab": { "base_uri": "https://localhost:8080/", "height": 538 }, + "id": "tHYeCHREwvyB", "outputId": "3be686d0-f56a-4054-a71a-d3019bf379e8" }, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1192,75 +1110,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_f88b77f183371d1a45fa87bed4a545f6\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_f88b77f183371d1a45fa87bed4a545f6\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_f88b77f183371d1a45fa87bed4a545f6\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_f88b77f183371d1a45fa87bed4a545f6\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " near_earth_object absolute_magnitude eccentricity inclination \\\n", - "0 N 3.40 0.076009 10.594067 \n", - "1 N 4.20 0.229972 34.832932 \n", - "2 N 5.33 0.256936 12.991043 \n", - "3 N 3.00 0.088721 7.141771 \n", - "4 N 6.90 0.190913 5.367427 \n", - "... ... ... ... ... \n", - "9994 N 15.10 0.160610 2.311731 \n", - "9995 N 13.60 0.235174 7.657713 \n", - "9996 N 14.30 0.113059 2.459643 \n", - "9997 N 15.10 0.093852 3.912263 \n", - "9998 N 13.00 0.071351 3.198839 \n", - "\n", - " moid_ld object_class semi_major_axis_au_unit hazardous_flag \n", - "0 620.640533 MBA 2.769165 N \n", - "1 480.348639 MBA 2.773841 N \n", - "2 402.514639 MBA 2.668285 N \n", - "3 443.451432 MBA 2.361418 N \n", - "4 426.433027 MBA 2.574037 N \n", - "... ... ... ... ... \n", - "9994 388.723233 MBA 2.390249 N \n", - "9995 444.194746 MBA 2.796605 N \n", - "9996 495.460110 MBA 2.545674 N \n", - "9997 373.848377 MBA 2.160961 N \n", - "9998 632.144398 MBA 2.839917 N \n", - "\n", - "[9999 rows x 8 columns]" - ], "text/html": [ "\n", "
\n", @@ -1495,10 +1361,40 @@ "
\n", " \n", " " + ], + "text/plain": [ + " near_earth_object absolute_magnitude eccentricity inclination \\\n", + "0 N 3.40 0.076009 10.594067 \n", + "1 N 4.20 0.229972 34.832932 \n", + "2 N 5.33 0.256936 12.991043 \n", + "3 N 3.00 0.088721 7.141771 \n", + "4 N 6.90 0.190913 5.367427 \n", + "... ... ... ... ... \n", + "9994 N 15.10 0.160610 2.311731 \n", + "9995 N 13.60 0.235174 7.657713 \n", + "9996 N 14.30 0.113059 2.459643 \n", + "9997 N 15.10 0.093852 3.912263 \n", + "9998 N 13.00 0.071351 3.198839 \n", + "\n", + " moid_ld object_class semi_major_axis_au_unit hazardous_flag \n", + "0 620.640533 MBA 2.769165 N \n", + "1 480.348639 MBA 2.773841 N \n", + "2 402.514639 MBA 2.668285 N \n", + "3 443.451432 MBA 2.361418 N \n", + "4 426.433027 MBA 2.574037 N \n", + "... ... ... ... ... \n", + "9994 388.723233 MBA 2.390249 N \n", + "9995 444.194746 MBA 2.796605 N \n", + "9996 495.460110 MBA 2.545674 N \n", + "9997 373.848377 MBA 2.160961 N \n", + "9998 632.144398 MBA 2.839917 N \n", + "\n", + "[9999 rows x 8 columns]" ] }, + "execution_count": 31, "metadata": {}, - "execution_count": 31 + "output_type": "execute_result" } ], "source": [ @@ -1559,19 +1455,15 @@ }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/content/beam/sdks/python/apache_beam/dataframe/frame_base.py:145: RuntimeWarning: invalid value encountered in double_scalars\n", " lambda left, right: getattr(left, op)(right), name=op, args=[other])\n" ] }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1580,75 +1472,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_55302fa5950ce6ceb9f99ff9a168097a\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_55302fa5950ce6ceb9f99ff9a168097a\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_55302fa5950ce6ceb9f99ff9a168097a\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_55302fa5950ce6ceb9f99ff9a168097a\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " absolute_magnitude eccentricity inclination moid_ld \\\n", - "306 -1.570727 -0.062543 -0.278518 0.373194 \n", - "310 -1.631718 -1.724526 -0.736389 1.087833 \n", - "546 -1.753698 1.028793 1.415303 -0.339489 \n", - "635 -1.875678 0.244869 0.005905 0.214107 \n", - "701 -3.278451 -1.570523 2.006145 1.542754 \n", - "... ... ... ... ... \n", - "9697 0.807888 -1.151809 -0.082944 -0.129556 \n", - "9813 1.722740 0.844551 -0.583247 -1.006447 \n", - "9868 0.807888 -0.207399 -0.784665 -0.462136 \n", - "9903 0.868878 0.460086 0.092258 -0.107597 \n", - "9956 0.746898 -0.234132 -0.161116 -0.601379 \n", - "\n", - " semi_major_axis_au_unit \n", - "306 0.357201 \n", - "310 0.344233 \n", - "546 0.139080 \n", - "635 0.367559 \n", - "701 0.829337 \n", - "... ... \n", - "9697 -0.533538 \n", - "9813 -0.677961 \n", - "9868 -0.539794 \n", - "9903 0.071794 \n", - "9956 -0.664887 \n", - "\n", - "[9999 rows x 5 columns]" - ], "text/html": [ "\n", "
\n", @@ -1847,10 +1687,40 @@ "
\n", " \n", " " + ], + "text/plain": [ + " absolute_magnitude eccentricity inclination moid_ld \\\n", + "306 -1.570727 -0.062543 -0.278518 0.373194 \n", + "310 -1.631718 -1.724526 -0.736389 1.087833 \n", + "546 -1.753698 1.028793 1.415303 -0.339489 \n", + "635 -1.875678 0.244869 0.005905 0.214107 \n", + "701 -3.278451 -1.570523 2.006145 1.542754 \n", + "... ... ... ... ... \n", + "9697 0.807888 -1.151809 -0.082944 -0.129556 \n", + "9813 1.722740 0.844551 -0.583247 -1.006447 \n", + "9868 0.807888 -0.207399 -0.784665 -0.462136 \n", + "9903 0.868878 0.460086 0.092258 -0.107597 \n", + "9956 0.746898 -0.234132 -0.161116 -0.601379 \n", + "\n", + " semi_major_axis_au_unit \n", + "306 0.357201 \n", + "310 0.344233 \n", + "546 0.139080 \n", + "635 0.367559 \n", + "701 0.829337 \n", + "... ... \n", + "9697 -0.533538 \n", + "9813 -0.677961 \n", + "9868 -0.539794 \n", + "9903 0.071794 \n", + "9956 -0.664887 \n", + "\n", + "[9999 rows x 5 columns]" ] }, + "execution_count": 33, "metadata": {}, - "execution_count": 33 + "output_type": "execute_result" } ], "source": [ @@ -1895,12 +1765,7 @@ }, { "cell_type": "code", - "source": [ - "for categorical_col in categorical_cols:\n", - " beam_df_categorical = get_one_hot_encoding(df=beam_df, categorical_col=categorical_col)\n", - " beam_df_numericals = beam_df_numericals.merge(beam_df_categorical, left_index = True, right_index = True)\n", - "ib.collect(beam_df_numericals)" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1909,14 +1774,9 @@ "id": "k9rvtWqHf6Qw", "outputId": "b8d8ae57-6dba-45b4-e7ae-e4b14084eede" }, - "execution_count": null, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1925,49 +1785,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6b2563c7f661bc0fc5729c2577d6f232\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6b2563c7f661bc0fc5729c2577d6f232\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6b2563c7f661bc0fc5729c2577d6f232\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6b2563c7f661bc0fc5729c2577d6f232\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1976,49 +1810,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6fa896083b128ad99059af69a3d7fc7e\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6fa896083b128ad99059af69a3d7fc7e\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6fa896083b128ad99059af69a3d7fc7e\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6fa896083b128ad99059af69a3d7fc7e\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2027,49 +1835,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6339347de9805da541eba53abaee2d5e\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6339347de9805da541eba53abaee2d5e\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6339347de9805da541eba53abaee2d5e\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6339347de9805da541eba53abaee2d5e\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2078,127 +1860,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_1af5b908898a1e5949dcc20549f650eb\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_1af5b908898a1e5949dcc20549f650eb\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_1af5b908898a1e5949dcc20549f650eb\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_1af5b908898a1e5949dcc20549f650eb\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " absolute_magnitude eccentricity inclination moid_ld \\\n", - "0 -5.657067 -0.867596 0.426645 0.540537 \n", - "12 -3.583402 -0.756931 1.364340 0.238610 \n", - "47 -3.400432 -0.912290 -0.211925 1.136060 \n", - "381 -2.363599 0.271412 -0.078826 0.535299 \n", - "515 -2.729540 1.469775 0.799915 -0.602881 \n", - "... ... ... ... ... \n", - "9146 0.563927 -0.508757 -0.327512 -0.637391 \n", - "9657 1.478779 0.487849 -0.637779 -0.648240 \n", - "9704 0.380957 -0.238383 0.443053 0.670490 \n", - "9879 1.295809 -0.442966 -0.698505 -0.494818 \n", - "9980 0.746898 -1.455992 -0.849144 0.592902 \n", - "\n", - " semi_major_axis_au_unit near_earth_object_N near_earth_object_Y \\\n", - "0 0.130649 1 0 \n", - "12 -0.187375 1 0 \n", - "47 0.691182 1 0 \n", - "381 0.712755 1 0 \n", - "515 -0.014654 1 0 \n", - "... ... ... ... \n", - "9146 -0.820638 1 0 \n", - "9657 -0.468778 1 0 \n", - "9704 0.587128 1 0 \n", - "9879 -0.662602 1 0 \n", - "9980 -0.022726 1 0 \n", - "\n", - " near_earth_object_nan object_class_AMO object_class_APO ... \\\n", - "0 0 0 0 ... \n", - "12 0 0 0 ... \n", - "47 0 0 0 ... \n", - "381 0 0 0 ... \n", - "515 0 0 0 ... \n", - "... ... ... ... ... \n", - "9146 0 0 0 ... \n", - "9657 0 0 0 ... \n", - "9704 0 0 0 ... \n", - "9879 0 0 0 ... \n", - "9980 0 0 0 ... \n", - "\n", - " object_class_CEN object_class_IMB object_class_MBA object_class_MCA \\\n", - "0 0 0 1 0 \n", - "12 0 0 1 0 \n", - "47 0 0 1 0 \n", - "381 0 0 1 0 \n", - "515 0 0 1 0 \n", - "... ... ... ... ... \n", - "9146 0 0 1 0 \n", - "9657 0 0 1 0 \n", - "9704 0 0 1 0 \n", - "9879 0 0 1 0 \n", - "9980 0 0 1 0 \n", - "\n", - " object_class_OMB object_class_TJN object_class_nan hazardous_flag_N \\\n", - "0 0 0 0 1 \n", - "12 0 0 0 1 \n", - "47 0 0 0 1 \n", - "381 0 0 0 1 \n", - "515 0 0 0 1 \n", - "... ... ... ... ... \n", - "9146 0 0 0 1 \n", - "9657 0 0 0 1 \n", - "9704 0 0 0 1 \n", - "9879 0 0 0 1 \n", - "9980 0 0 0 1 \n", - "\n", - " hazardous_flag_Y hazardous_flag_nan \n", - "0 0 0 \n", - "12 0 0 \n", - "47 0 0 \n", - "381 0 0 \n", - "515 0 0 \n", - "... ... ... \n", - "9146 0 0 \n", - "9657 0 0 \n", - "9704 0 0 \n", - "9879 0 0 \n", - "9980 0 0 \n", - "\n", - "[9999 rows x 22 columns]" - ], "text/html": [ "\n", "
\n", @@ -2589,11 +2267,99 @@ "
\n", " \n", " " + ], + "text/plain": [ + " absolute_magnitude eccentricity inclination moid_ld \\\n", + "0 -5.657067 -0.867596 0.426645 0.540537 \n", + "12 -3.583402 -0.756931 1.364340 0.238610 \n", + "47 -3.400432 -0.912290 -0.211925 1.136060 \n", + "381 -2.363599 0.271412 -0.078826 0.535299 \n", + "515 -2.729540 1.469775 0.799915 -0.602881 \n", + "... ... ... ... ... \n", + "9146 0.563927 -0.508757 -0.327512 -0.637391 \n", + "9657 1.478779 0.487849 -0.637779 -0.648240 \n", + "9704 0.380957 -0.238383 0.443053 0.670490 \n", + "9879 1.295809 -0.442966 -0.698505 -0.494818 \n", + "9980 0.746898 -1.455992 -0.849144 0.592902 \n", + "\n", + " semi_major_axis_au_unit near_earth_object_N near_earth_object_Y \\\n", + "0 0.130649 1 0 \n", + "12 -0.187375 1 0 \n", + "47 0.691182 1 0 \n", + "381 0.712755 1 0 \n", + "515 -0.014654 1 0 \n", + "... ... ... ... \n", + "9146 -0.820638 1 0 \n", + "9657 -0.468778 1 0 \n", + "9704 0.587128 1 0 \n", + "9879 -0.662602 1 0 \n", + "9980 -0.022726 1 0 \n", + "\n", + " near_earth_object_nan object_class_AMO object_class_APO ... \\\n", + "0 0 0 0 ... \n", + "12 0 0 0 ... \n", + "47 0 0 0 ... \n", + "381 0 0 0 ... \n", + "515 0 0 0 ... \n", + "... ... ... ... ... \n", + "9146 0 0 0 ... \n", + "9657 0 0 0 ... \n", + "9704 0 0 0 ... \n", + "9879 0 0 0 ... \n", + "9980 0 0 0 ... \n", + "\n", + " object_class_CEN object_class_IMB object_class_MBA object_class_MCA \\\n", + "0 0 0 1 0 \n", + "12 0 0 1 0 \n", + "47 0 0 1 0 \n", + "381 0 0 1 0 \n", + "515 0 0 1 0 \n", + "... ... ... ... ... \n", + "9146 0 0 1 0 \n", + "9657 0 0 1 0 \n", + "9704 0 0 1 0 \n", + "9879 0 0 1 0 \n", + "9980 0 0 1 0 \n", + "\n", + " object_class_OMB object_class_TJN object_class_nan hazardous_flag_N \\\n", + "0 0 0 0 1 \n", + "12 0 0 0 1 \n", + "47 0 0 0 1 \n", + "381 0 0 0 1 \n", + "515 0 0 0 1 \n", + "... ... ... ... ... \n", + "9146 0 0 0 1 \n", + "9657 0 0 0 1 \n", + "9704 0 0 0 1 \n", + "9879 0 0 0 1 \n", + "9980 0 0 0 1 \n", + "\n", + " hazardous_flag_Y hazardous_flag_nan \n", + "0 0 0 \n", + "12 0 0 \n", + "47 0 0 \n", + "381 0 0 \n", + "515 0 0 \n", + "... ... ... \n", + "9146 0 0 \n", + "9657 0 0 \n", + "9704 0 0 \n", + "9879 0 0 \n", + "9980 0 0 \n", + "\n", + "[9999 rows x 22 columns]" ] }, + "execution_count": 35, "metadata": {}, - "execution_count": 35 + "output_type": "execute_result" } + ], + "source": [ + "for categorical_col in categorical_cols:\n", + " beam_df_categorical = get_one_hot_encoding(df=beam_df, categorical_col=categorical_col)\n", + " beam_df_numericals = beam_df_numericals.merge(beam_df_categorical, left_index = True, right_index = True)\n", + "ib.collect(beam_df_numericals)" ] }, { @@ -2613,28 +2379,24 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "ndaSNond0v8Q", "colab": { "base_uri": "https://localhost:8080/", "height": 651 }, + "id": "ndaSNond0v8Q", "outputId": "b265e915-e649-44e4-a31a-95ac85c0ebf6" }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/content/beam/sdks/python/apache_beam/dataframe/frame_base.py:145: RuntimeWarning: invalid value encountered in double_scalars\n", " lambda left, right: getattr(left, op)(right), name=op, args=[other])\n" ] }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2643,49 +2405,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_cb06c945824aa1bb68aa31ad7e601b74\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_cb06c945824aa1bb68aa31ad7e601b74\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_cb06c945824aa1bb68aa31ad7e601b74\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_cb06c945824aa1bb68aa31ad7e601b74\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2694,49 +2430,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fb923f80fecb72b4fa55e5cfdba16d23\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fb923f80fecb72b4fa55e5cfdba16d23\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fb923f80fecb72b4fa55e5cfdba16d23\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fb923f80fecb72b4fa55e5cfdba16d23\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2745,49 +2455,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_3f4b1a0f483cd017e004e11816a91d3b\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_3f4b1a0f483cd017e004e11816a91d3b\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_3f4b1a0f483cd017e004e11816a91d3b\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_3f4b1a0f483cd017e004e11816a91d3b\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2796,127 +2480,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fce8902eccbfaa17e32ba0c7c242ccec\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fce8902eccbfaa17e32ba0c7c242ccec\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fce8902eccbfaa17e32ba0c7c242ccec\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fce8902eccbfaa17e32ba0c7c242ccec\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " absolute_magnitude eccentricity inclination moid_ld \\\n", - "0 -5.657067 -0.867596 0.426645 0.540537 \n", - "12 -3.583402 -0.756931 1.364340 0.238610 \n", - "47 -3.400432 -0.912290 -0.211925 1.136060 \n", - "381 -2.363599 0.271412 -0.078826 0.535299 \n", - "515 -2.729540 1.469775 0.799915 -0.602881 \n", - "... ... ... ... ... \n", - "9146 0.563927 -0.508757 -0.327512 -0.637391 \n", - "9657 1.478779 0.487849 -0.637779 -0.648240 \n", - "9704 0.380957 -0.238383 0.443053 0.670490 \n", - "9879 1.295809 -0.442966 -0.698505 -0.494818 \n", - "9980 0.746898 -1.455992 -0.849144 0.592902 \n", - "\n", - " semi_major_axis_au_unit near_earth_object_N near_earth_object_Y \\\n", - "0 0.130649 1 0 \n", - "12 -0.187375 1 0 \n", - "47 0.691182 1 0 \n", - "381 0.712755 1 0 \n", - "515 -0.014654 1 0 \n", - "... ... ... ... \n", - "9146 -0.820638 1 0 \n", - "9657 -0.468778 1 0 \n", - "9704 0.587128 1 0 \n", - "9879 -0.662602 1 0 \n", - "9980 -0.022726 1 0 \n", - "\n", - " near_earth_object_nan object_class_AMO object_class_APO ... \\\n", - "0 0 0 0 ... \n", - "12 0 0 0 ... \n", - "47 0 0 0 ... \n", - "381 0 0 0 ... \n", - "515 0 0 0 ... \n", - "... ... ... ... ... \n", - "9146 0 0 0 ... \n", - "9657 0 0 0 ... \n", - "9704 0 0 0 ... \n", - "9879 0 0 0 ... \n", - "9980 0 0 0 ... \n", - "\n", - " object_class_CEN object_class_IMB object_class_MBA object_class_MCA \\\n", - "0 0 0 1 0 \n", - "12 0 0 1 0 \n", - "47 0 0 1 0 \n", - "381 0 0 1 0 \n", - "515 0 0 1 0 \n", - "... ... ... ... ... \n", - "9146 0 0 1 0 \n", - "9657 0 0 1 0 \n", - "9704 0 0 1 0 \n", - "9879 0 0 1 0 \n", - "9980 0 0 1 0 \n", - "\n", - " object_class_OMB object_class_TJN object_class_nan hazardous_flag_N \\\n", - "0 0 0 0 1 \n", - "12 0 0 0 1 \n", - "47 0 0 0 1 \n", - "381 0 0 0 1 \n", - "515 0 0 0 1 \n", - "... ... ... ... ... \n", - "9146 0 0 0 1 \n", - "9657 0 0 0 1 \n", - "9704 0 0 0 1 \n", - "9879 0 0 0 1 \n", - "9980 0 0 0 1 \n", - "\n", - " hazardous_flag_Y hazardous_flag_nan \n", - "0 0 0 \n", - "12 0 0 \n", - "47 0 0 \n", - "381 0 0 \n", - "515 0 0 \n", - "... ... ... \n", - "9146 0 0 \n", - "9657 0 0 \n", - "9704 0 0 \n", - "9879 0 0 \n", - "9980 0 0 \n", - "\n", - "[9999 rows x 22 columns]" - ], "text/html": [ "\n", "
\n", @@ -3307,10 +2887,92 @@ "
\n", " \n", " " + ], + "text/plain": [ + " absolute_magnitude eccentricity inclination moid_ld \\\n", + "0 -5.657067 -0.867596 0.426645 0.540537 \n", + "12 -3.583402 -0.756931 1.364340 0.238610 \n", + "47 -3.400432 -0.912290 -0.211925 1.136060 \n", + "381 -2.363599 0.271412 -0.078826 0.535299 \n", + "515 -2.729540 1.469775 0.799915 -0.602881 \n", + "... ... ... ... ... \n", + "9146 0.563927 -0.508757 -0.327512 -0.637391 \n", + "9657 1.478779 0.487849 -0.637779 -0.648240 \n", + "9704 0.380957 -0.238383 0.443053 0.670490 \n", + "9879 1.295809 -0.442966 -0.698505 -0.494818 \n", + "9980 0.746898 -1.455992 -0.849144 0.592902 \n", + "\n", + " semi_major_axis_au_unit near_earth_object_N near_earth_object_Y \\\n", + "0 0.130649 1 0 \n", + "12 -0.187375 1 0 \n", + "47 0.691182 1 0 \n", + "381 0.712755 1 0 \n", + "515 -0.014654 1 0 \n", + "... ... ... ... \n", + "9146 -0.820638 1 0 \n", + "9657 -0.468778 1 0 \n", + "9704 0.587128 1 0 \n", + "9879 -0.662602 1 0 \n", + "9980 -0.022726 1 0 \n", + "\n", + " near_earth_object_nan object_class_AMO object_class_APO ... \\\n", + "0 0 0 0 ... \n", + "12 0 0 0 ... \n", + "47 0 0 0 ... \n", + "381 0 0 0 ... \n", + "515 0 0 0 ... \n", + "... ... ... ... ... \n", + "9146 0 0 0 ... \n", + "9657 0 0 0 ... \n", + "9704 0 0 0 ... \n", + "9879 0 0 0 ... \n", + "9980 0 0 0 ... \n", + "\n", + " object_class_CEN object_class_IMB object_class_MBA object_class_MCA \\\n", + "0 0 0 1 0 \n", + "12 0 0 1 0 \n", + "47 0 0 1 0 \n", + "381 0 0 1 0 \n", + "515 0 0 1 0 \n", + "... ... ... ... ... \n", + "9146 0 0 1 0 \n", + "9657 0 0 1 0 \n", + "9704 0 0 1 0 \n", + "9879 0 0 1 0 \n", + "9980 0 0 1 0 \n", + "\n", + " object_class_OMB object_class_TJN object_class_nan hazardous_flag_N \\\n", + "0 0 0 0 1 \n", + "12 0 0 0 1 \n", + "47 0 0 0 1 \n", + "381 0 0 0 1 \n", + "515 0 0 0 1 \n", + "... ... ... ... ... \n", + "9146 0 0 0 1 \n", + "9657 0 0 0 1 \n", + "9704 0 0 0 1 \n", + "9879 0 0 0 1 \n", + "9980 0 0 0 1 \n", + "\n", + " hazardous_flag_Y hazardous_flag_nan \n", + "0 0 0 \n", + "12 0 0 \n", + "47 0 0 \n", + "381 0 0 \n", + "515 0 0 \n", + "... ... ... \n", + "9146 0 0 \n", + "9657 0 0 \n", + "9704 0 0 \n", + "9879 0 0 \n", + "9980 0 0 \n", + "\n", + "[9999 rows x 22 columns]" ] }, + "execution_count": 36, "metadata": {}, - "execution_count": 36 + "output_type": "execute_result" } ], "source": [ @@ -3356,31 +3018,36 @@ }, { "cell_type": "code", - "source": [ - "PROJECT_ID = \"\"\n", - "REGION = \"us-central1\"\n", - "TEMP_DIR = \"gs:///tmp\"\n", - "OUTPUT_DIR = \"gs:///dataframe-result\"" - ], + "execution_count": null, "metadata": { "id": "dDBYbMEWbL4t" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "PROJECT_ID = \"\" # @param {type:'string'}\n", + "REGION = \"us-central1\"\n", + "TEMP_DIR = \"gs:///tmp\" # @param {type:'string'}\n", + "OUTPUT_DIR = \"gs:///dataframe-result\" # @param {type:'string'}" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "Qk1GaYoSc9-1" + }, "source": [ "These steps process the full dataset, `full.csv`, which contains approximately one million rows. To materialize the deferred DataFrame, these steps also write the results to a CSV file instead of using `ib.collect()`.\n", "\n", "To switch from an interactive runner to a distributed runner, update the pipeline options. The rest of the pipeline steps don't change." - ], - "metadata": { - "id": "Qk1GaYoSc9-1" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1XovR0gKbMlK" + }, + "outputs": [], "source": [ "# Specify the location of the source CSV file (the full dataset).\n", "source_csv_file = 'gs://apache-beam-samples/nasa_jpl_asteroid/full.csv'\n", @@ -3417,44 +3084,42 @@ "\n", "# Write the preprocessed dataset to a CSV file.\n", "beam_df_numericals.to_csv(os.path.join(OUTPUT_DIR, \"preprocessed_data.csv\"))" - ], - "metadata": { - "id": "1XovR0gKbMlK" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Submit and run the pipeline." - ], "metadata": { "id": "a789u4Yecs_g" - } + }, + "source": [ + "Submit and run the pipeline." + ] }, { "cell_type": "code", - "source": [ - "p.run().wait_until_finish()" - ], + "execution_count": null, "metadata": { "id": "pbUlC102bPaZ" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "p.run().wait_until_finish()" + ] }, { "cell_type": "markdown", - "source": [ - "Wait while the pipeline job runs." - ], "metadata": { "id": "dzdqmzKzTOng" - } + }, + "source": [ + "Wait while the pipeline job runs." + ] }, { "cell_type": "markdown", + "metadata": { + "id": "UOLr6YgOOSVQ" + }, "source": [ "## What's next \n", "\n", @@ -3464,13 +3129,13 @@ "[Structured data classification from scratch](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/).\n", "\n", "To continue learning, find another dataset to use with the Apache Beam DataFrames API processing. Think carefully about which features to include in your model and how to represent them.\n" - ], - "metadata": { - "id": "UOLr6YgOOSVQ" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "nG9WXXVcMCe_" + }, "source": [ "## Resources\n", "\n", @@ -3479,10 +3144,7 @@ "* [10 minutes to Pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/10min.html) - A quickstart guide to the Pandas DataFrames.\n", "* [Pandas DataFrame API](https://pandas.pydata.org/pandas-docs/stable/reference/frame.html) - The API reference for the Pandas DataFrames.\n", "* [Data preparation and feature training in ML](https://developers.google.com/machine-learning/data-prep) - A guideline about data transformation for ML training." - ], - "metadata": { - "id": "nG9WXXVcMCe_" - } + ] } ], "metadata": { diff --git a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb index d7b2b157f613..1b20270f327a 100644 --- a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb +++ b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb @@ -174,7 +174,8 @@ "WORKDIR /workspace\n", "\n", "COPY gemma2 gemma2\n", - "RUN apt-get update -y && apt-get install -y cmake && apt-get install -y vim" + "RUN apt-get update -y && apt-get install -y cmake && apt-get install -y vim\n", + "```" ] }, { @@ -208,7 +209,8 @@ "apache_beam[gcp]==2.54.0\n", "keras_nlp==0.14.3\n", "keras==3.4.1\n", - "jax[cuda12]" + "jax[cuda12]\n", + "```" ] }, { @@ -261,7 +263,8 @@ "\n", "\n", "# Set the entrypoint to the Apache Beam SDK launcher.\n", - "ENTRYPOINT [\"/opt/apache/beam/boot\"]" + "ENTRYPOINT [\"/opt/apache/beam/boot\"]\n", + "```" ] }, { @@ -364,7 +367,7 @@ "# options.view_as(WorkerOptions).disk_size_gb=200\n", "# options.view_as(GoogleCloudOptions).dataflow_service_options=[\"worker_accelerator=type:nvidia-l4;count:1;install-nvidia-driver\"]\n", "\n", - "topic_reviews=\"\"" + "topic_reviews=\"\" # @param {type:'string'}" ] }, { diff --git a/examples/notebooks/beam-ml/nlp_tensorflow_streaming.ipynb b/examples/notebooks/beam-ml/nlp_tensorflow_streaming.ipynb index f9a263e39030..6f5048e7e8ee 100644 --- a/examples/notebooks/beam-ml/nlp_tensorflow_streaming.ipynb +++ b/examples/notebooks/beam-ml/nlp_tensorflow_streaming.ipynb @@ -496,23 +496,23 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Epoch 1/10\n", "25/25 [==============================] - ETA: 0s - loss: 0.5931 - accuracy: 0.7650" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as _update_step_xla, lstm_cell_7_layer_call_fn, lstm_cell_7_layer_call_and_return_conditional_losses, lstm_cell_8_layer_call_fn, lstm_cell_8_layer_call_and_return_conditional_losses while saving (showing 5 of 9). These functions will not be directly callable after loading.\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r25/25 [==============================] - 60s 2s/step - loss: 0.5931 - accuracy: 0.7650 - val_loss: 0.3625 - val_accuracy: 0.8900\n", "Epoch 2/10\n", @@ -536,14 +536,14 @@ ] }, { - "output_type": "execute_result", "data": { "text/plain": [ "" ] }, + "execution_count": 57, "metadata": {}, - "execution_count": 57 + "output_type": "execute_result" } ], "source": [ @@ -604,14 +604,14 @@ }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "" ] }, + "execution_count": 59, "metadata": {}, - "execution_count": 59 + "output_type": "execute_result" } ], "source": [ @@ -641,8 +641,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as _update_step_xla, lstm_cell_7_layer_call_fn, lstm_cell_7_layer_call_and_return_conditional_losses, lstm_cell_8_layer_call_fn, lstm_cell_8_layer_call_and_return_conditional_losses while saving (showing 5 of 9). These functions will not be directly callable after loading.\n" ] @@ -706,8 +706,10 @@ "source": [ "import os\n", "from google.cloud import pubsub_v1\n", - "PROJECT_ID = '' # Add your project ID here\n", - "TOPIC = '' # Add your topic name here\n", + "# Add your project ID here\n", + "PROJECT_ID = '' # @param {type:'string'}\n", + "# Add your topic name here\n", + "TOPIC = '' # @param {type:'string'}\n", "publisher = pubsub_v1.PublisherClient()\n", "topic_name = 'projects/{project_id}/topics/{topic}'.format(\n", " project_id = PROJECT_ID,\n", @@ -739,8 +741,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Can’t wait to watch you guys grow . Harmonies are on point and the oversized early 90’s blazers are a great touch.\n", "Amazing performance! Such an inspiring group ❤\n", @@ -908,8 +910,8 @@ }, "outputs": [], "source": [ - "# path to the topic\n", - "TOPIC_PATH = '' # Add the path to your topic here" + "# Add the path to your topic here\n", + "TOPIC_PATH = '' # @param {type:'string'}" ] }, { @@ -920,18 +922,18 @@ }, "outputs": [], "source": [ - "# path to the subscription\n", - "SUBS_PATH = '' # Add the path to your subscription here" + "# Add the path to your subscription here\n", + "SUBS_PATH = '' # @param {type:'string'}" ] }, { "cell_type": "markdown", - "source": [ - "Importing InteractiveRunner" - ], "metadata": { "id": "UliBhojEfxhq" - } + }, + "source": [ + "Importing InteractiveRunner" + ] }, { "cell_type": "code", @@ -986,8 +988,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Can’t wait to watch you guys grow . Harmonies are on point and the oversized early 90’s blazers are a great touch.\n", "Amazing performance! Such an inspiring group ❤\n", @@ -1048,7 +1050,6 @@ }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "[[[0.852806806564331, 0.14719319343566895], 'positive'],\n", @@ -1059,8 +1060,9 @@ " [[0.8648154735565186, 0.13518451154232025], 'positive']]" ] }, + "execution_count": 38, "metadata": {}, - "execution_count": 38 + "output_type": "execute_result" } ], "source": [ diff --git a/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb b/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb index cc31ff678fe4..aae86e31aa44 100644 --- a/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb +++ b/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb @@ -209,7 +209,7 @@ "\n", "3. Create the index.\n", "\n", - "4. Index creation is neeeded only once." + "4. Index creation is needed only once." ] }, { diff --git a/examples/notebooks/beam-ml/run_inference_pytorch.ipynb b/examples/notebooks/beam-ml/run_inference_pytorch.ipynb index eaf46be16bbd..93dd12dd20ab 100644 --- a/examples/notebooks/beam-ml/run_inference_pytorch.ipynb +++ b/examples/notebooks/beam-ml/run_inference_pytorch.ipynb @@ -1,22 +1,13 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "code", + "execution_count": 3, + "metadata": { + "cellView": "form", + "id": "C1rAsD2L-hSO" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -36,13 +27,7 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "C1rAsD2L-hSO" - }, - "execution_count": 3, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -95,23 +80,23 @@ }, { "cell_type": "code", - "source": [ - "!pip install apache_beam[gcp,dataframe] --quiet" - ], + "execution_count": null, "metadata": { "id": "loxD-rOVchRn" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "!pip install apache_beam[gcp,dataframe] --quiet" + ] }, { "cell_type": "code", "execution_count": 39, "metadata": { - "id": "7f841596-f217-46d2-b64e-1952db4de4cb", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "7f841596-f217-46d2-b64e-1952db4de4cb", "outputId": "09e0026a-cf8e-455c-9580-bfaef44683ce" }, "outputs": [], @@ -151,15 +136,15 @@ }, { "cell_type": "code", - "source": [ - "from google.colab import auth\n", - "auth.authenticate_user()" - ], + "execution_count": 41, "metadata": { "id": "V0E35R5Ka2cE" }, - "execution_count": 41, - "outputs": [] + "outputs": [], + "source": [ + "from google.colab import auth\n", + "auth.authenticate_user()" + ] }, { "cell_type": "code", @@ -170,8 +155,8 @@ "outputs": [], "source": [ "# Constants\n", - "project = \"\"\n", - "bucket = \"\"\n", + "project = \"\" # @param {type:'string'}\n", + "bucket = \"\" # @param {type:'string'}\n", "\n", "# To avoid warnings, set the project.\n", "os.environ['GOOGLE_CLOUD_PROJECT'] = project\n", @@ -183,8 +168,8 @@ { "cell_type": "markdown", "metadata": { - "tags": [], - "id": "b2b7cedc-79f5-4599-8178-e5da35dba032" + "id": "b2b7cedc-79f5-4599-8178-e5da35dba032", + "tags": [] }, "source": [ "## Create data and PyTorch models for the RunInference transform\n", @@ -294,16 +279,16 @@ "cell_type": "code", "execution_count": 46, "metadata": { - "id": "882bbada-4f6d-4370-a047-c5961e564ee8", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "882bbada-4f6d-4370-a047-c5961e564ee8", "outputId": "ab7242a9-76eb-4760-d74e-c725261e2a34" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "True\n" ] @@ -384,16 +369,16 @@ "cell_type": "code", "execution_count": 49, "metadata": { - "id": "42b2ca0f-5d44-4d15-a313-f3d56ae7f675", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "42b2ca0f-5d44-4d15-a313-f3d56ae7f675", "outputId": "9cb2f268-a500-4ad5-a075-856c87b8e3be" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "True\n" ] @@ -430,16 +415,16 @@ "cell_type": "code", "execution_count": 50, "metadata": { - "id": "e488a821-3b70-4284-96f3-ddee4dcb9d71", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "e488a821-3b70-4284-96f3-ddee4dcb9d71", "outputId": "add9af31-1cc6-496f-a6e4-3fb185c0de25" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "PredictionResult(example=tensor([20.]), inference=tensor([102.0095], grad_fn=))\n", "PredictionResult(example=tensor([40.]), inference=tensor([201.2056], grad_fn=))\n", @@ -483,16 +468,16 @@ "cell_type": "code", "execution_count": 51, "metadata": { - "id": "96f38a5a-4db0-4c39-8ce7-80d9f9911b48", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "96f38a5a-4db0-4c39-8ce7-80d9f9911b48", "outputId": "b1d689a2-9336-40b2-a984-538bec888cc9" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "input is 20.0 output is 102.00947570800781\n", "input is 40.0 output is 201.20559692382812\n", @@ -576,7 +561,7 @@ " yield (f\"key: {key}, input: {input_value.item()} output: {output_value.item()}\" )" ] }, - { + { "cell_type": "markdown", "metadata": { "id": "f22da313-5bf8-4334-865b-bbfafc374e63" @@ -592,7 +577,7 @@ "id": "c9b0fb49-d605-4f26-931a-57f42b0ad253" }, "source": [ - "#### Use BigQuery as the source", + "#### Use BigQuery as the source\n", "Follow these steps to use BigQuery as your source." ] }, @@ -627,47 +612,47 @@ }, { "cell_type": "code", - "source": [ - "!gcloud config set project $project" - ], + "execution_count": 54, "metadata": { - "id": "7mgnryX-Zlfs", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "7mgnryX-Zlfs", "outputId": "6e608e98-8369-45aa-c983-e62296202c52" }, - "execution_count": 54, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Updated property [core/project].\n" ] } + ], + "source": [ + "!gcloud config set project $project" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { - "id": "a6a984cd-2e92-4c44-821b-9bf1dd52fb7d", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "a6a984cd-2e92-4c44-821b-9bf1dd52fb7d", "outputId": "a50ab0fd-4f4e-4493-b506-41d3f7f08966" }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "" ] }, + "execution_count": 55, "metadata": {}, - "execution_count": 55 + "output_type": "execute_result" } ], "source": [ @@ -715,16 +700,16 @@ "cell_type": "code", "execution_count": 56, "metadata": { - "id": "34331897-23f5-4850-8974-67e522e956dc", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "34331897-23f5-4850-8974-67e522e956dc", "outputId": "9d2b0ba5-97a2-46bf-c9d3-e023afbd3122" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "key: third_question, input: 1000.0 output: 4962.61962890625\n", "key: second_question, input: 108.0 output: 538.472412109375\n", @@ -761,7 +746,7 @@ "id": "53ee7f24-5625-475a-b8cc-9c031591f304" }, "source": [ - "#### Use a CSV file as the source", + "#### Use a CSV file as the source\n", "Follow these steps to use a CSV file as your source." ] }, @@ -776,6 +761,11 @@ }, { "cell_type": "code", + "execution_count": 62, + "metadata": { + "id": "exAZjP7cYAFv" + }, + "outputs": [], "source": [ "# creates a CSV file with the values.\n", "csv_values = [(\"first_question\", 105.00),\n", @@ -791,27 +781,22 @@ " writer.writerow(row)\n", "\n", "assert os.path.exists(input_csv_file) == True" - ], - "metadata": { - "id": "exAZjP7cYAFv" - }, - "execution_count": 62, - "outputs": [] + ] }, { "cell_type": "code", "execution_count": 66, "metadata": { - "id": "9a054c2d-4d84-4b37-b067-1dda5347e776", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "9a054c2d-4d84-4b37-b067-1dda5347e776", "outputId": "2f2ea8b7-b425-48ae-e857-fe214c7eced2" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "key: first_question, input: 105.0 output: 523.5929565429688\n", "key: second_question, input: 108.0 output: 538.472412109375\n", @@ -890,16 +875,16 @@ "cell_type": "code", "execution_count": 68, "metadata": { - "id": "629d070e-9902-42c9-a1e7-56c3d1864f13", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "629d070e-9902-42c9-a1e7-56c3d1864f13", "outputId": "0b4d7f3c-4696-422f-b031-ee5a03e90e03" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "key: third_question * 10, input: 1000.0 output: 9889.59765625\n", "key: second_question * 10, input: 108.0 output: 1075.4891357421875\n", @@ -966,16 +951,16 @@ "cell_type": "code", "execution_count": 69, "metadata": { - "id": "8db9d649-5549-4b58-a9ad-7b8592c2bcbf", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "8db9d649-5549-4b58-a9ad-7b8592c2bcbf", "outputId": "328ba32b-40d4-445b-8b4e-5568258b8a26" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "key: original input is `third_question tensor([1000.])`, input: 4962.61962890625 output: 49045.37890625\n", "key: original input is `second_question tensor([108.])`, input: 538.472412109375 output: 5329.11083984375\n", @@ -1015,5 +1000,20 @@ " inference_result | beam.Map(print)" ] } - ] + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/run_inference_sklearn.ipynb b/examples/notebooks/beam-ml/run_inference_sklearn.ipynb index cf896a18981a..1b76f76df292 100644 --- a/examples/notebooks/beam-ml/run_inference_sklearn.ipynb +++ b/examples/notebooks/beam-ml/run_inference_sklearn.ipynb @@ -1,22 +1,13 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "C1rAsD2L-hSO" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -36,13 +27,7 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "C1rAsD2L-hSO" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -87,24 +72,20 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "zzwnMzzgdyPB" + }, "source": [ "## Before you begin\n", "Complete the following setup steps:\n", "1. Install dependencies for Apache Beam.\n", "1. Authenticate with Google Cloud.\n", "1. Specify your project and bucket. You use the project and bucket to save and load models." - ], - "metadata": { - "id": "zzwnMzzgdyPB" - } + ] }, { "cell_type": "code", - "source": [ - "!pip install google-api-core --quiet\n", - "!pip install google-cloud-pubsub google-cloud-bigquery-storage --quiet\n", - "!pip install apache-beam[gcp,dataframe] --quiet" - ], + "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -112,8 +93,12 @@ "id": "6vlKcT-Wev20", "outputId": "336e8afc-6716-41dd-a438-500353189c62" }, - "execution_count": 1, - "outputs": [] + "outputs": [], + "source": [ + "!pip install google-api-core --quiet\n", + "!pip install google-cloud-pubsub google-cloud-bigquery-storage --quiet\n", + "!pip install apache-beam[gcp,dataframe] --quiet" + ] }, { "cell_type": "markdown", @@ -128,15 +113,15 @@ }, { "cell_type": "code", - "source": [ - "from google.colab import auth\n", - "auth.authenticate_user()" - ], + "execution_count": 2, "metadata": { "id": "V0E35R5Ka2cE" }, - "execution_count": 2, - "outputs": [] + "outputs": [], + "source": [ + "from google.colab import auth\n", + "auth.authenticate_user()" + ] }, { "cell_type": "code", @@ -174,8 +159,8 @@ "import os\n", "\n", "# Constants\n", - "project = \"\"\n", - "bucket = \"\" \n", + "project = \"\" # @param {type:'string'}\n", + "bucket = \"\" # @param {type:'string'}\n", "\n", "# To avoid warnings, set the project.\n", "os.environ['GOOGLE_CLOUD_PROJECT'] = project\n" @@ -240,20 +225,18 @@ }, { "cell_type": "code", - "source": [ - "%pip install --upgrade google-cloud-bigquery --quiet" - ], + "execution_count": 9, "metadata": { "id": "AEGaqpMVqgRP" }, - "execution_count": 9, - "outputs": [] + "outputs": [], + "source": [ + "%pip install --upgrade google-cloud-bigquery --quiet" + ] }, { "cell_type": "code", - "source": [ - "!gcloud config set project $project" - ], + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -261,19 +244,41 @@ "id": "xq5AKtRrqlUx", "outputId": "fba8fb42-4958-451a-8aaa-9a838052a2f8" }, - "execution_count": 10, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Updated property [core/project].\n" ] } + ], + "source": [ + "!gcloud config set project $project" ] }, { "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QCIjN__rpoVF", + "outputId": "0ded224f-2272-482e-80f5-bb2d21b6f5d8" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Populated BigQuery table\n", "\n", @@ -306,42 +311,22 @@ "\n", "create_job = client.query(query)\n", "create_job.result()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "QCIjN__rpoVF", - "outputId": "0ded224f-2272-482e-80f5-bb2d21b6f5d8" - }, - "execution_count": 22, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 22 - } ] }, { "cell_type": "code", "execution_count": 23, "metadata": { - "id": "50a648a3-794a-4286-ab2b-fc0458db04ca", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "50a648a3-794a-4286-ab2b-fc0458db04ca", "outputId": "8eab34b4-dcc7-4df1-ec0e-8c86a34d31c6" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "PredictionResult(example=[1000.0], inference=array([5000.]))\n", "PredictionResult(example=[1013.0], inference=array([5065.]))\n", @@ -388,16 +373,16 @@ "cell_type": "code", "execution_count": 25, "metadata": { - "id": "c212916d-b517-4589-ad15-a3a1df926fb3", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "c212916d-b517-4589-ad15-a3a1df926fb3", "outputId": "61db2d76-4dfa-4b38-cf9a-645790b4c5aa" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "('third_example', PredictionResult(example=[1000.0], inference=array([5000.])))\n", "('fourth_example', PredictionResult(example=[1013.0], inference=array([5065.])))\n", @@ -424,17 +409,41 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "JQ4zvlwsRK1W" + }, "source": [ "## Run multiple models\n", "\n", "This code creates a pipeline that takes two RunInference transforms with different models and then combines the output." - ], - "metadata": { - "id": "JQ4zvlwsRK1W" - } + ] }, { "cell_type": "code", + "execution_count": 86, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0qMlX6SeR68D", + "outputId": "5e4a0852-3761-47da-aa08-0386fd524a78" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "key = third_example * 10, example = 1000.0 -> predictions 10000.0\n", + "key = fourth_example * 10, example = 1013.0 -> predictions 10130.0\n", + "key = second_example * 10, example = 108.0 -> predictions 1080.0\n", + "key = first_example * 10, example = 105.0 -> predictions 1050.0\n", + "key = third_example * 5, example = 1000.0 -> predictions 5000.0\n", + "key = fourth_example * 5, example = 1013.0 -> predictions 5065.0\n", + "key = second_example * 5, example = 108.0 -> predictions 540.0\n", + "key = first_example * 5, example = 105.0 -> predictions 525.0\n" + ] + } + ], "source": [ "from typing import Tuple\n", "\n", @@ -464,31 +473,22 @@ " _ = ((five_times, ten_times) | \"Flattened\" >> beam.Flatten()\n", " | \"format output\" >> beam.Map(format_output)\n", " | \"Print\" >> beam.Map(print))\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0qMlX6SeR68D", - "outputId": "5e4a0852-3761-47da-aa08-0386fd524a78" - }, - "execution_count": 86, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "key = third_example * 10, example = 1000.0 -> predictions 10000.0\n", - "key = fourth_example * 10, example = 1013.0 -> predictions 10130.0\n", - "key = second_example * 10, example = 108.0 -> predictions 1080.0\n", - "key = first_example * 10, example = 105.0 -> predictions 1050.0\n", - "key = third_example * 5, example = 1000.0 -> predictions 5000.0\n", - "key = fourth_example * 5, example = 1013.0 -> predictions 5065.0\n", - "key = second_example * 5, example = 108.0 -> predictions 540.0\n", - "key = first_example * 5, example = 105.0 -> predictions 525.0\n" - ] - } ] } - ] + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb b/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb index ad5bb671cce2..c15e9b21ecf9 100644 --- a/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb +++ b/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb @@ -168,8 +168,8 @@ "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor\n", "from apache_beam.options.pipeline_options import PipelineOptions\n", "\n", - "project = \"PROJECT_ID\"\n", - "bucket = \"BUCKET_NAME\"\n", + "project = \"PROJECT_ID\" # @param {type:'string'}\n", + "bucket = \"BUCKET_NAME\" # @param {type:'string'}\n", "\n", "save_model_dir_multiply = f'gs://{bucket}/tf-inference/model/multiply_five/v1/'\n", "save_weights_dir_multiply = f'gs://{bucket}/tf-inference/weights/multiply_five/v1/'\n" diff --git a/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb b/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb index 9a9c6f5d6e92..2c2f6460651b 100644 --- a/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb +++ b/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb @@ -1,32 +1,13 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "collapsed_sections": [ - "X80jy3FqHjK4", - "40qtP6zJuMXm", - "YzvZWEv-1oiK", - "rIwD_qEpX7Gu", - "O_a0-4Gb19cy", - "G-sAu3cf31f3", - "r4dpR6dQ4JwX", - "P2UMmbNW4YQV" - ] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "fFjof1NgAJwu" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -46,13 +27,7 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "fFjof1NgAJwu" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -74,6 +49,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "HrCtxslBGK8Z" + }, "source": [ "This notebook demonstrates how to use the Apache Beam [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform with TensorFlow and [TFX Basic Shared Libraries](https://github.com/tensorflow/tfx-bsl) (`tfx-bsl`).\n", "\n", @@ -89,69 +67,69 @@ "- Use the `tfx-bsl` model handler with the example data, and get a prediction inside an Apache Beam pipeline.\n", "\n", "For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation." - ], - "metadata": { - "id": "HrCtxslBGK8Z" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "HrCtxslBGK8A" + }, "source": [ "## Before you begin\n", "Set up your environment and download dependencies." - ], - "metadata": { - "id": "HrCtxslBGK8A" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "HrCtxslBGK8A" + }, "source": [ "### Import `tfx-bsl`\n", "First, import `tfx-bsl`.\n", "Creating a model handler is supported in `tfx-bsl` versions 1.10 and later." - ], - "metadata": { - "id": "HrCtxslBGK8A" - } + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "jBakpNZnAhqk" }, + "outputs": [], "source": [ "!pip install tfx_bsl==1.10.0 --quiet\n", "!pip install protobuf --quiet\n", "!pip install apache_beam --quiet" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "X80jy3FqHjK4" + }, "source": [ "### Authenticate with Google Cloud\n", "This notebook relies on saving your model to Google Cloud. To use your Google Cloud account, authenticate this notebook." - ], - "metadata": { - "id": "X80jy3FqHjK4" - } + ] }, { "cell_type": "code", + "execution_count": 2, "metadata": { "id": "Kz9sccyGBqz3" }, + "outputs": [], "source": [ "from google.colab import auth\n", "auth.authenticate_user()" - ], - "execution_count": 2, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "40qtP6zJuMXm" + }, "source": [ "### Import dependencies and set up your bucket\n", "Use the following code to import dependencies and to set up your Google Cloud Storage bucket.\n", @@ -159,16 +137,15 @@ "Replace `PROJECT_ID` and `BUCKET_NAME` with the ID of your project and the name of your bucket.\n", "\n", "**Important**: If an error occurs, restart your runtime." - ], - "metadata": { - "id": "40qtP6zJuMXm" - } + ] }, { "cell_type": "code", + "execution_count": 12, "metadata": { "id": "eEle839_Akqx" }, + "outputs": [], "source": [ "import argparse\n", "\n", @@ -190,24 +167,22 @@ "\n", "from apache_beam.options.pipeline_options import PipelineOptions\n", "\n", - "project = \"PROJECT_ID\"\n", - "bucket = \"BUCKET_NAME\"\n", + "project = \"PROJECT_ID\" # @param {type:'string'}\n", + "bucket = \"BUCKET_NAME\" # @param {type:'string'}\n", "\n", "save_model_dir_multiply = f'gs://{bucket}/tfx-inference/model/multiply_five/v1/'\n" - ], - "execution_count": 12, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "YzvZWEv-1oiK" + }, "source": [ "## Create and test a simple model\n", "\n", "This section creates and tests a model that predicts the 5 times multiplication table." - ], - "metadata": { - "id": "YzvZWEv-1oiK" - } + ] }, { "cell_type": "markdown", @@ -221,6 +196,7 @@ }, { "cell_type": "code", + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -228,26 +204,10 @@ "id": "SH7iq3zeBBJ-", "outputId": "c5adb7ec-285b-401e-f9be-1e9b83c6d0ba" }, - "source": [ - "# Create training data that represents the 5 times multiplication table for the numbers 0 to 99.\n", - "# x is the data and y is the labels.\n", - "x = numpy.arange(0, 100) # Examples\n", - "y = x * 5 # Labels\n", - "\n", - "# Build a simple linear regression model.\n", - "# Note that the model has a shape of (1) for its input layer and expects a single int64 value.\n", - "input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')\n", - "output_layer= keras.layers.Dense(1)(input_layer)\n", - "\n", - "model = keras.Model(input_layer, output_layer)\n", - "model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n", - "model.summary()" - ], - "execution_count": 4, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Model: \"model\"\n", "_________________________________________________________________\n", @@ -264,21 +224,37 @@ "_________________________________________________________________\n" ] } + ], + "source": [ + "# Create training data that represents the 5 times multiplication table for the numbers 0 to 99.\n", + "# x is the data and y is the labels.\n", + "x = numpy.arange(0, 100) # Examples\n", + "y = x * 5 # Labels\n", + "\n", + "# Build a simple linear regression model.\n", + "# Note that the model has a shape of (1) for its input layer and expects a single int64 value.\n", + "input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')\n", + "output_layer= keras.layers.Dense(1)(input_layer)\n", + "\n", + "model = keras.Model(input_layer, output_layer)\n", + "model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n", + "model.summary()" ] }, { "cell_type": "markdown", + "metadata": { + "id": "O_a0-4Gb19cy" + }, "source": [ "### Test the model\n", "\n", "This step tests the model that you created." - ], - "metadata": { - "id": "O_a0-4Gb19cy" - } + ] }, { "cell_type": "code", + "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -286,20 +262,10 @@ "id": "5XkIYXhJBFmS", "outputId": "e3bb5079-5cb8-4fe4-eb8d-d3d13d5f9f0c" }, - "source": [ - "model.fit(x, y, epochs=500, verbose=0)\n", - "test_examples =[20, 40, 60, 90]\n", - "value_to_predict = numpy.array(test_examples, dtype=numpy.float32)\n", - "predictions = model.predict(value_to_predict)\n", - "\n", - "print('Test Examples ' + str(test_examples))\n", - "print('Predictions ' + str(predictions))" - ], - "execution_count": 6, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "1/1 [==============================] - 0s 94ms/step\n", "Test Examples [20, 40, 60, 90]\n", @@ -309,10 +275,22 @@ " [34.41496 ]]\n" ] } + ], + "source": [ + "model.fit(x, y, epochs=500, verbose=0)\n", + "test_examples =[20, 40, 60, 90]\n", + "value_to_predict = numpy.array(test_examples, dtype=numpy.float32)\n", + "predictions = model.predict(value_to_predict)\n", + "\n", + "print('Test Examples ' + str(test_examples))\n", + "print('Predictions ' + str(predictions))" ] }, { "cell_type": "markdown", + "metadata": { + "id": "dEmleqiH3t71" + }, "source": [ "## RunInference with Tensorflow using `tfx-bsl`\n", "In versions 1.10.0 and later of `tfx-bsl`, you can\n", @@ -321,16 +299,15 @@ "### Populate the data in a TensorFlow proto\n", "\n", "Tensorflow data uses protos. If you are loading from a file, helpers exist for this step. Because this example uses generated data, this code populates a proto." - ], - "metadata": { - "id": "dEmleqiH3t71" - } + ] }, { "cell_type": "code", + "execution_count": 7, "metadata": { "id": "XvKc9kQilPjx" }, + "outputs": [], "source": [ "# This example shows a proto that converts the samples and labels into\n", "# tensors usable by TensorFlow.\n", @@ -371,23 +348,22 @@ " for i in value_to_predict:\n", " example = ExampleProcessor().create_example(feature=i)\n", " writer.write(example.SerializeToString())" - ], - "execution_count": 7, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "G-sAu3cf31f3" + }, "source": [ "### Fit the model\n", "\n", "This step builds a model. Because RunInference requires pretrained models, this segment builds a usable model." - ], - "metadata": { - "id": "G-sAu3cf31f3" - } + ] }, { "cell_type": "code", + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -395,6 +371,18 @@ "id": "AnbrxXPKeAOQ", "outputId": "42439aac-3a10-4e86-829f-44332aad6173" }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "RAW_DATA_TRAIN_SPEC = {\n", "'x': tf.io.FixedLenFeature([], tf.float32),\n", @@ -408,37 +396,26 @@ "dataset = dataset.repeat()\n", "\n", "model.fit(dataset, epochs=5000, steps_per_epoch=1, verbose=0)" - ], - "execution_count": 8, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 8 - } ] }, { "cell_type": "markdown", + "metadata": { + "id": "r4dpR6dQ4JwX" + }, "source": [ "### Save the model\n", "\n", "This step shows how to save your model." - ], - "metadata": { - "id": "r4dpR6dQ4JwX" - } + ] }, { "cell_type": "code", + "execution_count": 9, "metadata": { "id": "fYvrIYO3qiJx" }, + "outputs": [], "source": [ "RAW_DATA_PREDICT_SPEC = {\n", "'x': tf.io.FixedLenFeature([], tf.float32),\n", @@ -461,25 +438,24 @@ "# programs that consume SavedModels, such as serving APIs.\n", "# See https://www.tensorflow.org/api_docs/python/tf/saved_model/save\n", "tf.keras.models.save_model(model, save_model_dir_multiply, signatures=signature)" - ], - "execution_count": 9, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "P2UMmbNW4YQV" + }, "source": [ "## Run the pipeline\n", "Use the following code to run the pipeline.\n", "\n", "* `FormatOutput` demonstrates how to extract values from the output protos.\n", "* `CreateModelHandler` demonstrates the model handler that needs to be passed into the Apache Beam RunInference API." - ], - "metadata": { - "id": "P2UMmbNW4YQV" - } + ] }, { "cell_type": "code", + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -488,72 +464,24 @@ "id": "PzjmXM_KvqHY", "outputId": "0aa60bef-52a0-4ce2-d228-3fac977d59e0" }, - "source": [ - "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n", - "\n", - "class FormatOutput(beam.DoFn):\n", - " def process(self, element: prediction_log_pb2.PredictionLog):\n", - " predict_log = element.predict_log\n", - " input_value = tf.train.Example.FromString(predict_log.request.inputs['examples'].string_val[0])\n", - " input_float_value = input_value.features.feature['x'].float_list.value[0]\n", - " output_value = predict_log.response.outputs\n", - " output_float_value = output_value['output_0'].float_val[0]\n", - " yield (f\"example is {input_float_value:.2f} prediction is {output_float_value:.2f}\")\n", - "\n", - "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n", - "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n", - "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n", - "model_handler = CreateModelHandler(inference_spec_type)\n", - "with beam.Pipeline() as p:\n", - " _ = (p | tfexample_beam_record.RawRecordBeamSource()\n", - " | RunInference(model_handler)\n", - " | beam.ParDo(FormatOutput())\n", - " | beam.Map(print)\n", - " )" - ], - "execution_count": 10, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n" ] }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:tensorflow:From /usr/local/lib/python3.9/dist-packages/tfx_bsl/beam/run_inference.py:615: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", @@ -562,8 +490,8 @@ ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "example is 20.00 prediction is 104.36\n", "example is 40.00 prediction is 202.62\n", @@ -571,10 +499,36 @@ "example is 90.00 prediction is 448.26\n" ] } + ], + "source": [ + "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n", + "\n", + "class FormatOutput(beam.DoFn):\n", + " def process(self, element: prediction_log_pb2.PredictionLog):\n", + " predict_log = element.predict_log\n", + " input_value = tf.train.Example.FromString(predict_log.request.inputs['examples'].string_val[0])\n", + " input_float_value = input_value.features.feature['x'].float_list.value[0]\n", + " output_value = predict_log.response.outputs\n", + " output_float_value = output_value['output_0'].float_val[0]\n", + " yield (f\"example is {input_float_value:.2f} prediction is {output_float_value:.2f}\")\n", + "\n", + "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n", + "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n", + "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n", + "model_handler = CreateModelHandler(inference_spec_type)\n", + "with beam.Pipeline() as p:\n", + " _ = (p | tfexample_beam_record.RawRecordBeamSource()\n", + " | RunInference(model_handler)\n", + " | beam.ParDo(FormatOutput())\n", + " | beam.Map(print)\n", + " )" ] }, { "cell_type": "markdown", + "metadata": { + "id": "IXikjkGdHm9n" + }, "source": [ "## Use `KeyedModelHandler` with `tfx-bsl`\n", "\n", @@ -584,13 +538,30 @@ "* If you don't know whether keys are associated with your examples, use `beam.MaybeKeyedModelHandler`.\n", "\n", "In addition to demonstrating how to use a keyed model handler, this step demonstrates how to use `tfx-bsl` examples." - ], - "metadata": { - "id": "IXikjkGdHm9n" - } + ] }, { "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KPtE3fmdJQry", + "outputId": "c33558fc-fb12-4c20-b828-b5520721f279" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "key 5.0 : example is 5.00 prediction is 30.67\n", + "key 50.0 : example is 50.00 prediction is 251.75\n", + "key 40.0 : example is 40.00 prediction is 202.62\n", + "key 100.0 : example is 100.00 prediction is 497.38\n" + ] + } + ], "source": [ "from apache_beam.ml.inference.base import KeyedModelHandler\n", "from google.protobuf import text_format\n", @@ -632,27 +603,32 @@ " | beam.ParDo(FormatOutputKeyed())\n", " | beam.Map(print)\n", " )" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KPtE3fmdJQry", - "outputId": "c33558fc-fb12-4c20-b828-b5520721f279" - }, - "execution_count": 11, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "key 5.0 : example is 5.00 prediction is 30.67\n", - "key 50.0 : example is 50.00 prediction is 251.75\n", - "key 40.0 : example is 40.00 prediction is 202.62\n", - "key 100.0 : example is 100.00 prediction is 497.38\n" - ] - } ] } - ] -} \ No newline at end of file + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "X80jy3FqHjK4", + "40qtP6zJuMXm", + "YzvZWEv-1oiK", + "rIwD_qEpX7Gu", + "O_a0-4Gb19cy", + "G-sAu3cf31f3", + "r4dpR6dQ4JwX", + "P2UMmbNW4YQV" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/notebooks/beam-ml/run_inference_vertex_ai.ipynb b/examples/notebooks/beam-ml/run_inference_vertex_ai.ipynb index 46bfc0f2fc00..2ab45e0491a7 100644 --- a/examples/notebooks/beam-ml/run_inference_vertex_ai.ipynb +++ b/examples/notebooks/beam-ml/run_inference_vertex_ai.ipynb @@ -151,6 +151,17 @@ "Replace `PROJECT_ID`, `LOCATION_NAME`, and `ENDPOINT_ID` with the ID of your project, the GCP region where your model is deployed, and the ID of your Vertex AI endpoint." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PROJECT_ID = \"\" # @param {type:'string'}\n", + "LOCATION_NAME = \"\" # @param {type:'string'}\n", + "ENDPOINT_ID = \"> beam.Create([IMG_URL])\n", " | beam.Map(lambda img_name: (img_name, download_image(img_name)))\n", diff --git a/examples/notebooks/beam-ml/run_inference_vllm.ipynb b/examples/notebooks/beam-ml/run_inference_vllm.ipynb new file mode 100644 index 000000000000..40eff1af5155 --- /dev/null +++ b/examples/notebooks/beam-ml/run_inference_vllm.ipynb @@ -0,0 +1,645 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OsFaZscKSPvo" + }, + "outputs": [], + "source": [ + "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", + "\n", + "# Licensed to the Apache Software Foundation (ASF) under one\n", + "# or more contributor license agreements. See the NOTICE file\n", + "# distributed with this work for additional information\n", + "# regarding copyright ownership. The ASF licenses this file\n", + "# to you under the Apache License, Version 2.0 (the\n", + "# \"License\"); you may not use this file except in compliance\n", + "# with the License. You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing,\n", + "# software distributed under the License is distributed on an\n", + "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + "# KIND, either express or implied. See the License for the\n", + "# specific language governing permissions and limitations\n", + "# under the License" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NrHRIznKp3nS" + }, + "source": [ + "# Run ML inference by using vLLM on GPUs\n", + "\n", + "\n", + " \n", + " \n", + "
\n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H0ZFs9rDvtJm" + }, + "source": [ + "[vLLM](https://github.com/vllm-project/vllm) is a fast and user-friendly library for LLM inference and serving. vLLM optimizes LLM inference with mechanisms like PagedAttention for memory management and continuous batching for increasing throughput. For popular models, vLLM has been shown to increase throughput by a multiple of 2 to 4. With Apache Beam, you can serve models with vLLM and scale that serving with just a few lines of code.\n", + "\n", + "This notebook demonstrates how to run machine learning inference by using vLLM and GPUs in three ways:\n", + "\n", + "* locally without Apache Beam\n", + "* locally with the Apache Beam local runner\n", + "* remotely with the Dataflow runner\n", + "\n", + "It also shows how to swap in a different model without modifying your pipeline structure by changing the configuration." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6x41tnbTvQM1" + }, + "source": [ + "## Requirements\n", + "\n", + "This notebook assumes that a GPU is enabled in Colab. If this setting isn't enabled, the locally executed sections of this notebook might not work. To enable a GPU, in the Colab menu, click **Runtime** > **Change runtime type**. For **Hardware accelerator**, choose a GPU accelerator. If you can't access a GPU in Colab, you can run the Dataflow section of this notebook.\n", + "\n", + "To run the Dataflow section, you need access to the following resources:\n", + "\n", + "- a computer with Docker installed\n", + "- a [Google Cloud](https://cloud.google.com/) account" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8PSjyDIavRcn" + }, + "source": [ + "## Install dependencies\n", + "\n", + "Before creating your pipeline, download and install the dependencies required to develop with Apache Beam and vLLM. vLLM is supported in Apache Beam versions 2.60.0 and later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "irCKNe42p22r" + }, + "outputs": [], + "source": [ + "!pip install openai>=1.52.2\n", + "!pip install vllm>=0.6.3\n", + "!pip install triton>=3.1.0\n", + "!pip install apache-beam[gcp]==2.61.0\n", + "!pip install nest_asyncio # only needed in colab\n", + "!pip check" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Colab only: allow nested asyncio\n", + "\n", + "The vLLM model handler logic below uses asyncio to feed vLLM records. This only works if we are not already in an asyncio event loop. Most of the time, this is fine, but colab already operates in an event loop. To work around this, we can use nest_asyncio to make things work smoothly in colab. Do not include this step outside of colab." + ], + "metadata": { + "id": "3xz8zuA7vcS3" + } + }, + { + "cell_type": "code", + "source": [ + "# This should not be necessary outside of colab.\n", + "import nest_asyncio\n", + "nest_asyncio.apply()\n" + ], + "metadata": { + "id": "sUqjOzw3wpI3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sUqjOzw3wpI4" + }, + "source": [ + "## Run locally without Apache Beam\n", + "\n", + "In this section, you run a vLLM server without using Apache Beam. Use the `facebook/opt-125m` model. This model is small enough to fit in Colab memory and doesn't require any extra authentication.\n", + "\n", + "First, start the vLLM server. This step might take a minute or two, because the model needs to download before vLLM starts running inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GbJGzINNt5sG" + }, + "outputs": [], + "source": [ + "! python -m vllm.entrypoints.openai.api_server --model facebook/opt-125m" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n35LXTS3uzIC" + }, + "source": [ + "Next, while the vLLM server is running, open a separate terminal to communicate with the vLLM serving process. To open a terminal in Colab, in the sidebar, click **Terminal**. In the terminal, run the following commands.\n", + "\n", + "```\n", + "pip install openai\n", + "python\n", + "\n", + "from openai import OpenAI\n", + "\n", + "# Modify OpenAI's API key and API base to use vLLM's API server.\n", + "openai_api_key = \"EMPTY\"\n", + "openai_api_base = \"http://localhost:8000/v1\"\n", + "client = OpenAI(\n", + " api_key=openai_api_key,\n", + " base_url=openai_api_base,\n", + ")\n", + "completion = client.completions.create(model=\"facebook/opt-125m\",\n", + " prompt=\"San Francisco is a\")\n", + "print(\"Completion result:\", completion)\n", + "```\n", + "\n", + "This code runs against the server running in the cell. You can experiment with different prompts." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hbxi83BfwbBa" + }, + "source": [ + "## Run locally with Apache Beam\n", + "\n", + "In this section, you set up an Apache Beam pipeline to run a job with an embedded vLLM instance.\n", + "\n", + "First, define the `VllmCompletionsModelHandler` object. This configuration object gives Apache Beam the information that it needs to create a dedicated vLLM process in the middle of the pipeline. Apache Beam then provides examples to the pipeline. No additional code is needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sUqjOzw3wpI4" + }, + "outputs": [], + "source": [ + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n", + "from apache_beam.ml.inference.base import PredictionResult\n", + "import apache_beam as beam\n", + "\n", + "model_handler = VLLMCompletionsModelHandler('facebook/opt-125m')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N06lXRKRxCz5" + }, + "source": [ + "Next, define examples to run inference against, and define a helper function to print out the inference results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3a1PznmtxNR_" + }, + "outputs": [], + "source": [ + "class FormatOutput(beam.DoFn):\n", + " def process(self, element, *args, **kwargs):\n", + " yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n", + "\n", + "prompts = [\n", + " \"Hello, my name is\",\n", + " \"The president of the United States is\",\n", + " \"The capital of France is\",\n", + " \"The future of AI is\",\n", + " \"Emperor penguins are\",\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Njl0QfrLxQ0m" + }, + "source": [ + "Finally, run the pipeline.\n", + "\n", + "This step might take a minute or two, because the model needs to download before Apache Beam can start running inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9yXbzV0ZmZcJ" + }, + "outputs": [], + "source": [ + "with beam.Pipeline() as p:\n", + " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", + " | RunInference(model_handler) # Send the prompts to the model and get responses.\n", + " | beam.ParDo(FormatOutput()) # Format the output.\n", + " | beam.Map(print) # Print the formatted output.\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Jv7be6Pk9Hlx" + }, + "source": [ + "## Run remotely on Dataflow\n", + "\n", + "After you validate that the pipeline can run against a vLLM locally, you can productionalize the workflow on a remote runner. This notebook runs the pipeline on the Dataflow runner." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J1LMrl1Yy6QB" + }, + "source": [ + "### Build a Docker image\n", + "\n", + "To run a pipeline with vLLM on Dataflow, you must create a Docker image that contains your dependencies and is compatible with a GPU runtime. For more information about building GPU compatible Dataflow containers, see [Build a custom container image](https://cloud.google.com/dataflow/docs/gpu/use-gpus#custom-container) in the Datafow documentation.\n", + "\n", + "First, define and save your Dockerfile. This file uses an Nvidia GPU-compatible base image. In the Dockerfile, install the Python dependencies needed to run the job.\n", + "\n", + "Before proceeding, make sure that your configuration meets the following requirements:\n", + "\n", + "- The Python version in the following cell matches the Python version defined in the Dockerfile.\n", + "- The Apache Beam version defined in your dependencies matches the Apache Beam version defined in the Dockerfile." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jCQ6-D55gqfl" + }, + "outputs": [], + "source": [ + "!python --version" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7QyNq_gygHLO" + }, + "outputs": [], + "source": [ + "cell_str='''\n", + "FROM nvidia/cuda:12.4.1-devel-ubuntu22.04\n", + "\n", + "RUN apt update\n", + "RUN apt install software-properties-common -y\n", + "RUN add-apt-repository ppa:deadsnakes/ppa\n", + "RUN apt update\n", + "RUN apt-get update\n", + "\n", + "ARG DEBIAN_FRONTEND=noninteractive\n", + "\n", + "RUN apt install python3.10-full -y\n", + "# RUN apt install python3.10-venv -y\n", + "# RUN apt install python3.10-dev -y\n", + "RUN rm /usr/bin/python3\n", + "RUN ln -s python3.10 /usr/bin/python3\n", + "RUN python3 --version\n", + "RUN apt-get install -y curl\n", + "RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && pip install --upgrade pip\n", + "\n", + "# Copy the Apache Beam worker dependencies from the Beam Python 3.10 SDK image.\n", + "COPY --from=apache/beam_python3.10_sdk:2.60.0 /opt/apache/beam /opt/apache/beam\n", + "\n", + "RUN pip install --no-cache-dir -vvv apache-beam[gcp]==2.60.0\n", + "RUN pip install openai>=1.52.2 vllm>=0.6.3 triton>=3.1.0\n", + "\n", + "RUN apt install libcairo2-dev pkg-config python3-dev -y\n", + "RUN pip install pycairo\n", + "\n", + "# Set the entrypoint to Apache Beam SDK worker launcher.\n", + "ENTRYPOINT [ \"/opt/apache/beam/boot\" ]\n", + "'''\n", + "\n", + "with open('VllmDockerfile', 'w') as f:\n", + " f.write(cell_str)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zWma0YetiEn5" + }, + "source": [ + "After you save the Dockerfile, build and push your Docker image. Because Docker is not accessible from Colab, you need to complete this step in a separate environment.\n", + "\n", + "1. In the sidebar, click **Files** to open the **Files** pane.\n", + "2. In an environment with Docker installed, download the file **VllmDockerfile** file to an empty folder.\n", + "3. Run the following commands. Replace `:` with a valid [Artifact Registry](https://cloud.google.com/artifact-registry/docs/overview) repository and tag.\n", + "\n", + " ```\n", + " docker build -t \":\" -f VllmDockerfile ./\n", + " docker image push \":\"\n", + " ```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NjZyRjte0g0Q" + }, + "source": [ + "### Define and run the pipeline\n", + "\n", + "When you have a working Docker image, define and run your pipeline.\n", + "\n", + "First, define the pipeline options that you want to use to launch the Dataflow job. Before running the next cell, replace the following variables:\n", + "\n", + "- ``: the name of a valid [Google Cloud Storage](https://cloud.google.com/storage?e=48754805&hl=en) bucket. Don't include a `gs://` prefix or trailing slashes.\n", + "- ``: the name of the Google Artifact Registry repository that you used in the previous step. \n", + "- ``: image tag used in the previous step. Prefer a versioned tag or SHA instead of :latest tag or mutable tags.\n", + "- ``: the name of the Google Cloud project that you created your bucket and Artifact Registry repository in.\n", + "\n", + "This workflow uses the following Dataflow service option: `worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx`. When you use this service option, Dataflow to installs a T4 GPU that uses a `5xx` series Nvidia driver on each worker machine. The 5xx driver is required to run vLLM jobs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kXy9FRYVCSjq" + }, + "outputs": [], + "source": [ + "\n", + "from apache_beam.options.pipeline_options import GoogleCloudOptions\n", + "from apache_beam.options.pipeline_options import PipelineOptions\n", + "from apache_beam.options.pipeline_options import SetupOptions\n", + "from apache_beam.options.pipeline_options import StandardOptions\n", + "from apache_beam.options.pipeline_options import WorkerOptions\n", + "\n", + "\n", + "options = PipelineOptions()\n", + "\n", + "# Replace with your bucket name.\n", + "BUCKET_NAME = '' # @param {type:'string'}\n", + "# Replace with the image repository and tag from the previous step.\n", + "CONTAINER_IMAGE = ':' # @param {type:'string'}\n", + "# Replace with your GCP project\n", + "PROJECT_NAME = '' # @param {type:'string'}\n", + "\n", + "options.view_as(GoogleCloudOptions).project = PROJECT_NAME\n", + "\n", + "# Provide required pipeline options for the Dataflow Runner.\n", + "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n", + "\n", + "# Set the Google Cloud region that you want to run Dataflow in.\n", + "options.view_as(GoogleCloudOptions).region = 'us-central1'\n", + "\n", + "# IMPORTANT: Replace BUCKET_NAME with the name of your Cloud Storage bucket.\n", + "dataflow_gcs_location = \"gs://%s/dataflow\" % BUCKET_NAME\n", + "\n", + "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", + "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n", + "\n", + "\n", + "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", + "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n", + "\n", + "# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.\n", + "options.view_as(GoogleCloudOptions).temp_location = '%s/temp' % dataflow_gcs_location\n", + "\n", + "# Enable GPU runtime. Make sure to enable 5xx driver since vLLM only works with 5xx drivers, not 4xx\n", + "options.view_as(GoogleCloudOptions).dataflow_service_options = [\"worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx\"]\n", + "\n", + "options.view_as(SetupOptions).save_main_session = True\n", + "\n", + "# Choose a machine type compatible with GPU type\n", + "options.view_as(WorkerOptions).machine_type = \"n1-standard-4\"\n", + "\n", + "options.view_as(WorkerOptions).sdk_container_image = CONTAINER_IMAGE" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xPhe597P1-QJ" + }, + "source": [ + "Next, authenticate Colab so that it can to submit a job on your behalf." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Xkf6yIVlFB8-" + }, + "outputs": [], + "source": [ + "def auth_to_colab():\n", + " from google.colab import auth\n", + " auth.authenticate_user()\n", + "\n", + "auth_to_colab()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MJtEI6Ux2eza" + }, + "source": [ + "Finally, run the pipeline on Dataflow. The pipeline definition is almost exactly the same as the definition used for local execution. The pipeline options are the only change to the pipeline.\n", + "\n", + "The following code creates a Dataflow job in your project. You can view the results in Colab or in the Google Cloud console. Creating a Dataflow job and downloading the model might take a few minutes. After the job starts performing inference, it quickly runs through the inputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8gjDdru_9Dii" + }, + "outputs": [], + "source": [ + "import logging\n", + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n", + "from apache_beam.ml.inference.base import PredictionResult\n", + "import apache_beam as beam\n", + "\n", + "class FormatOutput(beam.DoFn):\n", + " def process(self, element, *args, **kwargs):\n", + " yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n", + "\n", + "logging.getLogger().setLevel(logging.INFO) # Output additional Dataflow Job metadata and launch logs. \n", + "prompts = [\n", + " \"Hello, my name is\",\n", + " \"The president of the United States is\",\n", + " \"The capital of France is\",\n", + " \"The future of AI is\",\n", + " \"John cena is\",\n", + "]\n", + "\n", + "# Specify the model handler, providing a path and the custom inference function.\n", + "model_handler = VLLMCompletionsModelHandler('facebook/opt-125m')\n", + "\n", + "with beam.Pipeline(options=options) as p:\n", + " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", + " | RunInference(model_handler) # Send the prompts to the model and get responses.\n", + " | beam.ParDo(FormatOutput()) # Format the output.\n", + " | beam.Map(logging.info) # Print the formatted output.\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "22cEHPCc28fH" + }, + "source": [ + "## Run vLLM with a Gemma model\n", + "\n", + "After you configure your pipeline, switching the model used by the pipeline is relatively straightforward. You can run the same pipeline, but switch the model name defined in the model handler. This example runs the pipeline created previously but uses a [Gemma](https://ai.google.dev/gemma) model.\n", + "\n", + "Before you start, sign in to HuggingFace, and make sure that you can access the Gemma models. To access Gemma models, you must accept the terms and conditions.\n", + "\n", + "1. Navigate to the [Gemma Model Card](https://huggingface.co/google/gemma-2b).\n", + "2. Sign in, or sign up for a free HuggingFace account.\n", + "3. Follow the prompts to agree to the conditions\n", + "\n", + "When you complete these steps, the following message appears on the model card page: `You have been granted access to this model`.\n", + "\n", + "Next, sign in to your account from this notebook by running the following code and then following the prompts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JHwIsFI9kd9j" + }, + "outputs": [], + "source": [ + "! huggingface-cli login" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IjX2If8rnCol" + }, + "source": [ + "Verify that the notebook can now access the Gemma model. Run the following code, which starts a vLLM server to serve the Gemma 2b model. Because the default T4 Colab runtime doesn't support the full data type precision needed to run Gemma models, the `--dtype=half` parameter is required.\n", + "\n", + "When successful, the following cell runs indefinitely. After it starts the server process, you can shut it down. When the server process starts, the Gemma 2b model is successfully downloaded, and the server is ready to serve traffic." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LH_oCFWMiwFs" + }, + "outputs": [], + "source": [ + "! python -m vllm.entrypoints.openai.api_server --model google/gemma-2b --dtype=half" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "31BmdDUAn-SW" + }, + "source": [ + "To run the pipeline in Apache Beam, run the following code. Update the `VLLMCompletionsModelHandler` object with the new parameters, which match the command from the previous cell. Reuse all of the pipeline logic from the previous pipelines." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DyC2ikXg237p" + }, + "outputs": [], + "source": [ + "model_handler = VLLMCompletionsModelHandler('google/gemma-2b', vllm_server_kwargs={'dtype': 'half'})\n", + "\n", + "with beam.Pipeline() as p:\n", + " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", + " | RunInference(model_handler) # Send the prompts to the model and get responses.\n", + " | beam.ParDo(FormatOutput()) # Format the output.\n", + " | beam.Map(print) # Print the formatted output.\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "C6OYfub6ovFK" + }, + "source": [ + "### Run Gemma on Dataflow\n", + "\n", + "As a next step, run this pipeline on Dataflow. Follow the same steps described in the \"Run remotely on Dataflow\" section of this page:\n", + "\n", + "1. Construct a Dockerfile and push a new Docker image. You can use the same Dockerfile that you created previously, but you need to add a step to set your HuggingFace authentication key. In your Dockerfile, add the following line before the entrypoint:\n", + "\n", + " ```\n", + " RUN python3 -c 'from huggingface_hub import HfFolder; HfFolder.save_token(\"\")'\n", + " ```\n", + "\n", + "2. Set pipeline options. You can reuse the options defined in this notebook. Replace the Docker image location with your new Docker image.\n", + "3. Run the pipeline. Copy the pipeline that you ran on Dataflow, and replace the pipeline options with the pipeline options that you just defined.\n", + "\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/notebooks/beam-ml/vertex_ai_feature_store_enrichment.ipynb b/examples/notebooks/beam-ml/vertex_ai_feature_store_enrichment.ipynb index ebfcca34b94c..03feb96cbf68 100644 --- a/examples/notebooks/beam-ml/vertex_ai_feature_store_enrichment.ipynb +++ b/examples/notebooks/beam-ml/vertex_ai_feature_store_enrichment.ipynb @@ -197,8 +197,8 @@ }, "outputs": [], "source": [ - "PROJECT_ID = \"\"\n", - "LOCATION = \"\"" + "PROJECT_ID = \"\" # @param {type:'string'}\n", + "LOCATION = \"\" # @param {type:'string'}" ] }, { @@ -1790,10 +1790,10 @@ "outputs": [], "source": [ "# Replace with the name of your Pub/Sub topic.\n", - "TOPIC = \" \"\n", + "TOPIC = \"\" # @param {type:'string'}\n", "\n", "# Replace with the subscription path for your topic.\n", - "SUBSCRIPTION = \"\"" + "SUBSCRIPTION = \"\" # @param {type:'string'}" ] }, { diff --git a/examples/notebooks/healthcare/beam_nlp.ipynb b/examples/notebooks/healthcare/beam_nlp.ipynb index c2061bc4d75f..bbcbb6254024 100644 --- a/examples/notebooks/healthcare/beam_nlp.ipynb +++ b/examples/notebooks/healthcare/beam_nlp.ipynb @@ -1,25 +1,10 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" @@ -27,6 +12,12 @@ }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "lBuUTzxD2mvJ" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -46,16 +37,13 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "id": "lBuUTzxD2mvJ", - "cellView": "form" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "nEUAYCTx4Ijj" + }, "source": [ "# **Natural Language Processing Pipeline**\n", "\n", @@ -70,101 +58,103 @@ "For details about Apache Beam pipelines, including PTransforms and PCollections, visit the [Beam Programming Guide](https://beam.apache.org/documentation/programming-guide/).\n", "\n", "You'll be able to use this notebook to explore the data in each PCollection." - ], - "metadata": { - "id": "nEUAYCTx4Ijj" - } + ] }, { "cell_type": "markdown", - "source": [ - "First, lets install the necessary packages." - ], "metadata": { "id": "ZLBB0PTG5CHw" - } + }, + "source": [ + "First, lets install the necessary packages." + ] }, { "cell_type": "code", - "source": [ - "!pip install apache-beam[gcp]" - ], + "execution_count": null, "metadata": { "id": "O7hq2sse8K4u" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "!pip install apache-beam[gcp]" + ] }, { "cell_type": "markdown", - "source": [ - " **GCP Setup**" - ], "metadata": { "id": "5vQDhIv0E-LR" - } + }, + "source": [ + " **GCP Setup**" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "DGYiBYfxsSCw" + }, "source": [ "1. Authenticate your notebook by `gcloud auth application-default login` in the Colab terminal.\n", "\n", "2. Run `gcloud config set project `" - ], - "metadata": { - "id": "DGYiBYfxsSCw" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "D7lJqW2PRFcN" + }, "source": [ "Set the variables in the next cell based upon your project and preferences. The files referred to in this notebook nlpsample*.csv are in the format with one\n", "blurb of clinical note.\n", "\n", "Note that below, **us-central1** is hardcoded as the location. This is because of the limited number of [locations](https://cloud.google.com/healthcare-api/docs/how-tos/nlp) the API currently supports." - ], - "metadata": { - "id": "D7lJqW2PRFcN" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s9lhe5CZ5F3o" + }, + "outputs": [], "source": [ - "DATASET=\"\"\n", - "TEMP_LOCATION=\"\"\n", - "PROJECT=''\n", + "DATASET=\"\" # @param {type:'string'}\n", + "TEMP_LOCATION=\"\" # @param {type:'string'}\n", + "PROJECT=''# @param {type:'string'}\n", "LOCATION='us-central1'\n", "URL=f'https://healthcare.googleapis.com/v1/projects/{PROJECT}/locations/{LOCATION}/services/nlp:analyzeEntities'\n", "NLP_SERVICE=f'projects/{PROJECT}/locations/{LOCATION}/services/nlp'" - ], - "metadata": { - "id": "s9lhe5CZ5F3o" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Then, download [this raw CSV file](https://github.com/socd06/medical-nlp/blob/master/data/test.csv), and then upload it into Colab. You should be able to view this file (*test.csv*) in the \"Files\" tab in Colab after uploading." - ], "metadata": { "id": "1IArtEm8QuCR" - } + }, + "source": [ + "Then, download [this raw CSV file](https://github.com/socd06/medical-nlp/blob/master/data/test.csv), and then upload it into Colab. You should be able to view this file (*test.csv*) in the \"Files\" tab in Colab after uploading." + ] }, { "cell_type": "markdown", + "metadata": { + "id": "DI_Qkyn75LO-" + }, "source": [ "**BigQuery Setup**\n", "\n", "We will be using BigQuery to warehouse the structured data revealed in the output of the Healthcare NLP API. For this purpose, we create 3 tables to organize the data. Specifically, these will be table entities, table relations, and table entity mentions, which are all outputs of interest from the Healthcare NLP API." - ], - "metadata": { - "id": "DI_Qkyn75LO-" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bZDqtFVE5Wd_" + }, + "outputs": [], "source": [ "from google.cloud import bigquery\n", "\n", @@ -198,15 +188,15 @@ "print(\n", " \"Created table {}.{}.{}\".format(table.project, table.dataset_id, table.table_id)\n", ")" - ], - "metadata": { - "id": "bZDqtFVE5Wd_" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YK-G7uV5APuP" + }, + "outputs": [], "source": [ "from google.cloud import bigquery\n", "\n", @@ -240,15 +230,15 @@ ")\n", "\n", "\n" - ], - "metadata": { - "id": "YK-G7uV5APuP" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R9IHgZKoAQWj" + }, + "outputs": [], "source": [ "from google.cloud import bigquery\n", "\n", @@ -324,26 +314,26 @@ "print(\n", " \"Created table {}.{}.{}\".format(table.project, table.dataset_id, table.table_id)\n", ")" - ], - "metadata": { - "id": "R9IHgZKoAQWj" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "jc_iS_BP5aS4" + }, "source": [ "**Pipeline Setup**\n", "\n", "We will use InteractiveRunner in this notebook." - ], - "metadata": { - "id": "jc_iS_BP5aS4" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "07ct6kf55ihP" + }, + "outputs": [], "source": [ "# Python's regular expression library\n", "import re\n", @@ -365,24 +355,24 @@ " job_name=\"my-healthcare-nlp-job\",\n", " temp_location=TEMP_LOCATION,\n", " region=LOCATION)" - ], - "metadata": { - "id": "07ct6kf55ihP" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "The following defines a `PTransform` named `ReadLinesFromText`, that extracts lines from a file." - ], "metadata": { "id": "dO1A9_WK5lb4" - } + }, + "source": [ + "The following defines a `PTransform` named `ReadLinesFromText`, that extracts lines from a file." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t5iDRKMK5n_B" + }, + "outputs": [], "source": [ "class ReadLinesFromText(beam.PTransform):\n", "\n", @@ -392,74 +382,73 @@ " def expand(self, pcoll):\n", " return (pcoll.pipeline\n", " | beam.io.ReadFromText(self._file_pattern))" - ], - "metadata": { - "id": "t5iDRKMK5n_B" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "The following sets up an Apache Beam pipeline with the *Interactive Runner*. The *Interactive Runner* is the runner suitable for running in notebooks. A runner is an execution engine for Apache Beam pipelines." - ], "metadata": { "id": "HI_HVB185sMQ" - } + }, + "source": [ + "The following sets up an Apache Beam pipeline with the *Interactive Runner*. The *Interactive Runner* is the runner suitable for running in notebooks. A runner is an execution engine for Apache Beam pipelines." + ] }, { "cell_type": "code", - "source": [ - "p = beam.Pipeline(options = options)" - ], + "execution_count": null, "metadata": { "id": "7osCZ1om5ql0" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "p = beam.Pipeline(options = options)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "EaF8NfC_521y" + }, "source": [ "The following sets up a PTransform that extracts words from a Google Cloud Storage file that contains lines with each line containing a In our example, each line is a medical notes excerpt that will be passed through the Healthcare NLP API\n", "\n", "**\"|\"** is an overloaded operator that applies a PTransform to a PCollection to produce a new PCollection. Together with |, >> allows you to optionally name a PTransform.\n", "\n", "Usage:[PCollection] | [PTransform], **or** [PCollection] | [name] >> [PTransform]" - ], - "metadata": { - "id": "EaF8NfC_521y" - } + ] }, { "cell_type": "code", - "source": [ - "lines = p | 'read' >> ReadLinesFromText(\"test.csv\")" - ], + "execution_count": null, "metadata": { - "id": "2APAh6XQ6NYd", "colab": { "base_uri": "https://localhost:8080/", "height": 72 }, + "id": "2APAh6XQ6NYd", "outputId": "033c5110-fd5a-4da0-b59b-801a1ce9d3b1" }, - "execution_count": null, - "outputs": [ + "outputs": [], + "source": [ + "lines = p | 'read' >> ReadLinesFromText(\"test.csv\")" ] }, { "cell_type": "markdown", - "source": [ - "We then write a **DoFn** that will invoke the [NLP API](https://cloud.google.com/healthcare-api/docs/how-tos/nlp)." - ], "metadata": { "id": "vM_FbhkbGI-E" - } + }, + "source": [ + "We then write a **DoFn** that will invoke the [NLP API](https://cloud.google.com/healthcare-api/docs/how-tos/nlp)." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3ZJ-0dex9WE5" + }, + "outputs": [], "source": [ "class InvokeNLP(beam.DoFn):\n", "\n", @@ -486,24 +475,24 @@ " pcoll\n", " | \"Invoke NLP API\" >> beam.ParDo(InvokeNLP())\n", " )" - ], - "metadata": { - "id": "3ZJ-0dex9WE5" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "From our elements, being processed, we will get the entity mentions, relationships, and entities respectively." - ], "metadata": { "id": "TeYxIlNgGdK0" - } + }, + "source": [ + "From our elements, being processed, we will get the entity mentions, relationships, and entities respectively." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3KZgUv3d6haf" + }, + "outputs": [], "source": [ "import json\n", "from apache_beam import pvalue\n", @@ -529,15 +518,15 @@ " for e in element['entityMentions']:\n", " e['id'] = element['id']\n", " yield e\n" - ], - "metadata": { - "id": "3KZgUv3d6haf" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OkxgB2a-6iYN" + }, + "outputs": [], "source": [ "from apache_beam.io.gcp.internal.clients import bigquery\n", "\n", @@ -550,24 +539,24 @@ "nlp_annotations = (lines\n", " | \"Analyze\" >> AnalyzeLines()\n", " )\n" - ], - "metadata": { - "id": "OkxgB2a-6iYN" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "We then write these results to [BigQuery](https://cloud.google.com/bigquery), a cloud data warehouse." - ], "metadata": { "id": "iTh65CXIGoQn" - } + }, + "source": [ + "We then write these results to [BigQuery](https://cloud.google.com/bigquery), a cloud data warehouse." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q9GIyLeS6oAe" + }, + "outputs": [], "source": [ "resultsEntities = ( nlp_annotations\n", " | \"Break\" >> beam.ParDo(breakUpEntities())\n", @@ -576,15 +565,15 @@ " write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,\n", " create_disposition=beam.io.BigQueryDisposition.CREATE_NEVER)\n", " )" - ], - "metadata": { - "id": "Q9GIyLeS6oAe" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yOlHfkcT6s4y" + }, + "outputs": [], "source": [ "table_spec = bigquery.TableReference(\n", " projectId=PROJECT,\n", @@ -598,15 +587,15 @@ " write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,\n", " create_disposition=beam.io.BigQueryDisposition.CREATE_NEVER)\n", " )" - ], - "metadata": { - "id": "yOlHfkcT6s4y" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a6QxxnY890Za" + }, + "outputs": [], "source": [ "table_spec = bigquery.TableReference(\n", " projectId=PROJECT,\n", @@ -620,43 +609,31 @@ " write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,\n", " create_disposition=beam.io.BigQueryDisposition.CREATE_NEVER)\n", " )" - ], - "metadata": { - "id": "a6QxxnY890Za" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "You can see the job graph for the pipeline by doing:" - ], "metadata": { "id": "6rP2nO6Z60bt" - } + }, + "source": [ + "You can see the job graph for the pipeline by doing:" + ] }, { "cell_type": "code", - "source": [ - "ib.show_graph(p)" - ], + "execution_count": null, "metadata": { - "id": "zQB5h1Zq6x8d", "colab": { "base_uri": "https://localhost:8080/", "height": 806 }, + "id": "zQB5h1Zq6x8d", "outputId": "7885e493-fee8-402e-baf2-cbbf406a3eb9" }, - "execution_count": null, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -665,16 +642,16 @@ " Processing... show_graph\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", "\n", "\n", "\n" + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fa6997b180fa86966dd888a7d59a34f7\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fa6997b180fa86966dd888a7d59a34f7\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fa6997b180fa86966dd888a7d59a34f7\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fa6997b180fa86966dd888a7d59a34f7\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } + ], + "source": [ + "ib.show_graph(p)" ] } - ] + ], + "metadata": { + "colab": { + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/gradle.properties b/gradle.properties index f2a0b05eca09..3923dc204272 100644 --- a/gradle.properties +++ b/gradle.properties @@ -30,8 +30,8 @@ signing.gnupg.useLegacyGpg=true # buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy. # To build a custom Beam version make sure you change it in both places, see # https://github.com/apache/beam/issues/21302. -version=2.61.0-SNAPSHOT -sdk_version=2.61.0.dev +version=2.62.0-SNAPSHOT +sdk_version=2.62.0.dev javaVersion=1.8 @@ -39,6 +39,6 @@ docker_image_default_repo_root=apache docker_image_default_repo_prefix=beam_ # supported flink versions -flink_versions=1.15,1.16,1.17,1.18 +flink_versions=1.17,1.18,1.19 # supported python versions -python_versions=3.8,3.9,3.10,3.11,3.12 +python_versions=3.9,3.10,3.11,3.12 diff --git a/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/WordCountIT.java b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/WordCountIT.java index 511c322b8c4f..b561c0c71fb6 100644 --- a/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/WordCountIT.java +++ b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/WordCountIT.java @@ -19,15 +19,19 @@ import static org.apache.beam.it.truthmatchers.PipelineAsserts.assertThatPipeline; import static org.apache.beam.it.truthmatchers.PipelineAsserts.assertThatResult; +import static org.apache.beam.sdk.util.construction.resources.PipelineResources.detectClassPathResourcesToStage; import java.io.IOException; import java.time.Duration; import java.util.Arrays; +import java.util.List; import org.apache.beam.it.common.PipelineLauncher.LaunchConfig; import org.apache.beam.it.common.PipelineLauncher.LaunchInfo; import org.apache.beam.it.common.PipelineLauncher.Sdk; import org.apache.beam.it.common.PipelineOperator.Result; +import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Filter; @@ -67,6 +71,29 @@ public void testWordCountDataflow() throws IOException { assertThatResult(result).isLaunchFinished(); } + @Test + public void testWordCountDataflowWithGCSFilesToStage() throws IOException { + + PipelineOptions pipelineOptions = wcPipeline.getOptions(); + List filesToStage = + detectClassPathResourcesToStage(DataflowRunner.class.getClassLoader(), pipelineOptions); + filesToStage.add("gs://apache-beam-samples/shakespeare/kinglear.txt"); + + LaunchConfig options = + LaunchConfig.builder("test-wordcount") + .setSdk(Sdk.JAVA) + .setPipeline(wcPipeline) + .addParameter("runner", "DataflowRunner") + .addParameter("filesToStage", String.join(",", filesToStage)) + .build(); + + LaunchInfo launchInfo = pipelineLauncher.launch(project, region, options); + assertThatPipeline(launchInfo).isRunning(); + Result result = + pipelineOperator.waitUntilDone(createConfig(launchInfo, Duration.ofMinutes(20))); + assertThatResult(result).isLaunchFinished(); + } + /** Build WordCount pipeline. */ private void buildPipeline() { wcPipeline diff --git a/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOST.java b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOST.java index 22ff94e293b6..ddb300d74f66 100644 --- a/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOST.java +++ b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOST.java @@ -84,6 +84,7 @@ public final class BigQueryIOST extends IOStressTestBase { private static final String READ_ELEMENT_METRIC_NAME = "read_count"; private static final String STORAGE_WRITE_API_METHOD = "STORAGE_WRITE_API"; private static final String STORAGE_API_AT_LEAST_ONCE_METHOD = "STORAGE_API_AT_LEAST_ONCE"; + private static final double STORAGE_API_AT_LEAST_ONCE_MAX_ALLOWED_DIFFERENCE_FRACTION = 0.00001; private static BigQueryResourceManager resourceManager; private static String tableName; @@ -334,11 +335,14 @@ private void generateDataAndWrite(BigQueryIO.Write writeIO) throws IOExc // Depending on writing method there might be duplicates on different sides (read or write). if (configuration.writeMethod.equals(STORAGE_API_AT_LEAST_ONCE_METHOD)) { + long allowedDifference = + (long) (numRecords * STORAGE_API_AT_LEAST_ONCE_MAX_ALLOWED_DIFFERENCE_FRACTION); + long actualDifference = (long) numRecords - rowCount; assertTrue( String.format( - "Number of rows in the table (%d) is less than the expected number (%d). Missing records: %d", - rowCount, (long) numRecords, (long) numRecords - rowCount), - rowCount >= numRecords); + "Row difference (%d) exceeds the limit of %d. Rows: %d, Expected: %d", + actualDifference, allowedDifference, rowCount, (long) numRecords), + actualDifference <= allowedDifference); } else { assertTrue( String.format( diff --git a/learning/tour-of-beam/learning-content/common-transforms/aggregation/count/description.md b/learning/tour-of-beam/learning-content/common-transforms/aggregation/count/description.md index 43ab5503240c..60fd0cc9f216 100644 --- a/learning/tour-of-beam/learning-content/common-transforms/aggregation/count/description.md +++ b/learning/tour-of-beam/learning-content/common-transforms/aggregation/count/description.md @@ -238,11 +238,11 @@ PCollection> input = pipeline.apply( And replace `Count.globally` with `Count.perKey` it will output the count numbers by key. It is also necessary to replace the generic type: ``` -PCollection> output = applyTransform(input); +PCollection> output = applyTransform(input); ``` ``` -static PCollection> applyTransform(PCollection> input) { +static PCollection> applyTransform(PCollection> input) { return input.apply(Count.globally()); } ``` diff --git a/learning/tour-of-beam/learning-content/common-transforms/aggregation/max/description.md b/learning/tour-of-beam/learning-content/common-transforms/aggregation/max/description.md index 92a0ace73b4d..9b8cfbea8b11 100644 --- a/learning/tour-of-beam/learning-content/common-transforms/aggregation/max/description.md +++ b/learning/tour-of-beam/learning-content/common-transforms/aggregation/max/description.md @@ -42,11 +42,11 @@ func ApplyTransform(s beam.Scope, input beam.PCollection) beam.PCollection { ``` {{end}} {{if (eq .Sdk "java")}} -You can find the global maximum value from the `PCollection` by using `Max.doublesGlobally()` +You can find the global maximum value from the `PCollection` by using `Max.integersGlobally()` ``` PCollection input = pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); -PCollection max = input.apply(Max.doublesGlobally()); +PCollection max = input.apply(Max.integersGlobally()); ``` Output diff --git a/learning/tour-of-beam/learning-content/common-transforms/aggregation/min/description.md b/learning/tour-of-beam/learning-content/common-transforms/aggregation/min/description.md index 1343b9d8c85f..138c8aef640e 100644 --- a/learning/tour-of-beam/learning-content/common-transforms/aggregation/min/description.md +++ b/learning/tour-of-beam/learning-content/common-transforms/aggregation/min/description.md @@ -41,11 +41,11 @@ func ApplyTransform(s beam.Scope, input beam.PCollection) beam.PCollection { ``` {{end}} {{if (eq .Sdk "java")}} -You can find the global minimum value from the `PCollection` by using `Min.doublesGlobally()` +You can find the global minimum value from the `PCollection` by using `Min.integersGlobally()` ``` PCollection input = pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); -PCollection min = input.apply(Min.doublesGlobally()); +PCollection min = input.apply(Min.integersGlobally()); ``` Output @@ -165,7 +165,7 @@ PCollection> output = applyTransform(input); ``` static PCollection> applyTransform(PCollection> input) { - return input.apply(Sum.integersPerKey()); + return input.apply(Min.integersPerKey()); } ``` {{end}} diff --git a/learning/tour-of-beam/learning-content/common-transforms/filter/description.md b/learning/tour-of-beam/learning-content/common-transforms/filter/description.md index 7a5b9522926d..96f4b549625b 100644 --- a/learning/tour-of-beam/learning-content/common-transforms/filter/description.md +++ b/learning/tour-of-beam/learning-content/common-transforms/filter/description.md @@ -51,7 +51,7 @@ world ### Built-in filters -The Java SDK has several filter methods built-in, like `Filter.greaterThan` and `Filter.lessThen` With `Filter.greaterThan`, the input `PCollection` can be filtered so that only the elements whose values are greater than the specified amount remain. Similarly, you can use `Filter.lessThen` to filter out elements of the input `PCollection` whose values are greater than the specified amount. +The Java SDK has several filter methods built-in, like `Filter.greaterThan` and `Filter.lessThan` With `Filter.greaterThan`, the input `PCollection` can be filtered so that only the elements whose values are greater than the specified amount remain. Similarly, you can use `Filter.lessThan` to filter out elements of the input `PCollection` whose values are greater than the specified amount. Other built-in filters are: @@ -62,7 +62,7 @@ Other built-in filters are: * Filter.equal -## Example 2: Filtering with a built-in methods +## Example 2: Filtering with built-in methods ``` // List of integers @@ -404,11 +404,3 @@ func applyTransform(s beam.Scope, input beam.PCollection) beam.PCollection { }) } ``` - -### Playground exercise - -You can find the complete code of the above example using 'Filter' in the playground window, which you can run and experiment with. - -Filter transform can be used with both text and numerical collection. For example, let's try filtering the input collection that contains words so that only words that start with the letter 'a' are returned. - -You can also chain several filter transforms to form more complex filtering based on several simple filters or implement more complex filtering logic within a single filter transform. For example, try both approaches to filter the same list of words such that only ones that start with a letter 'a' (regardless of the case) and containing more than three symbols are returned. diff --git a/learning/tour-of-beam/learning-content/core-transforms/additional-outputs/description.md b/learning/tour-of-beam/learning-content/core-transforms/additional-outputs/description.md index d228e6537498..7c4922314521 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/additional-outputs/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/additional-outputs/description.md @@ -104,7 +104,7 @@ func extractWordsFn(pn beam.PaneInfo, line string, emitWords func(string)) { ``` {{end}} {{if (eq .Sdk "java")}} -While `ParDo` always outputs the main output of `PCollection` (as a return value from apply), you can also force your `ParDo` to output any number of additional `PCollection` outputs. If you decide to have multiple outputs, your `ParDo` will return all the `PCollection` output (including the main output) combined. This will be useful when you are working with big data or a database that needs to be divided into different collections. You get a combined `PCollectionTuple`, you can use `TupleTag` to get a `PCollection`. +While `ParDo` always outputs the main output of `PCollection` (as a return value from apply), you can also force your `ParDo` to output any number of additional `PCollection` outputs. If you decide to have multiple outputs, your `ParDo` will return all the `PCollection` outputs (including the main output) combined. This will be useful when you are working with big data or a database that needs to be divided into different collections. You get a combined `PCollectionTuple`, you can use `TupleTag` to get a `PCollection`. A `PCollectionTuple` is an immutable tuple of heterogeneously typed `PCollection`, "with keys" `TupleTags`. A `PCollectionTuple` can be used as input or output for `PTransform` receiving or creating multiple `PCollection` inputs or outputs, which can be of different types, for example, `ParDo` with multiple outputs. @@ -202,7 +202,7 @@ tens = results[None] # the undeclared main output You can find the full code of this example in the playground window, which you can run and experiment with. -The `applyTransform()` accepts a list of integers at the output two `PCollection` one `PCollection` above 100 and second below 100. +The `applyTransform()` accepts a list of integers and outputs two `PCollections`: one `PCollection` above 100 and second below 100. You can also work with strings: {{if (eq .Sdk "go")}} diff --git a/learning/tour-of-beam/learning-content/core-transforms/branching/description.md b/learning/tour-of-beam/learning-content/core-transforms/branching/description.md index a01022de97ab..8adff770cdd5 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/branching/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/branching/description.md @@ -86,7 +86,7 @@ starts_with_b = input | beam.Filter(lambda x: x.startswith('B')) You can find the full code of this example in the playground window, which you can run and experiment with. -Accepts a `PCollection` consisting of strings. Without modification, it returns a new "PCollection". In this case, one `PCollection` includes elements in uppercase. The other `PCollection' stores inverted elements. +Accepts a `PCollection` consisting of strings. Without modification, it returns a new `PCollection`. In this case, one `PCollection` includes elements in uppercase. The other `PCollection` stores inverted elements. You can use a different method of branching. Since `applyTransforms` performs 2 conversions, it takes a lot of time. It is possible to convert `PCollection` separately. {{if (eq .Sdk "go")}} diff --git a/learning/tour-of-beam/learning-content/core-transforms/combine/combine-per-key/description.md b/learning/tour-of-beam/learning-content/core-transforms/combine/combine-per-key/description.md index 1f44151337dd..af6f9e7aa197 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/combine/combine-per-key/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/combine/combine-per-key/description.md @@ -14,9 +14,9 @@ limitations under the License. # CombinePerKey -CombinePerKey is a transform in Apache Beam that applies a `CombineFn` function to each key of a PCollection of key-value pairs. The `CombineFn` function can be used to aggregate, sum, or combine the values associated with each key in the input `PCollection`. +`CombinePerKey` is a transform in Apache Beam that applies a `CombineFn` function to each key of a `PCollection` of key-value pairs. The `CombineFn` function can be used to aggregate, sum, or combine the values associated with each key in the input `PCollection`. -The `CombinePerKey` transform takes in an instance of a `CombineFn` class and applies it to the input `PCollection`. The output of the transform is a new PCollection where each element is a key-value pair, where the key is the same as the input key, and the value is the result of applying the `CombineFn` function to all the values associated with that key in the input `PCollection`. +The `CombinePerKey` transform takes in an instance of a `CombineFn` class and applies it to the input `PCollection`. The output of the transform is a new `PCollection` where each element is a key-value pair, where the key is the same as the input key, and the value is the result of applying the `CombineFn` function to all the values associated with that key in the input `PCollection`. {{if (eq .Sdk "go")}} ``` @@ -81,7 +81,7 @@ func applyTransform(s beam.Scope, input beam.PCollection) beam.PCollection { {{if (eq .Sdk "java")}} ``` PCollection> input = pipeline - .apply("ParseCitiesToTimeKV", Create.of( + .apply(Create.of( KV.of("a", "apple"), KV.of("o", "orange"), KV.of("a", "avocado), @@ -93,7 +93,7 @@ static PCollection> applyTransform(PCollection { + static class SumStringBinaryCombineFn extends BinaryCombineFn { @Override public String apply(String left, String right) { diff --git a/learning/tour-of-beam/learning-content/core-transforms/combine/simple-function/description.md b/learning/tour-of-beam/learning-content/core-transforms/combine/simple-function/description.md index 8eda136e7931..1946265fa66c 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/combine/simple-function/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/combine/simple-function/description.md @@ -14,7 +14,7 @@ limitations under the License. # Combine -`Combine` is a Beam transform for combining collections of elements or values in your data. Combine has variants that work on entire PCollections, and some that combine the values for each key in `PCollections` of **key/value** pairs. +`Combine` is a Beam transform for combining collections of elements or values in your data. Combine has variants that work on entire `PCollections`, and some that combine the values for each key in `PCollections` of **key/value** pairs. When you apply a `Combine` transform, you must provide the function that contains the logic for combining the elements or values. The combining function should be commutative and associative, as the function is not necessarily invoked exactly once on all values with a given key. Because the input data (including the value collection) may be distributed across multiple workers, the combining function might be called multiple times to perform partial combining on subsets of the value collection. The Beam SDK also provides some pre-built combine functions for common numeric combination operations such as sum, min, and max. diff --git a/learning/tour-of-beam/learning-content/core-transforms/composite/description.md b/learning/tour-of-beam/learning-content/core-transforms/composite/description.md index 774f5deae924..c61dcee8950b 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/composite/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/composite/description.md @@ -195,7 +195,7 @@ func extractWords(s beam.Scope, input beam.PCollection) beam.PCollection { } ``` -You can use other transformations you can replace `Count` with `Filter` to output words starting with **p**: +You can use other transformations, i.e. you can replace `Count` with `Filter` to output words starting with **p**: ``` func applyTransform(s beam.Scope, input beam.PCollection) beam.PCollection { @@ -224,7 +224,7 @@ PCollection words = input })); ``` -You can use other transformations you can replace `Count` with `Filter` to output words starting with **p**: +You can use other transformations, i.e. you can replace `Count` with `Filter` to output words starting with **p**: ``` PCollection filtered = input @@ -252,7 +252,7 @@ PCollection filtered = input words = input | 'ExtractWords' >> beam.FlatMap(lambda line: [word for word in line.split() if word]) ``` -You can use other transformations you can replace `Count` with `Filter` to output words starting with **p**: +You can use other transformations, i.e. you can replace `Count` with `Filter` to output words starting with **p**: ``` filtered = (input | 'ExtractNonSpaceCharacters' >> beam.FlatMap(lambda line: [word for word in line.split() if word]) diff --git a/learning/tour-of-beam/learning-content/core-transforms/flatten/description.md b/learning/tour-of-beam/learning-content/core-transforms/flatten/description.md index 19618f02c9e3..db9cbc31dc65 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/flatten/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/flatten/description.md @@ -56,7 +56,7 @@ By default, the coder for the output `PCollection` is the same as the coder for When using `Flatten` to merge `PCollection` objects that have a windowing strategy applied, all of the `PCollection` objects you want to merge must use a compatible windowing strategy and window sizing. For example, all the collections you’re merging must all use (hypothetically) identical 5-minute fixed windows or 4-minute sliding windows starting every 30 seconds. -If your pipeline attempts to use `Flatten` to merge `PCollection` objects with incompatible windows, Beam generates an IllegalStateException error when your pipeline is constructed. +If your pipeline attempts to use `Flatten` to merge `PCollection` objects with incompatible windows, Beam generates an `IllegalStateException` error when your pipeline is constructed. ### Playground exercise diff --git a/learning/tour-of-beam/learning-content/core-transforms/map/co-group-by-key/description.md b/learning/tour-of-beam/learning-content/core-transforms/map/co-group-by-key/description.md index a15321cd8ea3..739f43f96201 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/map/co-group-by-key/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/map/co-group-by-key/description.md @@ -75,7 +75,7 @@ func formatCoGBKResults(key string, emailIter, phoneIter func(*string) bool) str {{if (eq .Sdk "java")}} You can use the `CoGroupByKey` transformation for a tuple of tables. `CoGroupByKey` groups results from all tables by similar keys in `CoGbkResults`, from which results for any particular table can be accessed using the `TupleTag` tag supplied with the source table. -For type safety, the Jav SDK requires you to pass each `PCollection` as part of a `KeyedPCollectionTuple`. You must declare a `TupleTag` for each input `PCollection` in the `KeyedPCollectionTuple` that you want to pass to `CoGroupByKey`. As output, `CoGroupByKey` returns a `PCollection>`, which groups values from all the input `PCollections` by their common keys. Each key (all of type K) will have a different `CoGbkResult`, which is a map from `TupleTag to Iterable`. You can access a specific collection in an `CoGbkResult` object by using the `TupleTag` that you supplied with the initial collection. +For type safety, the Java SDK requires you to pass each `PCollection` as part of a `KeyedPCollectionTuple`. You must declare a `TupleTag` for each input `PCollection` in the `KeyedPCollectionTuple` that you want to pass to `CoGroupByKey`. As output, `CoGroupByKey` returns a `PCollection>`, which groups values from all the input `PCollections` by their common keys. Each key (all of type K) will have a different `CoGbkResult`, which is a map from `TupleTag to Iterable`. You can access a specific collection in an `CoGbkResult` object by using the `TupleTag` that you supplied with the initial collection. ``` // Mock data diff --git a/learning/tour-of-beam/learning-content/core-transforms/map/group-by-key/description.md b/learning/tour-of-beam/learning-content/core-transforms/map/group-by-key/description.md index c6041905a2d8..d8c396bf3a75 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/map/group-by-key/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/map/group-by-key/description.md @@ -11,7 +11,7 @@ limitations under the License. --> # GroupByKey -`GroupByKey` is a transform that is used to group elements in a `PCollection` by key. The input to `GroupByKey` is a `PCollection` of key-value pairs, where the keys are used to group the elements. The output of `GroupByKey` is a PCollection of key-value pairs, where the keys are the same as the input, and the values are lists of all the elements with that key. +`GroupByKey` is a transform that is used to group elements in a `PCollection` by key. The input to `GroupByKey` is a `PCollection` of key-value pairs, where the keys are used to group the elements. The output of `GroupByKey` is a `PCollection` of key-value pairs, where the keys are the same as the input, and the values are lists of all the elements with that key. Let’s examine the mechanics of `GroupByKey` with a simple example case, where our data set consists of words from a text file and the line number on which they appear. We want to group together all the line numbers (values) that share the same word (key), letting us see all the places in the text where a particular word appears. @@ -30,7 +30,7 @@ cat, 9 and, 6 ``` -`GroupByKey` gathers up all the values with the same key and outputs a new pair consisting of the unique key and a collection of all of the values that were associated with that key in the input collection. If we apply GroupByKey to our input collection above, the output collection would look like this: +`GroupByKey` gathers up all the values with the same key and outputs a new pair consisting of the unique key and a collection of all of the values that were associated with that key in the input collection. If we apply `GroupByKey` to our input collection above, the output collection would look like this: ``` cat, [1,5,9] @@ -61,11 +61,11 @@ PCollection> input = ...; // Apply GroupByKey to the PCollection input. // Save the result as the PCollection reduced. -PCollection>> reduced = mapped.apply(GroupByKey.create()); +PCollection>> reduced = input.apply(GroupByKey.create()); ``` {{end}} {{if (eq .Sdk "python")}} -While all SDKs have a GroupByKey transform, using GroupBy is generally more natural. The `GroupBy` transform can be parameterized by the name(s) of properties on which to group the elements of the PCollection, or a function taking the each element as input that maps to a key on which to do grouping. +While all SDKs have a `GroupByKey` transform, using `GroupBy` is generally more natural. The `GroupBy` transform can be parameterized by the name(s) of properties on which to group the elements of the `PCollection`, or a function taking the each element as input that maps to a key on which to do grouping. ``` input = ... @@ -77,11 +77,11 @@ grouped_words = input | beam.GroupByKey() If you are using unbounded `PCollections`, you must use either non-global windowing or an aggregation trigger in order to perform a `GroupByKey` or `CoGroupByKey`. This is because a bounded `GroupByKey` or `CoGroupByKey` must wait for all the data with a certain key to be collected, but with unbounded collections, the data is unlimited. Windowing and/or triggers allow grouping to operate on logical, finite bundles of data within the unbounded data streams. -If you do apply `GroupByKey` or `CoGroupByKey` to a group of unbounded `PCollections` without setting either a non-global windowing strategy, a trigger strategy, or both for each collection, Beam generates an IllegalStateException error at pipeline construction time. +If you do apply `GroupByKey` or `CoGroupByKey` to a group of unbounded `PCollections` without setting either a non-global windowing strategy, a trigger strategy, or both for each collection, Beam generates an `IllegalStateException` error at pipeline construction time. -When using `GroupByKey` or `CoGroupByKey` to group PCollections that have a windowing strategy applied, all of the `PCollections` you want to group must use the same windowing strategy and window sizing. For example, all the collections you are merging must use (hypothetically) identical 5-minute fixed windows, or 4-minute sliding windows starting every 30 seconds. +When using `GroupByKey` or `CoGroupByKey` to group `PCollections` that have a windowing strategy applied, all of the `PCollections` you want to group must use the same windowing strategy and window sizing. For example, all the collections you are merging must use (hypothetically) identical 5-minute fixed windows, or 4-minute sliding windows starting every 30 seconds. -If your pipeline attempts to use `GroupByKey` or `CoGroupByKey` to merge `PCollections` with incompatible windows, Beam generates an IllegalStateException error at pipeline construction time. +If your pipeline attempts to use `GroupByKey` or `CoGroupByKey` to merge `PCollections` with incompatible windows, Beam generates an `IllegalStateException` error at pipeline construction time. ### Playground exercise @@ -118,7 +118,7 @@ func applyTransform(s beam.Scope, input beam.PCollection) beam.PCollection { {{if (eq .Sdk "java")}} ``` PCollection> input = pipeline - .apply("ParseCitiesToTimeKV", Create.of( + .apply(Create.of( KV.of("banana", 2), KV.of("apple", 4), KV.of("lemon", 3), diff --git a/learning/tour-of-beam/learning-content/core-transforms/map/map-elements/description.md b/learning/tour-of-beam/learning-content/core-transforms/map/map-elements/description.md index 27a07e4849fb..51c38a44b771 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/map/map-elements/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/map/map-elements/description.md @@ -28,7 +28,7 @@ PCollection wordLengths = input.apply( })); ``` -If your `ParDo` performs a one-to-one mapping of input elements to output elements–that is, for each input element, it applies a function that produces exactly one output element, you can use the higher-level `MapElements` transform.MapElements can accept an anonymous Java 8 lambda function for additional brevity. +If your `ParDo` performs a one-to-one mapping of input elements to output elements–that is, for each input element, it applies a function that produces exactly one output element, you can use the higher-level `MapElements` transform. `MapElements` can accept an anonymous Java 8 lambda function for additional brevity. Here’s the previous example using `MapElements` : diff --git a/learning/tour-of-beam/learning-content/core-transforms/map/pardo-one-to-one/description.md b/learning/tour-of-beam/learning-content/core-transforms/map/pardo-one-to-one/description.md index e7bdbc6565ed..0f3dd6f580a5 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/map/pardo-one-to-one/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/map/pardo-one-to-one/description.md @@ -264,9 +264,9 @@ wordLengths := beam.ParDo(s, func(word string) int { {{if (eq .Sdk "java")}} ### Accessing additional parameters in your DoFn -In addition to the element and the OutputReceiver, Beam will populate other parameters to your DoFn’s @ProcessElement method. Any combination of these parameters can be added to your process method in any order. +In addition to the element and the `OutputReceiver`, Beam will populate other parameters to your DoFn’s `@ProcessElement` method. Any combination of these parameters can be added to your process method in any order. -**Timestamp**: To access the timestamp of an input element, add a parameter annotated with @Timestamp of type Instant. For example: +**Timestamp**: To access the timestamp of an input element, add a parameter annotated with `@Timestamp` of type `Instant`. For example: ``` .of(new DoFn() { @@ -274,7 +274,7 @@ In addition to the element and the OutputReceiver, Beam will populate other para }}) ``` -**Window**: To access the window an input element falls into, add a parameter of the type of the window used for the input `PCollection`. If the parameter is a window type (a subclass of BoundedWindow) that does not match the input `PCollection`, then an error will be raised. If an element falls in multiple windows (for example, this will happen when using `SlidingWindows`), then the `@ProcessElement` method will be invoked multiple time for the element, once for each window. For example, when fixed windows are being used, the window is of type `IntervalWindow`. +**Window**: To access the window an input element falls into, add a parameter of the type of the window used for the input `PCollection`. If the parameter is a window type (a subclass of `BoundedWindow`) that does not match the input `PCollection`, then an error will be raised. If an element falls in multiple windows (for example, this will happen when using `SlidingWindows`), then the `@ProcessElement` method will be invoked multiple time for the element, once for each window. For example, when fixed windows are being used, the window is of type `IntervalWindow`. ``` .of(new DoFn() { @@ -298,7 +298,7 @@ In addition to the element and the OutputReceiver, Beam will populate other para }}) ``` -`@OnTimer` methods can also access many of these parameters. Timestamp, Window, key, `PipelineOptions`, `OutputReceiver`, and `MultiOutputReceiver` parameters can all be accessed in an @OnTimer method. In addition, an `@OnTimer` method can take a parameter of type `TimeDomain` which tells whether the timer is based on event time or processing time. Timers are explained in more detail in the Timely (and Stateful) Processing with Apache Beam blog post. +`@OnTimer` methods can also access many of these parameters. Timestamp, Window, key, `PipelineOptions`, `OutputReceiver`, and `MultiOutputReceiver` parameters can all be accessed in an `@OnTimer` method. In addition, an `@OnTimer` method can take a parameter of type `TimeDomain` which tells whether the timer is based on event time or processing time. Timers are explained in more detail in the [Timely (and Stateful) Processing with Apache Beam blog post](https://beam.apache.org/blog/timely-processing/). {{end}} @@ -401,7 +401,7 @@ class StatefulDoFn(beam.DoFn): You can find the full code of this example in the playground window, which you can run and experiment with. -You can work with any type of object.For example String: +You can work with any type of object. For example `String`: {{if (eq .Sdk "go")}} ``` diff --git a/learning/tour-of-beam/learning-content/core-transforms/partition/description.md b/learning/tour-of-beam/learning-content/core-transforms/partition/description.md index 59f6ad5f5c33..e951a38c1750 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/partition/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/partition/description.md @@ -70,7 +70,7 @@ fortieth_percentile = by_decile[4] You can find the full code of this example in the playground window, which you can run and experiment with. -The `applyTransforms` returns a slice of the PCollection, you can access it by index. In this case, we have two `PCollections`, one consists of numbers that are less than 100, the second is more than 100. +The `applyTransforms` returns a slice of the `PCollection`, you can access it by index. In this case, we have two `PCollections`, one consists of numbers that are less than 100, the second is more than 100. You can also divide other types into parts, for example: "strings" and others. diff --git a/learning/tour-of-beam/learning-content/core-transforms/side-inputs/description.md b/learning/tour-of-beam/learning-content/core-transforms/side-inputs/description.md index a64ccde52950..ae5bdffc827d 100644 --- a/learning/tour-of-beam/learning-content/core-transforms/side-inputs/description.md +++ b/learning/tour-of-beam/learning-content/core-transforms/side-inputs/description.md @@ -11,7 +11,7 @@ limitations under the License. --> # Side inputs -In addition to the main input `PCollection`, you can provide additional inputs to a `ParDo` transform in the form of side inputs. A side input is an additional input that your `DoFn` can access each time it processes an element in the input PCollection. When you specify a side input, you create a view of some other data that can be read from within the `ParDo` transform’s `DoFn` while processing each element. +In addition to the main input `PCollection`, you can provide additional inputs to a `ParDo` transform in the form of side inputs. A side input is an additional input that your `DoFn` can access each time it processes an element in the input `PCollection`. When you specify a side input, you create a view of some other data that can be read from within the `ParDo` transform’s `DoFn` while processing each element. Side inputs are useful if your `ParDo` needs to inject additional data when processing each element in the input `PCollection`, but the additional data needs to be determined at runtime (and not hard-coded). Such values might be determined by the input data, or depend on a different branch of your pipeline. {{if (eq .Sdk "go")}} @@ -172,7 +172,7 @@ If the side input has multiple trigger firings, Beam uses the value from the lat You can find the full code of this example in the playground window, which you can run and experiment with. -At the entrance we have a map whose key is the city of the country value. And we also have a `Person` structure with his name and city. We can compare cities and embed countries in `Person`. +At the entrance we have a map whose key is the city of the country value. And we also have a `Person` structure with their name and city. We can compare cities and embed countries in `Person`. You can also use it as a variable for mathematical calculations. diff --git a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md index 6eb1c04e966a..d7c4cb4137dc 100644 --- a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md +++ b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md @@ -53,7 +53,7 @@ When using Java, you must specify your dependency on the Direct Runner in your p #### Set runner -In java, you need to set runner to `args` when you start the program. +In Java, you need to set runner to `args` when you start the program. ``` --runner=DirectRunner @@ -191,8 +191,8 @@ $ wordcount --input gs://dataflow-samples/shakespeare/kinglear.txt \ {{if (eq .Sdk "java")}} ##### Portable -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.15`, `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` @@ -233,8 +233,8 @@ mvn exec:java -Dexec.mainClass=org.apache.beam.examples.WordCount \ {{end}} {{if (eq .Sdk "python")}} -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.10`, `Flink 1.11`, `Flink 1.12`, `Flink 1.13`, `Flink 1.14`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` diff --git a/local-env-setup.sh b/local-env-setup.sh index f13dc88432a6..ba30813b2bcc 100755 --- a/local-env-setup.sh +++ b/local-env-setup.sh @@ -55,7 +55,7 @@ if [ "$kernelname" = "Linux" ]; then exit fi - for ver in 3.8 3.9 3.10 3.11 3.12 3; do + for ver in 3.9 3.10 3.11 3.12 3; do apt install --yes python$ver-venv done @@ -89,7 +89,7 @@ elif [ "$kernelname" = "Darwin" ]; then echo "Installing openjdk@8" brew install openjdk@8 fi - for ver in 3.8 3.9 3.10 3.11 3.12; do + for ver in 3.9 3.10 3.11 3.12; do if brew ls --versions python@$ver > /dev/null; then echo "python@$ver already installed. Skipping" brew info python@$ver diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index 429371e11055..f102e82bafa6 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -59,6 +59,24 @@ message ExpansionMethods { } } +// Defines the URNs for managed transforms. +message ManagedTransforms { + enum Urns { + ICEBERG_READ = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:iceberg_read:v1"]; + ICEBERG_WRITE = 1 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:iceberg_write:v1"]; + KAFKA_READ = 2 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:kafka_read:v1"]; + KAFKA_WRITE = 3 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:kafka_write:v1"]; + BIGQUERY_READ = 4 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:bigquery_storage_read:v1"]; + BIGQUERY_WRITE = 5 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:bigquery_write:v1"]; + } +} + // A configuration payload for an external transform. // Used to define a Java transform that can be directly instantiated by a Java // expansion service. diff --git a/playground/backend/internal/preparers/python_preparers.go b/playground/backend/internal/preparers/python_preparers.go index f050237492b1..96a4ed32910a 100644 --- a/playground/backend/internal/preparers/python_preparers.go +++ b/playground/backend/internal/preparers/python_preparers.go @@ -26,7 +26,7 @@ import ( ) const ( - addLogHandlerCode = "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + addLogHandlerCode = "import logging\nlogging.basicConfig(\n level=logging.ERROR,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" oneIndentation = " " findWithPipelinePattern = `(\s*)with.+Pipeline.+as (.+):` indentationPattern = `^(%s){0,1}\w+` diff --git a/playground/backend/internal/preparers/python_preparers_test.go b/playground/backend/internal/preparers/python_preparers_test.go index b2cfa7eccaac..f333a1639b7c 100644 --- a/playground/backend/internal/preparers/python_preparers_test.go +++ b/playground/backend/internal/preparers/python_preparers_test.go @@ -53,7 +53,7 @@ func TestGetPythonPreparers(t *testing.T) { } func Test_addCodeToFile(t *testing.T) { - wantCode := "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + pyCode + wantCode := "import logging\nlogging.basicConfig(\n level=logging.ERROR,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + pyCode type args struct { args []interface{} diff --git a/playground/infrastructure/cloudbuild/playground_cd_examples.sh b/playground/infrastructure/cloudbuild/playground_cd_examples.sh index d05773656b30..e571bc9fc9d9 100644 --- a/playground/infrastructure/cloudbuild/playground_cd_examples.sh +++ b/playground/infrastructure/cloudbuild/playground_cd_examples.sh @@ -97,15 +97,15 @@ LogOutput "Installing python and dependencies." export DEBIAN_FRONTEND=noninteractive apt install -y apt-transport-https ca-certificates software-properties-common curl unzip apt-utils > /dev/null 2>&1 add-apt-repository -y ppa:deadsnakes/ppa > /dev/null 2>&1 && apt update > /dev/null 2>&1 -apt install -y python3.8 python3.8-distutils python3-pip > /dev/null 2>&1 -apt install -y --reinstall python3.8-distutils > /dev/null 2>&1 +apt install -y python3.9 python3-distutils python3-pip > /dev/null 2>&1 +apt install -y --reinstall python3-distutils > /dev/null 2>&1 apt install -y python3-virtualenv virtualenv play_venv source play_venv/bin/activate pip install --upgrade google-api-python-client > /dev/null 2>&1 -python3.8 -m pip install pip --upgrade > /dev/null 2>&1 -ln -s /usr/bin/python3.8 /usr/bin/python > /dev/null 2>&1 -apt install -y python3.8-venv > /dev/null 2>&1 +python3.9 -m pip install pip --upgrade > /dev/null 2>&1 +ln -s /usr/bin/python3.9 /usr/bin/python > /dev/null 2>&1 +apt install -y python3.9-venv > /dev/null 2>&1 LogOutput "Installing Python packages from beam/playground/infrastructure/requirements.txt" cd $BEAM_ROOT_DIR diff --git a/playground/infrastructure/cloudbuild/playground_ci_examples.sh b/playground/infrastructure/cloudbuild/playground_ci_examples.sh index 437cc337faf7..2a63382615a5 100755 --- a/playground/infrastructure/cloudbuild/playground_ci_examples.sh +++ b/playground/infrastructure/cloudbuild/playground_ci_examples.sh @@ -94,12 +94,12 @@ export DEBIAN_FRONTEND=noninteractive LogOutput "Installing Python environment" apt-get install -y apt-transport-https ca-certificates software-properties-common curl unzip apt-utils > /dev/null add-apt-repository -y ppa:deadsnakes/ppa > /dev/null && apt update > /dev/null -apt install -y python3.8 python3.8-distutils python3-pip > /dev/null -apt install --reinstall python3.8-distutils > /dev/null +apt install -y python3.9 python3-distutils python3-pip > /dev/null +apt install --reinstall python3-distutils > /dev/null pip install --upgrade google-api-python-client > /dev/null -python3.8 -m pip install pip --upgrade > /dev/null -ln -s /usr/bin/python3.8 /usr/bin/python > /dev/null -apt install python3.8-venv > /dev/null +python3.9 -m pip install pip --upgrade > /dev/null +ln -s /usr/bin/python3.9 /usr/bin/python > /dev/null +apt install python3.9-venv > /dev/null LogOutput "Installing Python packages from beam/playground/infrastructure/requirements.txt" pip install -r $BEAM_ROOT_DIR/playground/infrastructure/requirements.txt diff --git a/release/build.gradle.kts b/release/build.gradle.kts index ca1c152c9eb5..7ec49b86aac2 100644 --- a/release/build.gradle.kts +++ b/release/build.gradle.kts @@ -39,7 +39,7 @@ task("runJavaExamplesValidationTask") { dependsOn(":runners:direct-java:runQuickstartJavaDirect") dependsOn(":runners:google-cloud-dataflow-java:runQuickstartJavaDataflow") dependsOn(":runners:spark:3:runQuickstartJavaSpark") - dependsOn(":runners:flink:1.18:runQuickstartJavaFlinkLocal") + dependsOn(":runners:flink:1.19:runQuickstartJavaFlinkLocal") dependsOn(":runners:direct-java:runMobileGamingJavaDirect") dependsOn(":runners:google-cloud-dataflow-java:runMobileGamingJavaDataflow") dependsOn(":runners:twister2:runQuickstartJavaTwister2") diff --git a/release/src/main/Dockerfile b/release/src/main/Dockerfile index 14fe6fdb5a49..6503c5c42ba8 100644 --- a/release/src/main/Dockerfile +++ b/release/src/main/Dockerfile @@ -42,12 +42,11 @@ RUN curl https://pyenv.run | bash && \ echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> /root/.bashrc && \ echo ''eval "$(pyenv init -)"'' >> /root/.bashrc && \ source /root/.bashrc && \ - pyenv install 3.8.9 && \ pyenv install 3.9.4 && \ pyenv install 3.10.7 && \ pyenv install 3.11.3 && \ pyenv install 3.12.3 && \ - pyenv global 3.8.9 3.9.4 3.10.7 3.11.3 3.12.3 + pyenv global 3.9.4 3.10.7 3.11.3 3.12.3 # Install a Go version >= 1.16 so we can bootstrap higher # Go versions diff --git a/release/src/main/groovy/TestScripts.groovy b/release/src/main/groovy/TestScripts.groovy index d5042aa61941..dc2438007ac1 100644 --- a/release/src/main/groovy/TestScripts.groovy +++ b/release/src/main/groovy/TestScripts.groovy @@ -36,6 +36,7 @@ class TestScripts { static String gcsBucket static String bqDataset static String pubsubTopic + static String mavenLocalPath } def TestScripts(String[] args) { @@ -47,6 +48,7 @@ class TestScripts { cli.gcsBucket(args:1, 'Google Cloud Storage Bucket') cli.bqDataset(args:1, "BigQuery Dataset") cli.pubsubTopic(args:1, "PubSub Topic") + cli.mavenLocalPath(args:1, "Maven local path") def options = cli.parse(args) var.repoUrl = options.repourl @@ -73,6 +75,10 @@ class TestScripts { var.pubsubTopic = options.pubsubTopic println "PubSub Topic: ${var.pubsubTopic}" } + if (options.mavenLocalPath) { + var.mavenLocalPath = options.mavenLocalPath + println "Maven local path: ${var.mavenLocalPath}" + } } def ver() { @@ -189,11 +195,16 @@ class TestScripts { } } - // Run a maven command, setting up a new local repository and a settings.xml with a custom repository + // Run a maven command, setting up a new local repository and a settings.xml with a custom repository if needed private String _mvn(String args) { - def m2 = new File(var.startDir, ".m2/repository") + String mvnlocalPath = var.mavenLocalPath + if (!(var.mavenLocalPath)) { + mvnlocalPath = var.startDir + } + def m2 = new File(mvnlocalPath, ".m2/repository") m2.mkdirs() - def settings = new File(var.startDir, "settings.xml") + def settings = new File(mvnlocalPath, "settings.xml") + if(!settings.exists()) { settings.write """ ${m2.absolutePath} @@ -209,16 +220,17 @@ class TestScripts { - """ - def cmd = "mvn ${args} -s ${settings.absolutePath} -Ptestrel -B" - String path = System.getenv("PATH"); - // Set the path on jenkins executors to use a recent maven - // MAVEN_HOME is not set on some executors, so default to 3.5.2 - String maven_home = System.getenv("MAVEN_HOME") ?: '/home/jenkins/tools/maven/apache-maven-3.5.4' - println "Using maven ${maven_home}" - def mvnPath = "${maven_home}/bin" - def setPath = "export PATH=\"${mvnPath}:${path}\" && " - return _execute(setPath + cmd) + """ + } + def cmd = "mvn ${args} -s ${settings.absolutePath} -Ptestrel -B" + String path = System.getenv("PATH"); + // Set the path on jenkins executors to use a recent maven + // MAVEN_HOME is not set on some executors, so default to 3.5.2 + String maven_home = System.getenv("MAVEN_HOME") ?: '/usr/local/maven' + println "Using maven ${maven_home}" + def mvnPath = "${maven_home}/bin" + def setPath = "export PATH=\"${mvnPath}:${path}\" && " + return _execute(setPath + cmd) } // Clean up and report error diff --git a/release/src/main/python-release/python_release_automation.sh b/release/src/main/python-release/python_release_automation.sh index 2f6986885a96..248bdd9b65ac 100755 --- a/release/src/main/python-release/python_release_automation.sh +++ b/release/src/main/python-release/python_release_automation.sh @@ -19,7 +19,7 @@ source release/src/main/python-release/run_release_candidate_python_quickstart.sh source release/src/main/python-release/run_release_candidate_python_mobile_gaming.sh -for version in 3.8 3.9 3.10 3.11 3.12 +for version in 3.9 3.10 3.11 3.12 do run_release_candidate_python_quickstart "tar" "python${version}" run_release_candidate_python_mobile_gaming "tar" "python${version}" diff --git a/release/src/main/scripts/download_github_actions_artifacts.py b/release/src/main/scripts/download_github_actions_artifacts.py index 99526f1ac7d1..5c553efc6500 100644 --- a/release/src/main/scripts/download_github_actions_artifacts.py +++ b/release/src/main/scripts/download_github_actions_artifacts.py @@ -279,8 +279,11 @@ def fetch_github_artifacts(run_id, repo_url, artifacts_dir, github_token, rc_num print("Starting downloading artifacts ... (it may take a while)") run_data = get_single_workflow_run_data(run_id, repo_url, github_token) artifacts_url = safe_get(run_data, "artifacts_url") - data_artifacts = request_url(artifacts_url, github_token) + data_artifacts = request_url(artifacts_url + '?per_page=100', github_token) artifacts = safe_get(data_artifacts, "artifacts", artifacts_url) + total_count = safe_get(data_artifacts, "total_count", artifacts_url) + if int(total_count) != len(artifacts): + raise RuntimeError(f"Expected total count {total_count} different than returned list length {len(data_artifacts)}") print('Filtering ', len(artifacts), ' artifacts') filtered_artifacts = filter_artifacts(artifacts, rc_number) print('Preparing to download ', len(filtered_artifacts), ' artifacts') diff --git a/runners/core-java/build.gradle b/runners/core-java/build.gradle index b477dde91212..0f55c10b97f2 100644 --- a/runners/core-java/build.gradle +++ b/runners/core-java/build.gradle @@ -42,6 +42,7 @@ dependencies { implementation project(path: ":model:pipeline", configuration: "shadow") implementation project(path: ":sdks:java:core", configuration: "shadow") implementation project(path: ":model:job-management", configuration: "shadow") + implementation library.java.google_api_services_dataflow implementation library.java.vendored_guava_32_1_2_jre implementation library.java.joda_time implementation library.java.vendored_grpc_1_60_1 @@ -52,5 +53,6 @@ dependencies { testImplementation library.java.junit testImplementation library.java.mockito_core testImplementation library.java.slf4j_api + testImplementation(library.java.google_api_services_dataflow) testRuntimeOnly library.java.slf4j_simple } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateNamespaces.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateNamespaces.java index 6c0ed7740489..a68ab6c913ce 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateNamespaces.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateNamespaces.java @@ -90,8 +90,8 @@ public void appendTo(Appendable sb) throws IOException { /** {@link StateNamespace} that is scoped to a specific window. */ public static class WindowNamespace implements StateNamespace { - private Coder windowCoder; - private W window; + private final Coder windowCoder; + private final W window; private WindowNamespace(Coder windowCoder, W window) { this.windowCoder = windowCoder; diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodings.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodings.java index 433e7f4fb20b..88136a864d03 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodings.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodings.java @@ -17,8 +17,19 @@ */ package org.apache.beam.runners.core.metrics; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.api.services.dataflow.model.Base2Exponent; +import com.google.api.services.dataflow.model.BucketOptions; +import com.google.api.services.dataflow.model.DataflowHistogramValue; +import com.google.api.services.dataflow.model.Linear; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; import java.util.Set; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.DoubleCoder; @@ -26,10 +37,14 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.util.HistogramData; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; import org.joda.time.Instant; +// TODO(#33093): Refactor out DataflowHistogramValue to be runner agnostic, and rename to +// remove Dataflow reference. + /** A set of functions used to encode and decode common monitoring info types. */ public class MonitoringInfoEncodings { private static final Coder VARINT_CODER = VarLongCoder.of(); @@ -163,4 +178,98 @@ public static double decodeDoubleCounter(ByteString payload) { throw new RuntimeException(e); } } + + /** Encodes to {@link MonitoringInfoConstants.TypeUrns#PER_WORKER_HISTOGRAM}. */ + public static ByteString encodeInt64Histogram(HistogramData inputHistogram) { + try { + int numberOfBuckets = inputHistogram.getBucketType().getNumBuckets(); + + DataflowHistogramValue outputHistogram2 = new DataflowHistogramValue(); + + if (inputHistogram.getBucketType() instanceof HistogramData.LinearBuckets) { + HistogramData.LinearBuckets buckets = + (HistogramData.LinearBuckets) inputHistogram.getBucketType(); + Linear linear = new Linear(); + linear.setNumberOfBuckets(numberOfBuckets); + linear.setWidth(buckets.getWidth()); + linear.setStart(buckets.getStart()); + outputHistogram2.setBucketOptions(new BucketOptions().setLinear(linear)); + } else if (inputHistogram.getBucketType() instanceof HistogramData.ExponentialBuckets) { + HistogramData.ExponentialBuckets buckets = + (HistogramData.ExponentialBuckets) inputHistogram.getBucketType(); + Base2Exponent base2Exp = new Base2Exponent(); + base2Exp.setNumberOfBuckets(numberOfBuckets); + base2Exp.setScale(buckets.getScale()); + outputHistogram2.setBucketOptions(new BucketOptions().setExponential(base2Exp)); + } else { + throw new HistogramParsingException( + "Unable to encode Int64 Histogram, bucket is not recognized"); + } + + outputHistogram2.setCount(inputHistogram.getTotalCount()); + + List bucketCounts = new ArrayList<>(); + + Arrays.stream(inputHistogram.getBucketCount()) + .forEach( + val -> { + bucketCounts.add(val); + }); + + outputHistogram2.setBucketCounts(bucketCounts); + + ObjectMapper objectMapper = new ObjectMapper(); + String jsonString = objectMapper.writeValueAsString(outputHistogram2); + + return ByteString.copyFromUtf8(jsonString); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + static class HistogramParsingException extends RuntimeException { + public HistogramParsingException(String message) { + super(message); + } + } + + /** Decodes to {@link MonitoringInfoConstants.TypeUrns#PER_WORKER_HISTOGRAM}. */ + public static HistogramData decodeInt64Histogram(ByteString payload) { + try { + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + JsonNode jsonNode = objectMapper.readTree(payload.toStringUtf8()); // parse afterwards + DataflowHistogramValue newHist = new DataflowHistogramValue(); + newHist.setCount(jsonNode.get("count").asLong()); + + List bucketCounts = new ArrayList<>(); + Iterator itr = jsonNode.get("bucketCounts").iterator(); + while (itr.hasNext()) { + Long item = itr.next().asLong(); + bucketCounts.add(item); + } + newHist.setBucketCounts(bucketCounts); + + if (jsonNode.get("bucketOptions").has("linear")) { + Linear linear = new Linear(); + JsonNode linearNode = jsonNode.get("bucketOptions").get("linear"); + linear.setNumberOfBuckets(linearNode.get("numberOfBuckets").asInt()); + linear.setWidth(linearNode.get("width").asDouble()); + linear.setStart(linearNode.get("start").asDouble()); + newHist.setBucketOptions(new BucketOptions().setLinear(linear)); + } else if (jsonNode.get("bucketOptions").has("exponential")) { + Base2Exponent base2Exp = new Base2Exponent(); + JsonNode expNode = jsonNode.get("bucketOptions").get("exponential"); + base2Exp.setNumberOfBuckets(expNode.get("numberOfBuckets").asInt()); + base2Exp.setScale(expNode.get("scale").asInt()); + newHist.setBucketOptions(new BucketOptions().setExponential(base2Exp)); + } else { + throw new HistogramParsingException( + "Unable to parse Int64 Histogram, bucket is not recognized"); + } + return new HistogramData(newHist); + } catch (IOException e) { + throw new RuntimeException(e); + } + } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java index 8455f154c0f8..fc8dcb49894f 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java @@ -22,7 +22,6 @@ import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.metrics.StringSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -101,11 +100,15 @@ public void add(String value) { if (this.setValue.get().stringSet().contains(value)) { return; } - update(StringSetData.create(ImmutableSet.of(value))); + add(new String[] {value}); } @Override public void add(String... values) { - update(StringSetData.create(ImmutableSet.copyOf(values))); + StringSetData original; + do { + original = setValue.get(); + } while (!setValue.compareAndSet(original, original.addAll(values))); + dirty.afterModification(); } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java index 466d4ad46eb6..5f9bb6392ec2 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java @@ -19,25 +19,49 @@ import com.google.auto.value.AutoValue; import java.io.Serializable; +import java.util.Arrays; import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; +import java.util.concurrent.ConcurrentHashMap; import org.apache.beam.sdk.metrics.StringSetResult; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * Data describing the StringSet. The {@link StringSetData} hold an immutable copy of the set from - * which it was initially created. This should retain enough detail that it can be combined with - * other {@link StringSetData}. + * Data describing the StringSet. The {@link StringSetData} hold a copy of the set from which it was + * initially created. This should retain enough detail that it can be combined with other {@link + * StringSetData}. + * + *

The underlying set is mutable for {@link #addAll} operation, otherwise a copy set will be + * generated. + * + *

The summation of all string length for a {@code StringSetData} cannot exceed 1 MB. Further + * addition of elements are dropped. */ @AutoValue public abstract class StringSetData implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(StringSetData.class); + // 1 MB + @VisibleForTesting static final long STRING_SET_SIZE_LIMIT = 1_000_000L; public abstract Set stringSet(); + public abstract long stringSize(); + /** Returns a {@link StringSetData} which is made from an immutable copy of the given set. */ public static StringSetData create(Set set) { - return new AutoValue_StringSetData(ImmutableSet.copyOf(set)); + if (set.isEmpty()) { + return empty(); + } + Set combined = ConcurrentHashMap.newKeySet(); + long stringSize = addUntilCapacity(combined, 0L, set); + return new AutoValue_StringSetData(combined, stringSize); + } + + /** Returns a {@link StringSetData} which is made from the given set in place. */ + private static StringSetData createInPlace(Set set, long stringSize) { + return new AutoValue_StringSetData(set, stringSize); } /** Return a {@link EmptyStringSetData#INSTANCE} representing an empty {@link StringSetData}. */ @@ -45,6 +69,24 @@ public static StringSetData empty() { return EmptyStringSetData.INSTANCE; } + /** + * Add strings into this {@code StringSetData} and return the result {@code StringSetData}. Reuse + * the original StringSetData's set. As a result, current StringSetData will become invalid. + * + *

>Should only be used by {@link StringSetCell#add}. + */ + public StringSetData addAll(String... strings) { + Set combined; + if (this.stringSet() instanceof ConcurrentHashMap.KeySetView) { + combined = this.stringSet(); + } else { + combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); + } + long stringSize = addUntilCapacity(combined, this.stringSize(), Arrays.asList(strings)); + return StringSetData.createInPlace(combined, stringSize); + } + /** * Combines this {@link StringSetData} with other, both original StringSetData are left intact. */ @@ -54,10 +96,10 @@ public StringSetData combine(StringSetData other) { } else if (other.stringSet().isEmpty()) { return this; } else { - ImmutableSet.Builder combined = ImmutableSet.builder(); + Set combined = ConcurrentHashMap.newKeySet(); combined.addAll(this.stringSet()); - combined.addAll(other.stringSet()); - return StringSetData.create(combined.build()); + long stringSize = addUntilCapacity(combined, this.stringSize(), other.stringSet()); + return StringSetData.createInPlace(combined, stringSize); } } @@ -65,12 +107,13 @@ public StringSetData combine(StringSetData other) { * Combines this {@link StringSetData} with others, all original StringSetData are left intact. */ public StringSetData combine(Iterable others) { - Set combined = - StreamSupport.stream(others.spliterator(), true) - .flatMap(other -> other.stringSet().stream()) - .collect(Collectors.toSet()); + Set combined = ConcurrentHashMap.newKeySet(); combined.addAll(this.stringSet()); - return StringSetData.create(combined); + long stringSize = this.stringSize(); + for (StringSetData other : others) { + stringSize = addUntilCapacity(combined, stringSize, other.stringSet()); + } + return StringSetData.createInPlace(combined, stringSize); } /** Returns a {@link StringSetResult} representing this {@link StringSetData}. */ @@ -78,6 +121,31 @@ public StringSetResult extractResult() { return StringSetResult.create(stringSet()); } + /** Add strings into set until reach capacity. Return the all string size of added set. */ + private static long addUntilCapacity( + Set combined, long currentSize, Iterable others) { + if (currentSize > STRING_SET_SIZE_LIMIT) { + // already at capacity + return currentSize; + } + for (String string : others) { + if (combined.add(string)) { + currentSize += string.length(); + + // check capacity both before insert and after insert one, so the warning only emit once. + if (currentSize > STRING_SET_SIZE_LIMIT) { + LOG.warn( + "StringSet metrics reaches capacity. Further incoming elements won't be recorded." + + " Current size: {}, last element size: {}.", + currentSize, + string.length()); + break; + } + } + } + return currentSize; + } + /** Empty {@link StringSetData}, representing no values reported and is immutable. */ public static class EmptyStringSetData extends StringSetData { @@ -91,6 +159,11 @@ public Set stringSet() { return ImmutableSet.of(); } + @Override + public long stringSize() { + return 0L; + } + /** Return a {@link StringSetResult#empty()} which is immutable empty set. */ @Override public StringSetResult extractResult() { diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodingsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodingsTest.java index 8a43eef5883d..2d7ba61dbc95 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodingsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MonitoringInfoEncodingsTest.java @@ -17,30 +17,43 @@ */ package org.apache.beam.runners.core.metrics; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.HistogramParsingException; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeDoubleCounter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Counter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Distribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Gauge; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeInt64Histogram; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.decodeStringSet; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeDoubleCounter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeDoubleDistribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Counter; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Distribution; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Gauge; +import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeInt64Histogram; import static org.apache.beam.runners.core.metrics.MonitoringInfoEncodings.encodeStringSet; import static org.junit.Assert.assertEquals; import java.util.Collections; +import org.apache.beam.sdk.testing.ExpectedLogs; +import org.apache.beam.sdk.util.HistogramData; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.joda.time.Instant; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Tests for {@link MonitoringInfoEncodings}. */ @RunWith(JUnit4.class) public class MonitoringInfoEncodingsTest { + @Rule + public ExpectedLogs monitoringInfoCodingsExpectedLogs = + ExpectedLogs.none(MonitoringInfoEncodings.class); + + @Rule public ExpectedException thrown = ExpectedException.none(); + @Test public void testInt64DistributionEncoding() { DistributionData data = DistributionData.create(1L, 2L, 3L, 4L); @@ -105,4 +118,36 @@ public void testDoubleCounterEncoding() { assertEquals(ByteString.copyFrom(new byte[] {0x3f, (byte) 0xf0, 0, 0, 0, 0, 0, 0}), payload); assertEquals(1.0, decodeDoubleCounter(payload), 0.001); } + + @Test + public void testHistgramInt64EncodingLinearHist() { + HistogramData.BucketType buckets = HistogramData.LinearBuckets.of(0, 5, 5); + + HistogramData inputHistogram = new HistogramData(buckets); + inputHistogram.record(5, 10, 15, 20); + ByteString payload = encodeInt64Histogram(inputHistogram); + + assertEquals(inputHistogram, decodeInt64Histogram(payload)); + } + + @Test + public void testHistgramInt64EncodingExpHist() { + HistogramData.BucketType buckets = HistogramData.ExponentialBuckets.of(1, 10); + HistogramData inputHistogram = new HistogramData(buckets); + inputHistogram.record(2, 4, 8, 16, 32); + ByteString payload = encodeInt64Histogram(inputHistogram); + assertEquals(inputHistogram, decodeInt64Histogram(payload)); + } + + @Test + public void testHistgramInt64EncodingUnsupportedBucket() { + thrown.expect(HistogramParsingException.class); + thrown.expectMessage("Unable to encode Int64 Histogram, bucket is not recognized"); + + HistogramData.BucketType buckets = HistogramData.UnsupportedBuckets.of(); + + HistogramData inputHistogram = new HistogramData(buckets); + inputHistogram.record(2, 4, 8, 16, 32); + encodeInt64Histogram(inputHistogram); + } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java index f78ed01603fb..9497bbe43d0e 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java @@ -20,7 +20,13 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.junit.Assert; @@ -94,4 +100,42 @@ public void testReset() { assertThat(stringSetCell.getCumulative(), equalTo(StringSetData.empty())); assertThat(stringSetCell.getDirty(), equalTo(new DirtyState())); } + + @Test(timeout = 5000) + public void testStringSetCellConcurrentAddRetrieval() throws InterruptedException { + StringSetCell cell = new StringSetCell(MetricName.named("namespace", "name")); + AtomicBoolean finished = new AtomicBoolean(false); + Thread increment = + new Thread( + () -> { + for (long i = 0; !finished.get(); ++i) { + cell.add(String.valueOf(i)); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + break; + } + } + }); + increment.start(); + Instant start = Instant.now(); + try { + while (true) { + Set s = cell.getCumulative().stringSet(); + List snapshot = new ArrayList<>(s); + if (Instant.now().isAfter(start.plusSeconds(3)) && snapshot.size() > 0) { + finished.compareAndSet(false, true); + break; + } + } + } finally { + increment.interrupt(); + increment.join(); + } + + Set s = cell.getCumulative().stringSet(); + for (long i = 0; i < s.size(); ++i) { + assertTrue(s.contains(String.valueOf(i))); + } + } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java index 665ce3743c51..534db203ff3c 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import java.util.Collections; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.junit.Rule; import org.junit.Test; @@ -81,6 +82,14 @@ public void testStringSetDataEmptyIsImmutable() { assertThrows(UnsupportedOperationException.class, () -> empty.stringSet().add("aa")); } + @Test + public void testStringSetDataEmptyCanAdd() { + ImmutableSet contents = ImmutableSet.of("ab", "cd"); + StringSetData stringSetData = StringSetData.empty(); + stringSetData = stringSetData.addAll(contents.toArray(new String[] {})); + assertEquals(stringSetData.stringSet(), contents); + } + @Test public void testEmptyExtract() { assertTrue(StringSetData.empty().extractResult().getStringSet().isEmpty()); @@ -94,9 +103,26 @@ public void testExtract() { } @Test - public void testExtractReturnsImmutable() { - StringSetData stringSetData = StringSetData.create(ImmutableSet.of("ab", "cd")); - // check that immutable copy is returned - assertThrows(UnsupportedOperationException.class, () -> stringSetData.stringSet().add("aa")); + public void testStringSetAddUntilCapacity() { + StringSetData combined = StringSetData.empty(); + @SuppressWarnings("InlineMeInliner") // Inline representation is Java11+ only + String commonPrefix = Strings.repeat("*", 1000); + long stringSize = 0; + for (int i = 0; i < 1000; ++i) { + String s = commonPrefix + i; + stringSize += s.length(); + combined = combined.addAll(s); + } + assertTrue(combined.stringSize() < stringSize); + assertTrue(combined.stringSize() > StringSetData.STRING_SET_SIZE_LIMIT); + } + + @Test + public void testStringSetAddSizeTrackedCorrectly() { + StringSetData combined = StringSetData.empty(); + combined = combined.addAll("a", "b", "c", "b"); + assertEquals(3, combined.stringSize()); + combined = combined.addAll("c", "d", "e"); + assertEquals(5, combined.stringSize()); } } diff --git a/runners/direct-java/build.gradle b/runners/direct-java/build.gradle index c357b8a04328..404b864c9c31 100644 --- a/runners/direct-java/build.gradle +++ b/runners/direct-java/build.gradle @@ -22,12 +22,12 @@ plugins { id 'org.apache.beam.module' } // Shade away runner execution utilities till because this causes ServiceLoader conflicts with // TransformPayloadTranslatorRegistrar amongst other runners. This only happens in the DirectRunner // because it is likely to appear on the classpath of another runner. -def dependOnProjects = [ - ":runners:core-java", - ":runners:local-java", - ":runners:java-fn-execution", - ":sdks:java:core", - ] +def dependOnProjectsAndConfigs = [ + ":runners:core-java":null, + ":runners:local-java":null, + ":runners:java-fn-execution":null, + ":sdks:java:core":"shadow", +] applyJavaNature( automaticModuleName: 'org.apache.beam.runners.direct', @@ -36,8 +36,8 @@ applyJavaNature( ], shadowClosure: { dependencies { - dependOnProjects.each { - include(project(path: it, configuration: "shadow")) + dependOnProjectsAndConfigs.each { + include(project(path: it.key, configuration: "shadow")) } } }, @@ -63,8 +63,10 @@ configurations { dependencies { shadow library.java.vendored_guava_32_1_2_jre shadow project(path: ":model:pipeline", configuration: "shadow") - dependOnProjects.each { - implementation project(it) + dependOnProjectsAndConfigs.each { + // For projects producing shadowjar, use the packaged jar as dependency to + // handle redirected packages from it + implementation project(path: it.key, configuration: it.value) } shadow library.java.vendored_grpc_1_60_1 shadow library.java.joda_time diff --git a/runners/flink/1.15/job-server/build.gradle b/runners/flink/1.15/job-server/build.gradle deleted file mode 100644 index 05ad8feb5b78..000000000000 --- a/runners/flink/1.15/job-server/build.gradle +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server' - -project.ext { - // Look for the source code in the parent module - main_source_dirs = ["$basePath/src/main/java"] - test_source_dirs = ["$basePath/src/test/java"] - main_resources_dirs = ["$basePath/src/main/resources"] - test_resources_dirs = ["$basePath/src/test/resources"] - archives_base_name = 'beam-runners-flink-1.15-job-server' -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server.gradle" diff --git a/runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java deleted file mode 100644 index 956aad428d8b..000000000000 --- a/runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.types; - -import java.io.EOFException; -import java.io.IOException; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.flink.FlinkPipelineOptions; -import org.apache.beam.runners.flink.translation.wrappers.DataInputViewWrapper; -import org.apache.beam.runners.flink.translation.wrappers.DataOutputViewWrapper; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; -import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.core.io.VersionedIOReadableWritable; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** - * Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for Beam {@link - * org.apache.beam.sdk.coders.Coder Coders}. - */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class CoderTypeSerializer extends TypeSerializer { - - private static final long serialVersionUID = 7247319138941746449L; - - private final Coder coder; - - /** - * {@link SerializablePipelineOptions} deserialization will cause {@link - * org.apache.beam.sdk.io.FileSystems} registration needed for {@link - * org.apache.beam.sdk.transforms.Reshuffle} translation. - */ - private final SerializablePipelineOptions pipelineOptions; - - private final boolean fasterCopy; - - public CoderTypeSerializer(Coder coder, SerializablePipelineOptions pipelineOptions) { - Preconditions.checkNotNull(coder); - Preconditions.checkNotNull(pipelineOptions); - this.coder = coder; - this.pipelineOptions = pipelineOptions; - - FlinkPipelineOptions options = pipelineOptions.get().as(FlinkPipelineOptions.class); - this.fasterCopy = options.getFasterCopy(); - } - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public CoderTypeSerializer duplicate() { - return new CoderTypeSerializer<>(coder, pipelineOptions); - } - - @Override - public T createInstance() { - return null; - } - - @Override - public T copy(T t) { - if (fasterCopy) { - return t; - } - try { - return CoderUtils.clone(coder, t); - } catch (CoderException e) { - throw new RuntimeException("Could not clone.", e); - } - } - - @Override - public T copy(T t, T reuse) { - return copy(t); - } - - @Override - public int getLength() { - return -1; - } - - @Override - public void serialize(T t, DataOutputView dataOutputView) throws IOException { - DataOutputViewWrapper outputWrapper = new DataOutputViewWrapper(dataOutputView); - coder.encode(t, outputWrapper); - } - - @Override - public T deserialize(DataInputView dataInputView) throws IOException { - try { - DataInputViewWrapper inputWrapper = new DataInputViewWrapper(dataInputView); - return coder.decode(inputWrapper); - } catch (CoderException e) { - Throwable cause = e.getCause(); - if (cause instanceof EOFException) { - throw (EOFException) cause; - } else { - throw e; - } - } - } - - @Override - public T deserialize(T t, DataInputView dataInputView) throws IOException { - return deserialize(dataInputView); - } - - @Override - public void copy(DataInputView dataInputView, DataOutputView dataOutputView) throws IOException { - serialize(deserialize(dataInputView), dataOutputView); - } - - @Override - public boolean equals(@Nullable Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - CoderTypeSerializer that = (CoderTypeSerializer) o; - return coder.equals(that.coder); - } - - @Override - public int hashCode() { - return coder.hashCode(); - } - - @Override - public TypeSerializerSnapshot snapshotConfiguration() { - return new UnversionedTypeSerializerSnapshot<>(this); - } - - /** - * A legacy snapshot which does not care about schema compatibility. This is used only for state - * restore of state created by Beam 2.54.0 and below for Flink 1.16 and below. - */ - public static class LegacySnapshot extends TypeSerializerConfigSnapshot { - - /** Needs to be public to work with {@link VersionedIOReadableWritable}. */ - public LegacySnapshot() {} - - public LegacySnapshot(CoderTypeSerializer serializer) { - setPriorSerializer(serializer); - } - - @Override - public int getVersion() { - // We always return the same version - return 1; - } - - @Override - public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( - TypeSerializer newSerializer) { - - // We assume compatibility because we don't have a way of checking schema compatibility - return TypeSerializerSchemaCompatibility.compatibleAsIs(); - } - } - - @Override - public String toString() { - return "CoderTypeSerializer{" + "coder=" + coder + '}'; - } -} diff --git a/runners/flink/1.16/build.gradle b/runners/flink/1.16/build.gradle deleted file mode 100644 index 21a222864a27..000000000000 --- a/runners/flink/1.16/build.gradle +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -project.ext { - flink_major = '1.16' - flink_version = '1.16.0' -} - -// Load the main build script which contains all build logic. -apply from: "../flink_runner.gradle" diff --git a/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java new file mode 100644 index 000000000000..7317788a72ee --- /dev/null +++ b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import java.util.Collection; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.BackendBuildingException; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.state.ttl.TtlTimeProvider; + +class MemoryStateBackendWrapper { + static AbstractKeyedStateBackend createKeyedStateBackend( + Environment env, + JobID jobID, + String operatorIdentifier, + TypeSerializer keySerializer, + int numberOfKeyGroups, + KeyGroupRange keyGroupRange, + TaskKvStateRegistry kvStateRegistry, + TtlTimeProvider ttlTimeProvider, + MetricGroup metricGroup, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws BackendBuildingException { + + MemoryStateBackend backend = new MemoryStateBackend(); + return backend.createKeyedStateBackend( + env, + jobID, + operatorIdentifier, + keySerializer, + numberOfKeyGroups, + keyGroupRange, + kvStateRegistry, + ttlTimeProvider, + metricGroup, + stateHandles, + cancelStreamRegistry); + } + + static OperatorStateBackend createOperatorStateBackend( + Environment env, + String operatorIdentifier, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws Exception { + MemoryStateBackend backend = new MemoryStateBackend(); + return backend.createOperatorStateBackend( + env, operatorIdentifier, stateHandles, cancelStreamRegistry); + } +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java similarity index 100% rename from runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java rename to runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java diff --git a/runners/flink/1.15/build.gradle b/runners/flink/1.19/build.gradle similarity index 94% rename from runners/flink/1.15/build.gradle rename to runners/flink/1.19/build.gradle index 8055cf593ad0..1545da258477 100644 --- a/runners/flink/1.15/build.gradle +++ b/runners/flink/1.19/build.gradle @@ -17,8 +17,8 @@ */ project.ext { - flink_major = '1.15' - flink_version = '1.15.0' + flink_major = '1.19' + flink_version = '1.19.0' } // Load the main build script which contains all build logic. diff --git a/runners/flink/1.15/job-server-container/build.gradle b/runners/flink/1.19/job-server-container/build.gradle similarity index 100% rename from runners/flink/1.15/job-server-container/build.gradle rename to runners/flink/1.19/job-server-container/build.gradle diff --git a/runners/flink/1.16/job-server/build.gradle b/runners/flink/1.19/job-server/build.gradle similarity index 95% rename from runners/flink/1.16/job-server/build.gradle rename to runners/flink/1.19/job-server/build.gradle index 99dc00275a0c..332f04e08ceb 100644 --- a/runners/flink/1.16/job-server/build.gradle +++ b/runners/flink/1.19/job-server/build.gradle @@ -24,7 +24,7 @@ project.ext { test_source_dirs = ["$basePath/src/test/java"] main_resources_dirs = ["$basePath/src/main/resources"] test_resources_dirs = ["$basePath/src/test/resources"] - archives_base_name = 'beam-runners-flink-1.16-job-server' + archives_base_name = 'beam-runners-flink-1.19-job-server' } // Load the main build script which contains all build logic. diff --git a/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java new file mode 100644 index 000000000000..cbaa6fd3a8c4 --- /dev/null +++ b/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import java.util.Collection; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.BackendBuildingException; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateBackendParametersImpl; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateBackendParametersImpl; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.state.ttl.TtlTimeProvider; + +class MemoryStateBackendWrapper { + static AbstractKeyedStateBackend createKeyedStateBackend( + Environment env, + JobID jobID, + String operatorIdentifier, + TypeSerializer keySerializer, + int numberOfKeyGroups, + KeyGroupRange keyGroupRange, + TaskKvStateRegistry kvStateRegistry, + TtlTimeProvider ttlTimeProvider, + MetricGroup metricGroup, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws BackendBuildingException { + + MemoryStateBackend backend = new MemoryStateBackend(); + return backend.createKeyedStateBackend( + new KeyedStateBackendParametersImpl<>( + env, + jobID, + operatorIdentifier, + keySerializer, + numberOfKeyGroups, + keyGroupRange, + kvStateRegistry, + ttlTimeProvider, + metricGroup, + stateHandles, + cancelStreamRegistry)); + } + + static OperatorStateBackend createOperatorStateBackend( + Environment env, + String operatorIdentifier, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws Exception { + MemoryStateBackend backend = new MemoryStateBackend(); + return backend.createOperatorStateBackend( + new OperatorStateBackendParametersImpl( + env, operatorIdentifier, stateHandles, cancelStreamRegistry)); + } +} diff --git a/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java new file mode 100644 index 000000000000..c03799d09535 --- /dev/null +++ b/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.streaming.runtime.streamrecord.RecordAttributes; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorChain; +import org.apache.flink.streaming.runtime.tasks.RegularOperatorChain; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; + +/** {@link StreamSource} utilities, that bridge incompatibilities between Flink releases. */ +public class StreamSources { + + public static > void run( + StreamSource streamSource, + Object lockingObject, + Output> collector) + throws Exception { + streamSource.run(lockingObject, collector, createOperatorChain(streamSource)); + } + + private static OperatorChain createOperatorChain(AbstractStreamOperator operator) { + return new RegularOperatorChain<>( + operator.getContainingTask(), + StreamTask.createRecordWriterDelegate( + operator.getOperatorConfig(), new MockEnvironmentBuilder().build())); + } + + /** The emitWatermarkStatus method was added in Flink 1.14, so we need to wrap Output. */ + public interface OutputWrapper extends Output { + @Override + default void emitWatermarkStatus(WatermarkStatus watermarkStatus) {} + + /** In Flink 1.19 the {@code emitRecordAttributes} method was added. */ + @Override + default void emitRecordAttributes(RecordAttributes recordAttributes) { + throw new UnsupportedOperationException("emitRecordAttributes not implemented"); + } + } +} diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle index c8f492a901d3..be39d4e0b012 100644 --- a/runners/flink/flink_runner.gradle +++ b/runners/flink/flink_runner.gradle @@ -173,36 +173,19 @@ dependencies { implementation library.java.joda_time implementation library.java.args4j - // Flink 1.15 shades all remaining scala dependencies and therefor does not depend on a specific version of Scala anymore - if (flink_version.compareTo("1.15") >= 0) { - implementation "org.apache.flink:flink-clients:$flink_version" - // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation - // configuration (https://issues.apache.org/jira/browse/BEAM-11732). - permitUnusedDeclared "org.apache.flink:flink-clients:$flink_version" - - implementation "org.apache.flink:flink-streaming-java:$flink_version" - // RocksDB state backend (included in the Flink distribution) - provided "org.apache.flink:flink-statebackend-rocksdb:$flink_version" - testImplementation "org.apache.flink:flink-statebackend-rocksdb:$flink_version" - testImplementation "org.apache.flink:flink-streaming-java:$flink_version:tests" - testImplementation "org.apache.flink:flink-test-utils:$flink_version" - - miniCluster "org.apache.flink:flink-runtime-web:$flink_version" - } else { - implementation "org.apache.flink:flink-clients_2.12:$flink_version" - // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation - // configuration (https://issues.apache.org/jira/browse/BEAM-11732). - permitUnusedDeclared "org.apache.flink:flink-clients_2.12:$flink_version" - - implementation "org.apache.flink:flink-streaming-java_2.12:$flink_version" - // RocksDB state backend (included in the Flink distribution) - provided "org.apache.flink:flink-statebackend-rocksdb_2.12:$flink_version" - testImplementation "org.apache.flink:flink-statebackend-rocksdb_2.12:$flink_version" - testImplementation "org.apache.flink:flink-streaming-java_2.12:$flink_version:tests" - testImplementation "org.apache.flink:flink-test-utils_2.12:$flink_version" - - miniCluster "org.apache.flink:flink-runtime-web_2.12:$flink_version" - } + implementation "org.apache.flink:flink-clients:$flink_version" + // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation + // configuration (https://issues.apache.org/jira/browse/BEAM-11732). + permitUnusedDeclared "org.apache.flink:flink-clients:$flink_version" + + implementation "org.apache.flink:flink-streaming-java:$flink_version" + // RocksDB state backend (included in the Flink distribution) + provided "org.apache.flink:flink-statebackend-rocksdb:$flink_version" + testImplementation "org.apache.flink:flink-statebackend-rocksdb:$flink_version" + testImplementation "org.apache.flink:flink-streaming-java:$flink_version:tests" + testImplementation "org.apache.flink:flink-test-utils:$flink_version" + + miniCluster "org.apache.flink:flink-runtime-web:$flink_version" implementation "org.apache.flink:flink-core:$flink_version" implementation "org.apache.flink:flink-metrics-core:$flink_version" @@ -439,3 +422,8 @@ createPipelineOptionsTableTask('Python') // Update the pipeline options documentation before running the tests test.dependsOn(generatePipelineOptionsTableJava) test.dependsOn(generatePipelineOptionsTablePython) + +// delegate spotlessApply to :runners:flink:spotlessApply +tasks.named("spotlessApply") { + dependsOn ":runners:flink:spotlessApply" +} diff --git a/runners/flink/job-server-container/Dockerfile b/runners/flink/job-server-container/Dockerfile index c5a81ecf6466..cbb73512400e 100644 --- a/runners/flink/job-server-container/Dockerfile +++ b/runners/flink/job-server-container/Dockerfile @@ -16,7 +16,7 @@ # limitations under the License. ############################################################################### -FROM openjdk:8 +FROM eclipse-temurin:11 MAINTAINER "Apache Beam " RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y libltdl7 diff --git a/runners/flink/job-server/flink_job_server.gradle b/runners/flink/job-server/flink_job_server.gradle index 56a58df4fb09..1c610477a444 100644 --- a/runners/flink/job-server/flink_job_server.gradle +++ b/runners/flink/job-server/flink_job_server.gradle @@ -171,7 +171,6 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean checkpoi excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' - excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesMapState' excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' excludeCategories 'org.apache.beam.sdk.testing.UsesSetState' diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index 205270c22332..fe6d628953d5 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -371,6 +371,7 @@ private void collectGlobalWindowStateDescriptor( private static class FlinkValueState implements ValueState { private final StateNamespace namespace; + private final String namespaceKey; private final String stateId; private final ValueStateDescriptor flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; @@ -383,6 +384,7 @@ private static class FlinkValueState implements ValueState { SerializablePipelineOptions pipelineOptions) { this.namespace = namespace; + this.namespaceKey = namespace.stringKey(); this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; @@ -394,8 +396,7 @@ private static class FlinkValueState implements ValueState { public void write(T input) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .update(input); } catch (Exception e) { throw new RuntimeException("Error updating state.", e); @@ -411,8 +412,7 @@ public ValueState readLater() { public T read() { try { return flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .value(); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -423,8 +423,7 @@ public T read() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -455,6 +454,7 @@ public int hashCode() { private static class FlinkOrderedListState implements OrderedListState { private final StateNamespace namespace; + private final String namespaceKey; private final ListStateDescriptor> flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; @@ -465,6 +465,7 @@ private static class FlinkOrderedListState implements OrderedListState { Coder coder, SerializablePipelineOptions pipelineOptions) { this.namespace = namespace; + this.namespaceKey = namespace.stringKey(); this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = new ListStateDescriptor<>( @@ -483,7 +484,7 @@ public void clearRange(Instant minTimestamp, Instant limitTimestamp) { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); partitionedState.update(Lists.newArrayList(sortedMap.values())); } catch (Exception e) { throw new RuntimeException("Error adding to bag state.", e); @@ -500,7 +501,7 @@ public void add(TimestampedValue value) { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); partitionedState.add(value); } catch (Exception e) { throw new RuntimeException("Error adding to bag state.", e); @@ -516,7 +517,7 @@ public Boolean read() { Iterable> result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .get(); return result == null; } catch (Exception e) { @@ -542,7 +543,7 @@ private SortedMap> readAsMap() { try { ListState> partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); listValues = MoreObjects.firstNonNull(partitionedState.get(), Collections.emptyList()); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -564,8 +565,7 @@ public GroupingState, Iterable>> readLat public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -576,6 +576,7 @@ public void clear() { private static class FlinkBagState implements BagState { private final StateNamespace namespace; + private final String namespaceKey; private final String stateId; private final ListStateDescriptor flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; @@ -589,6 +590,7 @@ private static class FlinkBagState implements BagState { SerializablePipelineOptions pipelineOptions) { this.namespace = namespace; + this.namespaceKey = namespace.stringKey(); this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; this.storesVoidValues = coder instanceof VoidCoder; @@ -601,7 +603,7 @@ public void add(T input) { try { ListState partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); if (storesVoidValues) { Preconditions.checkState(input == null, "Expected to a null value but was: %s", input); // Flink does not allow storing null values @@ -625,7 +627,7 @@ public Iterable read() { try { ListState partitionedState = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); Iterable result = partitionedState.get(); if (storesVoidValues) { return () -> { @@ -662,7 +664,7 @@ public Boolean read() { Iterable result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .get(); return result == null; } catch (Exception e) { @@ -681,8 +683,7 @@ public ReadableState readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -715,6 +716,7 @@ private static class FlinkCombiningState implements CombiningState { private final StateNamespace namespace; + private final String namespaceKey; private final String stateId; private final Combine.CombineFn combineFn; private final ValueStateDescriptor flinkStateDescriptor; @@ -729,6 +731,7 @@ private static class FlinkCombiningState SerializablePipelineOptions pipelineOptions) { this.namespace = namespace; + this.namespaceKey = namespace.stringKey(); this.stateId = stateId; this.combineFn = combineFn; this.flinkStateBackend = flinkStateBackend; @@ -748,7 +751,7 @@ public void add(InputT value) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -766,7 +769,7 @@ public void addAccum(AccumT accum) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -785,8 +788,7 @@ public AccumT getAccum() { try { AccumT accum = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .value(); return accum != null ? accum : combineFn.createAccumulator(); } catch (Exception e) { @@ -804,7 +806,7 @@ public OutputT read() { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); AccumT accum = state.value(); if (accum != null) { @@ -825,7 +827,7 @@ public Boolean read() { try { return flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .value() == null; } catch (Exception e) { @@ -844,8 +846,7 @@ public ReadableState readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -878,6 +879,7 @@ private static class FlinkCombiningStateWithContext implements CombiningState { private final StateNamespace namespace; + private final String namespaceKey; private final String stateId; private final CombineWithContext.CombineFnWithContext combineFn; private final ValueStateDescriptor flinkStateDescriptor; @@ -894,6 +896,7 @@ private static class FlinkCombiningStateWithContext SerializablePipelineOptions pipelineOptions) { this.namespace = namespace; + this.namespaceKey = namespace.stringKey(); this.stateId = stateId; this.combineFn = combineFn; this.flinkStateBackend = flinkStateBackend; @@ -914,7 +917,7 @@ public void add(InputT value) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -932,7 +935,7 @@ public void addAccum(AccumT accum) { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { @@ -951,8 +954,7 @@ public AccumT getAccum() { try { AccumT accum = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .value(); return accum != null ? accum : combineFn.createAccumulator(context); } catch (Exception e) { @@ -970,7 +972,7 @@ public OutputT read() { try { org.apache.flink.api.common.state.ValueState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); AccumT accum = state.value(); if (accum != null) { @@ -991,7 +993,7 @@ public Boolean read() { try { return flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .value() == null; } catch (Exception e) { @@ -1010,8 +1012,7 @@ public ReadableState readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1168,6 +1169,7 @@ public int hashCode() { private static class FlinkMapState implements MapState { private final StateNamespace namespace; + private final String namespaceKey; private final String stateId; private final MapStateDescriptor flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; @@ -1180,6 +1182,7 @@ private static class FlinkMapState implements MapState mapValueCoder, SerializablePipelineOptions pipelineOptions) { this.namespace = namespace; + this.namespaceKey = namespace.stringKey(); this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = @@ -1204,7 +1207,7 @@ public ReadableState get(final KeyT input) { ValueT value = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .get(key); return (value != null) ? value : defaultValue; } catch (Exception e) { @@ -1223,8 +1226,7 @@ public ReadableState get(final KeyT input) { public void put(KeyT key, ValueT value) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .put(key, value); } catch (Exception e) { throw new RuntimeException("Error put kv to state.", e); @@ -1235,17 +1237,12 @@ public void put(KeyT key, ValueT value) { public ReadableState computeIfAbsent( final KeyT key, Function mappingFunction) { try { - ValueT current = - flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) - .get(key); - + org.apache.flink.api.common.state.MapState state = + flinkStateBackend.getPartitionedState( + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); + ValueT current = state.get(key); if (current == null) { - flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) - .put(key, mappingFunction.apply(key)); + state.put(key, mappingFunction.apply(key)); } return ReadableStates.immediate(current); } catch (Exception e) { @@ -1257,8 +1254,7 @@ public ReadableState computeIfAbsent( public void remove(KeyT key) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .remove(key); } catch (Exception e) { throw new RuntimeException("Error remove map state key.", e); @@ -1274,7 +1270,7 @@ public Iterable read() { Iterable result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .keys(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1298,7 +1294,7 @@ public Iterable read() { Iterable result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .values(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1322,7 +1318,7 @@ public Iterable> read() { Iterable> result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .entries(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1360,8 +1356,7 @@ public ReadableState>> readLater() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); @@ -1393,6 +1388,7 @@ public int hashCode() { private static class FlinkSetState implements SetState { private final StateNamespace namespace; + private final String namespaceKey; private final String stateId; private final MapStateDescriptor flinkStateDescriptor; private final KeyedStateBackend flinkStateBackend; @@ -1404,6 +1400,7 @@ private static class FlinkSetState implements SetState { Coder coder, SerializablePipelineOptions pipelineOptions) { this.namespace = namespace; + this.namespaceKey = namespace.stringKey(); this.stateId = stateId; this.flinkStateBackend = flinkStateBackend; this.flinkStateDescriptor = @@ -1418,8 +1415,7 @@ public ReadableState contains(final T t) { try { Boolean result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .get(t); return ReadableStates.immediate(result != null && result); } catch (Exception e) { @@ -1432,7 +1428,7 @@ public ReadableState addIfAbsent(final T t) { try { org.apache.flink.api.common.state.MapState state = flinkStateBackend.getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor); boolean alreadyContained = state.contains(t); if (!alreadyContained) { state.put(t, true); @@ -1447,8 +1443,7 @@ public ReadableState addIfAbsent(final T t) { public void remove(T t) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .remove(t); } catch (Exception e) { throw new RuntimeException("Error remove value to state.", e); @@ -1464,8 +1459,7 @@ public SetState readLater() { public void add(T value) { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .put(value, true); } catch (Exception e) { throw new RuntimeException("Error add value to state.", e); @@ -1481,7 +1475,7 @@ public Boolean read() { Iterable result = flinkStateBackend .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .keys(); return result == null || Iterables.isEmpty(result); } catch (Exception e) { @@ -1501,8 +1495,7 @@ public Iterable read() { try { Iterable result = flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .keys(); return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); } catch (Exception e) { @@ -1514,8 +1507,7 @@ public Iterable read() { public void clear() { try { flinkStateBackend - .getPartitionedState( - namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .getPartitionedState(namespaceKey, StringSerializer.INSTANCE, flinkStateDescriptor) .clear(); } catch (Exception e) { throw new RuntimeException("Error clearing state.", e); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java index 10e20a6d47d3..c679c0725051 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java @@ -25,7 +25,6 @@ import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.OperatorStateBackend; -import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.junit.Ignore; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,10 +40,9 @@ public class FlinkBroadcastStateInternalsTest extends StateInternalsTest { @Override protected StateInternals createStateInternals() { - MemoryStateBackend backend = new MemoryStateBackend(); try { OperatorStateBackend operatorStateBackend = - backend.createOperatorStateBackend( + MemoryStateBackendWrapper.createOperatorStateBackend( new DummyEnvironment("test", 1, 0), "", Collections.emptyList(), null); return new FlinkBroadcastStateInternals<>( 1, diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index a2d6f5027abb..d0338ec3b0d3 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -47,7 +47,6 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateBackend; -import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.state.ttl.TtlTimeProvider; import org.hamcrest.Matchers; import org.joda.time.Instant; @@ -185,9 +184,8 @@ public void testGlobalWindowWatermarkHoldClear() throws Exception { } public static KeyedStateBackend createStateBackend() throws Exception { - MemoryStateBackend backend = new MemoryStateBackend(); AbstractKeyedStateBackend keyedStateBackend = - backend.createKeyedStateBackend( + MemoryStateBackendWrapper.createKeyedStateBackend( new DummyEnvironment("test", 1, 0), new JobID(), "test_op", diff --git a/runners/flink/src/test/resources/flink-conf.yaml b/runners/flink/src/test/resources/flink-conf.yaml index 84773535cdaf..57d659b4095c 100644 --- a/runners/flink/src/test/resources/flink-conf.yaml +++ b/runners/flink/src/test/resources/flink-conf.yaml @@ -19,3 +19,4 @@ parallelism.default: 23 taskmanager.memory.network.fraction: 0.2 taskmanager.memory.network.max: 2gb +taskmanager.memory.managed.size: 1gb diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index df2270d3b653..811a3c15f836 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -16,6 +16,8 @@ * limitations under the License. */ +import static org.apache.beam.gradle.BeamModulePlugin.getSupportedJavaVersion + import groovy.json.JsonOutput plugins { id 'org.apache.beam.module' } @@ -185,7 +187,7 @@ def commonLegacyExcludeCategories = [ 'org.apache.beam.sdk.testing.UsesGaugeMetrics', 'org.apache.beam.sdk.testing.UsesMultimapState', 'org.apache.beam.sdk.testing.UsesTestStream', - 'org.apache.beam.sdk.testing.UsesParDoLifecycle', + 'org.apache.beam.sdk.testing.UsesParDoLifecycle', // doesn't support remote runner 'org.apache.beam.sdk.testing.UsesMetricsPusher', 'org.apache.beam.sdk.testing.UsesBundleFinalizer', ] @@ -273,12 +275,75 @@ def createRunnerV2ValidatesRunnerTest = { Map args -> } } +tasks.register('examplesJavaRunnerV2IntegrationTestDistroless', Test.class) { + group = "verification" + dependsOn 'buildAndPushDistrolessContainerImage' + def javaVer = project.findProperty('testJavaVersion') + def repository = "us.gcr.io/apache-beam-testing/${System.getenv('USER')}" + def tag = project.findProperty('dockerTag') + def imageURL = "${repository}/beam_${javaVer}_sdk_distroless:${tag}" + def pipelineOptions = [ + "--runner=TestDataflowRunner", + "--project=${gcpProject}", + "--region=${gcpRegion}", + "--tempRoot=${dataflowValidatesTempRoot}", + "--sdkContainerImage=${imageURL}", + "--experiments=use_unified_worker,use_runner_v2", + "--firestoreDb=${firestoreDb}", + ] + systemProperty "beamTestPipelineOptions", JsonOutput.toJson(pipelineOptions) + + include '**/*IT.class' + + maxParallelForks 4 + classpath = configurations.examplesJavaIntegrationTest + testClassesDirs = files(project(":examples:java").sourceSets.test.output.classesDirs) + useJUnit { } +} + +tasks.register('buildAndPushDistrolessContainerImage', Task.class) { + // Only Java 17 and 21 are supported. + // See https://github.com/GoogleContainerTools/distroless/tree/main/java#image-contents. + def allowed = ["java17", "java21"] + doLast { + def javaVer = project.findProperty('testJavaVersion') + if (!allowed.contains(javaVer)) { + throw new GradleException("testJavaVersion must be one of ${allowed}, got: ${javaVer}") + } + if (!project.hasProperty('dockerTag')) { + throw new GradleException("dockerTag is missing but required") + } + def repository = "us.gcr.io/apache-beam-testing/${System.getenv('USER')}" + def tag = project.findProperty('dockerTag') + def imageURL = "${repository}/beam_${javaVer}_sdk_distroless:${tag}" + exec { + executable 'docker' + workingDir rootDir + args = [ + 'buildx', + 'build', + '-t', + imageURL, + '-f', + 'sdks/java/container/Dockerfile-distroless', + "--build-arg=BEAM_BASE=gcr.io/apache-beam-testing/beam-sdk/beam_${javaVer}_sdk", + "--build-arg=DISTROLESS_BASE=gcr.io/distroless/${javaVer}-debian12", + '.' + ] + } + exec { + executable 'docker' + args = ['push', imageURL] + } + } +} + // Push docker images to a container registry for use within tests. // NB: Tasks which consume docker images from the registry should depend on this // task directly ('dependsOn buildAndPushDockerJavaContainer'). This ensures the correct // task ordering such that the registry doesn't get cleaned up prior to task completion. def buildAndPushDockerJavaContainer = tasks.register("buildAndPushDockerJavaContainer") { - def javaVer = "java8" + def javaVer = getSupportedJavaVersion() if(project.hasProperty('testJavaVersion')) { javaVer = "java${project.getProperty('testJavaVersion')}" } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index abe7d0d364d3..ce99958c57fd 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -41,10 +41,12 @@ import com.google.auto.value.AutoValue; import java.io.BufferedWriter; import java.io.File; +import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -179,6 +181,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.hash.HashCode; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.hash.Hashing; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Files; import org.joda.time.DateTimeUtils; import org.joda.time.DateTimeZone; @@ -257,6 +260,92 @@ public class DataflowRunner extends PipelineRunner { /** Dataflow service endpoints are expected to match this pattern. */ static final String ENDPOINT_REGEXP = "https://[\\S]*googleapis\\.com[/]?"; + /** + * Replaces GCS file paths with local file paths by downloading the GCS files locally. This is + * useful when files need to be accessed locally before being staged to Dataflow. + * + * @param filesToStage List of file paths that may contain GCS paths (gs://) and local paths + * @return List of local file paths where any GCS paths have been downloaded locally + * @throws RuntimeException if there are errors copying GCS files locally + */ + public static List replaceGcsFilesWithLocalFiles(List filesToStage) { + List processedFiles = new ArrayList<>(); + + for (String fileToStage : filesToStage) { + String localPath; + if (fileToStage.contains("=")) { + // Handle files with staging name specified + String[] components = fileToStage.split("=", 2); + String stagingName = components[0]; + String filePath = components[1]; + + if (filePath.startsWith("gs://")) { + try { + // Create temp file with exact same name as GCS file + String gcsFileName = filePath.substring(filePath.lastIndexOf('/') + 1); + File tempDir = Files.createTempDir(); + tempDir.deleteOnExit(); + File tempFile = new File(tempDir, gcsFileName); + tempFile.deleteOnExit(); + + LOG.info( + "Downloading GCS file {} to local temp file {}", + filePath, + tempFile.getAbsolutePath()); + + // Copy GCS file to local temp file + ResourceId source = FileSystems.matchNewResource(filePath, false); + try (ReadableByteChannel reader = FileSystems.open(source); + FileOutputStream writer = new FileOutputStream(tempFile)) { + ByteStreams.copy(Channels.newInputStream(reader), writer); + } + + localPath = stagingName + "=" + tempFile.getAbsolutePath(); + LOG.info("Replaced GCS path {} with local path {}", fileToStage, localPath); + } catch (IOException e) { + throw new RuntimeException("Failed to copy GCS file locally: " + filePath, e); + } + } else { + localPath = fileToStage; + } + } else { + // Handle files without staging name + if (fileToStage.startsWith("gs://")) { + try { + // Create temp file with exact same name as GCS file + String gcsFileName = fileToStage.substring(fileToStage.lastIndexOf('/') + 1); + File tempDir = Files.createTempDir(); + tempDir.deleteOnExit(); + File tempFile = new File(tempDir, gcsFileName); + tempFile.deleteOnExit(); + + LOG.info( + "Downloading GCS file {} to local temp file {}", + fileToStage, + tempFile.getAbsolutePath()); + + // Copy GCS file to local temp file + ResourceId source = FileSystems.matchNewResource(fileToStage, false); + try (ReadableByteChannel reader = FileSystems.open(source); + FileOutputStream writer = new FileOutputStream(tempFile)) { + ByteStreams.copy(Channels.newInputStream(reader), writer); + } + + localPath = tempFile.getAbsolutePath(); + LOG.info("Replaced GCS path {} with local path {}", fileToStage, localPath); + } catch (IOException e) { + throw new RuntimeException("Failed to copy GCS file locally: " + fileToStage, e); + } + } else { + localPath = fileToStage; + } + } + processedFiles.add(localPath); + } + + return processedFiles; + } + /** * Construct a runner from the provided options. * @@ -312,6 +401,9 @@ && isServiceEndpoint(dataflowOptions.getDataflowEndpoint())) { } if (dataflowOptions.getFilesToStage() != null) { + // Replace GCS file paths with local file paths + dataflowOptions.setFilesToStage( + replaceGcsFilesWithLocalFiles(dataflowOptions.getFilesToStage())); // The user specifically requested these files, so fail now if they do not exist. // (automatically detected classpath elements are permitted to not exist, so later // staging will not fail on nonexistent files) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/options/DataflowStreamingPipelineOptions.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/options/DataflowStreamingPipelineOptions.java index 6a0208f1447f..61c38dde2b42 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/options/DataflowStreamingPipelineOptions.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/options/DataflowStreamingPipelineOptions.java @@ -20,6 +20,7 @@ import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.DefaultValueFactory; import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.Hidden; import org.apache.beam.sdk.options.PipelineOptions; import org.joda.time.Duration; @@ -219,10 +220,8 @@ public interface DataflowStreamingPipelineOptions extends PipelineOptions { void setWindmillServiceStreamMaxBackoffMillis(int value); - @Description( - "If true, Dataflow streaming pipeline will be running in direct path mode." - + " VMs must have IPv6 enabled for this to work.") - @Default.Boolean(false) + @Description("Enables direct path mode for streaming engine.") + @Default.InstanceFactory(EnableWindmillServiceDirectPathFactory.class) boolean getIsWindmillServiceDirectPathEnabled(); void setIsWindmillServiceDirectPathEnabled(boolean isWindmillServiceDirectPathEnabled); @@ -300,4 +299,12 @@ public Integer create(PipelineOptions options) { return streamingOptions.isEnableStreamingEngine() ? Integer.MAX_VALUE : 1; } } + + /** EnableStreamingEngine defaults to false unless one of the experiment is set. */ + class EnableWindmillServiceDirectPathFactory implements DefaultValueFactory { + @Override + public Boolean create(PipelineOptions options) { + return ExperimentalOptions.hasExperiment(options, "enable_windmill_service_direct_path"); + } + } } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index 01ceac9da585..106b15de6e4d 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -77,6 +77,7 @@ import java.io.FileNotFoundException; import java.io.IOException; import java.io.Serializable; +import java.io.Writer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Files; @@ -1155,6 +1156,19 @@ public void testNonExistentStagingLocation() throws IOException { assertValidJob(jobCaptor.getValue()); } + @Test + public void testReplaceGcsFilesWithLocalFilesEmptyList() { + List filesToStage = Collections.emptyList(); + List processedFiles = DataflowRunner.replaceGcsFilesWithLocalFiles(filesToStage); + assertTrue(processedFiles.isEmpty()); + } + + @Test(expected = RuntimeException.class) + public void testReplaceGcsFilesWithLocalFilesIOError() { + List filesToStage = Collections.singletonList("gs://non-existent-bucket/file.jar"); + DataflowRunner.replaceGcsFilesWithLocalFiles(filesToStage); + } + @Test public void testNonExistentProfileLocation() throws IOException { DataflowPipelineOptions options = buildPipelineOptions(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java index 94c894608a47..a28a5e989c88 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java @@ -82,7 +82,9 @@ public static T initializeGlobalStateAn @SuppressWarnings("Slf4jIllegalPassedClass") public static void initializeLogging(Class workerHarnessClass) { - /* Set up exception handling tied to the workerHarnessClass. */ + // Set up exception handling for raw Threads tied to the workerHarnessClass. + // Does NOT handle exceptions thrown by threads created by + // ScheduledExecutors/ScheduledExecutorServices. Thread.setDefaultUncaughtExceptionHandler( new WorkerUncaughtExceptionHandler(LoggerFactory.getLogger(workerHarnessClass))); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java index 30e920119120..91baefa0be4c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java @@ -38,9 +38,13 @@ /** * Converts metric updates to {@link PerStepNamespaceMetrics} protos. Currently we only support - * converting metrics from {@link BigQuerySinkMetrics} with this converter. + * converting metrics from {@link BigQuerySinkMetrics} and from {@link KafkaSinkMetrics} with this + * converter. */ public class MetricsToPerStepNamespaceMetricsConverter { + // Avoids to introduce mandatory kafka-io dependency to Dataflow worker + // keep in sync with org.apache.beam.sdk.io.kafka.KafkaSinkMetrics.METRICS_NAMESPACE + public static String KAFKA_SINK_METRICS_NAMESPACE = "KafkaSink"; private static Optional getParsedMetricName( MetricName metricName, @@ -65,7 +69,10 @@ private static Optional convertCounterToMetricValue( MetricName metricName, Long value, Map parsedPerWorkerMetricsCache) { - if (value == 0 || !metricName.getNamespace().equals(BigQuerySinkMetrics.METRICS_NAMESPACE)) { + + if (value == 0 + || (!metricName.getNamespace().equals(BigQuerySinkMetrics.METRICS_NAMESPACE) + && !metricName.getNamespace().equals(KAFKA_SINK_METRICS_NAMESPACE))) { return Optional.empty(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java index a413c2c03dbe..558848f488a7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java @@ -77,6 +77,7 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class SimpleParDoFn implements ParDoFn { + // TODO: Remove once Distributions has shipped. @VisibleForTesting static final String OUTPUTS_PER_ELEMENT_EXPERIMENT = "outputs_per_element_counter"; @@ -174,6 +175,7 @@ private boolean hasExperiment(String experiment) { /** Simple state tracker to calculate PerElementOutputCount counter. */ private interface OutputsPerElementTracker { + void onOutput(); void onProcessElement(); @@ -182,6 +184,7 @@ private interface OutputsPerElementTracker { } private class OutputsPerElementTrackerImpl implements OutputsPerElementTracker { + private long outputsPerElement; private final Counter counter; @@ -214,6 +217,7 @@ private void reset() { /** No-op {@link OutputsPerElementTracker} implementation used when the counter is disabled. */ private static class NoopOutputsPerElementTracker implements OutputsPerElementTracker { + private NoopOutputsPerElementTracker() {} public static final OutputsPerElementTracker INSTANCE = new NoopOutputsPerElementTracker(); @@ -516,10 +520,14 @@ private void registerStateCleanup( private Instant earliestAllowableCleanupTime( BoundedWindow window, WindowingStrategy windowingStrategy) { - return window - .maxTimestamp() - .plus(windowingStrategy.getAllowedLateness()) - .plus(Duration.millis(1L)); + Instant cleanupTime = + window + .maxTimestamp() + .plus(windowingStrategy.getAllowedLateness()) + .plus(Duration.millis(1L)); + return cleanupTime.isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE) + ? BoundedWindow.TIMESTAMP_MAX_VALUE + : cleanupTime; } /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index ecdba404151e..0112ab4af80a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -17,11 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.remoteChannel; -import com.google.api.services.dataflow.model.CounterUpdate; import com.google.api.services.dataflow.model.MapTask; import com.google.auto.value.AutoValue; +import java.io.PrintWriter; import java.util.List; import java.util.Map; import java.util.Optional; @@ -33,24 +33,28 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import javax.annotation.Nullable; import org.apache.beam.runners.core.metrics.MetricsLogger; import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; -import org.apache.beam.runners.dataflow.worker.counters.DataflowCounterUpdateExtractor; import org.apache.beam.runners.dataflow.worker.status.DebugCapture; import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ComputationStateCache; import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor; import org.apache.beam.runners.dataflow.worker.streaming.config.ComputationConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.FixedGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingApplianceComputationConfigFetcher; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingEngineComputationConfigFetcher; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; +import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandleImpl; +import org.apache.beam.runners.dataflow.worker.streaming.harness.FanOutStreamingEngineWorkerHarness; import org.apache.beam.runners.dataflow.worker.streaming.harness.SingleSourceWorkerHarness; import org.apache.beam.runners.dataflow.worker.streaming.harness.SingleSourceWorkerHarness.GetWorkSender; import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; @@ -59,12 +63,16 @@ import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingWorkerStatusReporter; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; +import org.apache.beam.runners.dataflow.worker.windmill.ApplianceWindmillClient; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; +import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter; @@ -77,8 +85,16 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcDispatcherClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCache; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingRemoteStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.IsolationChannel; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottledTimeTracker; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributors; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.StreamingWorkScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.StreamingApplianceFailureTracker; @@ -88,6 +104,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ApplianceHeartbeatSender; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.StreamPoolHeartbeatSender; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.sdk.fn.IdGenerators; import org.apache.beam.sdk.fn.JvmInitializers; @@ -96,18 +113,25 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** Implements a Streaming Dataflow worker. */ +/** + * For internal use only. + * + *

Implements a Streaming Dataflow worker. + */ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) +@Internal public final class StreamingDataflowWorker { /** @@ -140,7 +164,9 @@ public final class StreamingDataflowWorker { private static final int DEFAULT_STATUS_PORT = 8081; private static final Random CLIENT_ID_GENERATOR = new Random(); private static final String CHANNELZ_PATH = "/channelz"; - public static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL = + private static final String BEAM_FN_API_EXPERIMENT = "beam_fn_api"; + private static final String ENABLE_IPV6_EXPERIMENT = "enable_private_ipv6_google_access"; + private static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL_EXPERIMENT = "streaming_engine_use_job_settings_for_heartbeat_pool"; private final WindmillStateCache stateCache; @@ -155,9 +181,8 @@ public final class StreamingDataflowWorker { private final ReaderCache readerCache; private final DataflowExecutionStateSampler sampler = DataflowExecutionStateSampler.instance(); private final ActiveWorkRefresher activeWorkRefresher; - private final WorkCommitter workCommitter; private final StreamingWorkerStatusReporter workerStatusReporter; - private final StreamingCounters streamingCounters; + private final int numCommitThreads; private StreamingDataflowWorker( WindmillServerStub windmillServer, @@ -170,17 +195,17 @@ private StreamingDataflowWorker( DataflowWorkerHarnessOptions options, HotKeyLogger hotKeyLogger, Supplier clock, - StreamingWorkerStatusReporter workerStatusReporter, + StreamingWorkerStatusReporterFactory streamingWorkerStatusReporterFactory, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCounters streamingCounters, MemoryMonitor memoryMonitor, GrpcWindmillStreamFactory windmillStreamFactory, - Function executorSupplier, - ConcurrentMap stageInfoMap) { + ScheduledExecutorService activeWorkRefreshExecutorFn, + ConcurrentMap stageInfoMap, + @Nullable GrpcDispatcherClient dispatcherClient) { // Register standard file systems. FileSystems.setDefaultPipelineOptions(options); - this.configFetcher = configFetcher; this.computationStateCache = computationStateCache; this.stateCache = windmillStateCache; @@ -189,34 +214,13 @@ private StreamingDataflowWorker( Duration.standardSeconds(options.getReaderCacheTimeoutSec()), Executors.newCachedThreadPool()); this.options = options; - - boolean windmillServiceEnabled = options.isEnableStreamingEngine(); - - int numCommitThreads = 1; - if (windmillServiceEnabled && options.getWindmillServiceCommitThreads() > 0) { - numCommitThreads = options.getWindmillServiceCommitThreads(); - } - - this.workCommitter = - windmillServiceEnabled - ? StreamingEngineWorkCommitter.builder() - .setCommitWorkStreamFactory( - WindmillStreamPool.create( - numCommitThreads, - COMMIT_STREAM_TIMEOUT, - windmillServer::commitWorkStream) - ::getCloseableStream) - .setNumCommitSenders(numCommitThreads) - .setOnCommitComplete(this::onCompleteCommit) - .build() - : StreamingApplianceWorkCommitter.create( - windmillServer::commitWork, this::onCompleteCommit); - this.workUnitExecutor = workUnitExecutor; - - this.workerStatusReporter = workerStatusReporter; - this.streamingCounters = streamingCounters; this.memoryMonitor = BackgroundMemoryMonitor.create(memoryMonitor); + this.numCommitThreads = + options.isEnableStreamingEngine() + ? Math.max(options.getWindmillServiceCommitThreads(), 1) + : 1; + StreamingWorkScheduler streamingWorkScheduler = StreamingWorkScheduler.create( options, @@ -233,110 +237,200 @@ private StreamingDataflowWorker( ID_GENERATOR, configFetcher.getGlobalConfigHandle(), stageInfoMap); - ThrottlingGetDataMetricTracker getDataMetricTracker = new ThrottlingGetDataMetricTracker(memoryMonitor); - WorkerStatusPages workerStatusPages = - WorkerStatusPages.create(DEFAULT_STATUS_PORT, memoryMonitor); - StreamingWorkerStatusPages.Builder statusPagesBuilder = StreamingWorkerStatusPages.builder(); - int stuckCommitDurationMillis; - GetDataClient getDataClient; - HeartbeatSender heartbeatSender; - if (windmillServiceEnabled) { - WindmillStreamPool getDataStreamPool = - WindmillStreamPool.create( - Math.max(1, options.getWindmillGetDataStreamCount()), - GET_DATA_STREAM_TIMEOUT, - windmillServer::getDataStream); - getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool); - // Experiment gates the logic till backend changes are rollback safe - if (!DataflowRunner.hasExperiment( - options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL) - || options.getUseSeparateWindmillHeartbeatStreams() != null) { + // Status page members. Different implementations on whether the harness is streaming engine + // direct path, streaming engine cloud path, or streaming appliance. + @Nullable ChannelzServlet channelzServlet = null; + Consumer getDataStatusProvider; + Supplier currentActiveCommitBytesProvider; + if (isDirectPathPipeline(options)) { + WeightedSemaphore maxCommitByteSemaphore = Commits.maxCommitByteSemaphore(); + FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkerHarness = + FanOutStreamingEngineWorkerHarness.create( + createJobHeader(options, clientId), + GetWorkBudget.builder() + .setItems(chooseMaxBundlesOutstanding(options)) + .setBytes(MAX_GET_WORK_FETCH_BYTES) + .build(), + windmillStreamFactory, + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> + computationStateCache + .get(processingContext.computationId()) + .ifPresent( + computationState -> { + memoryMonitor.waitForResources("GetWork"); + streamingWorkScheduler.scheduleWork( + computationState, + workItem, + watermarks, + processingContext, + getWorkStreamLatencies); + }), + createFanOutStubFactory(options), + GetWorkBudgetDistributors.distributeEvenly(), + Preconditions.checkNotNull(dispatcherClient), + commitWorkStream -> + StreamingEngineWorkCommitter.builder() + // Share the commitByteSemaphore across all created workCommitters. + .setCommitByteSemaphore(maxCommitByteSemaphore) + .setBackendWorkerToken(commitWorkStream.backendWorkerToken()) + .setOnCommitComplete(this::onCompleteCommit) + .setNumCommitSenders(Math.max(options.getWindmillServiceCommitThreads(), 1)) + .setCommitWorkStreamFactory( + () -> CloseableStream.create(commitWorkStream, () -> {})) + .build(), + getDataMetricTracker); + getDataStatusProvider = getDataMetricTracker::printHtml; + currentActiveCommitBytesProvider = + fanOutStreamingEngineWorkerHarness::currentActiveCommitBytes; + channelzServlet = + createChannelzServlet( + options, fanOutStreamingEngineWorkerHarness::currentWindmillEndpoints); + this.streamingWorkerHarness = fanOutStreamingEngineWorkerHarness; + } else { + // Non-direct path pipelines. + Windmill.GetWorkRequest request = + Windmill.GetWorkRequest.newBuilder() + .setClientId(clientId) + .setMaxItems(chooseMaxBundlesOutstanding(options)) + .setMaxBytes(MAX_GET_WORK_FETCH_BYTES) + .build(); + GetDataClient getDataClient; + HeartbeatSender heartbeatSender; + WorkCommitter workCommitter; + GetWorkSender getWorkSender; + if (options.isEnableStreamingEngine()) { + WindmillStreamPool getDataStreamPool = + WindmillStreamPool.create( + Math.max(1, options.getWindmillGetDataStreamCount()), + GET_DATA_STREAM_TIMEOUT, + windmillServer::getDataStream); + getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool); heartbeatSender = - StreamPoolHeartbeatSender.Create( - Boolean.TRUE.equals(options.getUseSeparateWindmillHeartbeatStreams()) - ? separateHeartbeatPool(windmillServer) - : getDataStreamPool); - + createStreamingEngineHeartbeatSender( + options, windmillServer, getDataStreamPool, configFetcher.getGlobalConfigHandle()); + channelzServlet = + createChannelzServlet(options, windmillServer::getWindmillServiceEndpoints); + workCommitter = + StreamingEngineWorkCommitter.builder() + .setCommitWorkStreamFactory( + WindmillStreamPool.create( + numCommitThreads, + COMMIT_STREAM_TIMEOUT, + windmillServer::commitWorkStream) + ::getCloseableStream) + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) + .setNumCommitSenders(numCommitThreads) + .setOnCommitComplete(this::onCompleteCommit) + .build(); + getWorkSender = + GetWorkSender.forStreamingEngine( + receiver -> windmillServer.getWorkStream(request, receiver)); } else { - heartbeatSender = - StreamPoolHeartbeatSender.Create( - separateHeartbeatPool(windmillServer), - getDataStreamPool, - configFetcher.getGlobalConfigHandle()); + getDataClient = new ApplianceGetDataClient(windmillServer, getDataMetricTracker); + heartbeatSender = new ApplianceHeartbeatSender(windmillServer::getData); + workCommitter = + StreamingApplianceWorkCommitter.create( + windmillServer::commitWork, this::onCompleteCommit); + getWorkSender = GetWorkSender.forAppliance(() -> windmillServer.getWork(request)); } - stuckCommitDurationMillis = - options.getStuckCommitDurationMillis() > 0 ? options.getStuckCommitDurationMillis() : 0; - statusPagesBuilder - .setDebugCapture( - new DebugCapture.Manager(options, workerStatusPages.getDebugCapturePages())) - .setChannelzServlet( - new ChannelzServlet( - CHANNELZ_PATH, options, windmillServer::getWindmillServiceEndpoints)) - .setWindmillStreamFactory(windmillStreamFactory); - } else { - getDataClient = new ApplianceGetDataClient(windmillServer, getDataMetricTracker); - heartbeatSender = new ApplianceHeartbeatSender(windmillServer::getData); - stuckCommitDurationMillis = 0; + getDataStatusProvider = getDataClient::printHtml; + currentActiveCommitBytesProvider = workCommitter::currentActiveCommitBytes; + + this.streamingWorkerHarness = + SingleSourceWorkerHarness.builder() + .setStreamingWorkScheduler(streamingWorkScheduler) + .setWorkCommitter(workCommitter) + .setGetDataClient(getDataClient) + .setComputationStateFetcher(this.computationStateCache::get) + .setWaitForResources(() -> memoryMonitor.waitForResources("GetWork")) + .setHeartbeatSender(heartbeatSender) + .setThrottledTimeTracker(windmillServer::getAndResetThrottleTime) + .setGetWorkSender(getWorkSender) + .build(); } + this.workerStatusReporter = + streamingWorkerStatusReporterFactory.createStatusReporter(streamingWorkerHarness); this.activeWorkRefresher = new ActiveWorkRefresher( clock, options.getActiveWorkRefreshPeriodMillis(), - stuckCommitDurationMillis, + options.isEnableStreamingEngine() + ? Math.max(options.getStuckCommitDurationMillis(), 0) + : 0, computationStateCache::getAllPresentComputations, sampler, - executorSupplier.apply("RefreshWork"), + activeWorkRefreshExecutorFn, getDataMetricTracker::trackHeartbeats); this.statusPages = - statusPagesBuilder + createStatusPageBuilder(options, windmillStreamFactory, memoryMonitor) .setClock(clock) .setClientId(clientId) .setIsRunning(running) - .setStatusPages(workerStatusPages) .setStateCache(stateCache) .setComputationStateCache(this.computationStateCache) - .setCurrentActiveCommitBytes(workCommitter::currentActiveCommitBytes) - .setGetDataStatusProvider(getDataClient::printHtml) .setWorkUnitExecutor(workUnitExecutor) .setGlobalConfigHandle(configFetcher.getGlobalConfigHandle()) + .setChannelzServlet(channelzServlet) + .setGetDataStatusProvider(getDataStatusProvider) + .setCurrentActiveCommitBytes(currentActiveCommitBytesProvider) .build(); - Windmill.GetWorkRequest request = - Windmill.GetWorkRequest.newBuilder() - .setClientId(clientId) - .setMaxItems(chooseMaximumBundlesOutstanding()) - .setMaxBytes(MAX_GET_WORK_FETCH_BYTES) - .build(); - - this.streamingWorkerHarness = - SingleSourceWorkerHarness.builder() - .setStreamingWorkScheduler(streamingWorkScheduler) - .setWorkCommitter(workCommitter) - .setGetDataClient(getDataClient) - .setComputationStateFetcher(this.computationStateCache::get) - .setWaitForResources(() -> memoryMonitor.waitForResources("GetWork")) - .setHeartbeatSender(heartbeatSender) - .setGetWorkSender( - windmillServiceEnabled - ? GetWorkSender.forStreamingEngine( - receiver -> windmillServer.getWorkStream(request, receiver)) - : GetWorkSender.forAppliance(() -> windmillServer.getWork(request))) - .build(); - - LOG.debug("windmillServiceEnabled: {}", windmillServiceEnabled); + LOG.debug("isDirectPathEnabled: {}", options.getIsWindmillServiceDirectPathEnabled()); + LOG.debug("windmillServiceEnabled: {}", options.isEnableStreamingEngine()); LOG.debug("WindmillServiceEndpoint: {}", options.getWindmillServiceEndpoint()); LOG.debug("WindmillServicePort: {}", options.getWindmillServicePort()); LOG.debug("LocalWindmillHostport: {}", options.getLocalWindmillHostport()); } - private static WindmillStreamPool separateHeartbeatPool( - WindmillServerStub windmillServer) { - return WindmillStreamPool.create(1, GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream); + private static StreamingWorkerStatusPages.Builder createStatusPageBuilder( + DataflowWorkerHarnessOptions options, + GrpcWindmillStreamFactory windmillStreamFactory, + MemoryMonitor memoryMonitor) { + WorkerStatusPages workerStatusPages = + WorkerStatusPages.create(DEFAULT_STATUS_PORT, memoryMonitor); + + StreamingWorkerStatusPages.Builder streamingStatusPages = + StreamingWorkerStatusPages.builder().setStatusPages(workerStatusPages); + + return options.isEnableStreamingEngine() + ? streamingStatusPages + .setDebugCapture( + new DebugCapture.Manager(options, workerStatusPages.getDebugCapturePages())) + .setWindmillStreamFactory(windmillStreamFactory) + : streamingStatusPages; + } + + private static ChannelzServlet createChannelzServlet( + DataflowWorkerHarnessOptions options, + Supplier> windmillEndpointProvider) { + return new ChannelzServlet(CHANNELZ_PATH, options, windmillEndpointProvider); + } + + private static HeartbeatSender createStreamingEngineHeartbeatSender( + DataflowWorkerHarnessOptions options, + WindmillServerStub windmillClient, + WindmillStreamPool getDataStreamPool, + StreamingGlobalConfigHandle globalConfigHandle) { + // Experiment gates the logic till backend changes are rollback safe + if (!DataflowRunner.hasExperiment( + options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL_EXPERIMENT) + || options.getUseSeparateWindmillHeartbeatStreams() != null) { + return StreamPoolHeartbeatSender.create( + Boolean.TRUE.equals(options.getUseSeparateWindmillHeartbeatStreams()) + ? WindmillStreamPool.create(1, GET_DATA_STREAM_TIMEOUT, windmillClient::getDataStream) + : getDataStreamPool); + + } else { + return StreamPoolHeartbeatSender.create( + WindmillStreamPool.create(1, GET_DATA_STREAM_TIMEOUT, windmillClient::getDataStream), + getDataStreamPool, + globalConfigHandle); + } } public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions options) { @@ -351,10 +445,7 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o .setSizeMb(options.getWorkerCacheMb()) .setSupportMapViaMultimap(options.isEnableStreamingEngine()) .build(); - Function executorSupplier = - threadName -> - Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()); + GrpcWindmillStreamFactory.Builder windmillStreamFactoryBuilder = createGrpcwindmillStreamFactoryBuilder(options, clientId); @@ -392,17 +483,21 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o failureTracker, () -> Optional.ofNullable(memoryMonitor.tryToDumpHeap()), clock); - StreamingWorkerStatusReporter workerStatusReporter = - StreamingWorkerStatusReporter.create( - dataflowServiceClient, - windmillServer::getAndResetThrottleTime, - stageInfo::values, - failureTracker, - streamingCounters, - memoryMonitor, - workExecutor, - options.getWindmillHarnessUpdateReportingPeriod().getMillis(), - options.getPerWorkerMetricsUpdateReportingPeriodMillis()); + StreamingWorkerStatusReporterFactory workerStatusReporterFactory = + throttleTimeSupplier -> + StreamingWorkerStatusReporter.builder() + .setDataflowServiceClient(dataflowServiceClient) + .setWindmillQuotaThrottleTime(throttleTimeSupplier) + .setAllStageInfo(stageInfo::values) + .setFailureTracker(failureTracker) + .setStreamingCounters(streamingCounters) + .setMemoryMonitor(memoryMonitor) + .setWorkExecutor(workExecutor) + .setWindmillHarnessUpdateReportingPeriodMillis( + options.getWindmillHarnessUpdateReportingPeriod().getMillis()) + .setPerWorkerMetricsUpdateReportingPeriodMillis( + options.getPerWorkerMetricsUpdateReportingPeriodMillis()) + .build(); return new StreamingDataflowWorker( windmillServer, @@ -415,14 +510,16 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o options, new HotKeyLogger(), clock, - workerStatusReporter, + workerStatusReporterFactory, failureTracker, workFailureProcessor, streamingCounters, memoryMonitor, configFetcherComputationStateCacheAndWindmillClient.windmillStreamFactory(), - executorSupplier, - stageInfo); + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat("RefreshWork").build()), + stageInfo, + configFetcherComputationStateCacheAndWindmillClient.windmillDispatcherClient()); } /** @@ -437,53 +534,121 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o WorkUnitClient dataflowServiceClient, GrpcWindmillStreamFactory.Builder windmillStreamFactoryBuilder, Function computationStateCacheFactory) { - ComputationConfig.Fetcher configFetcher; - WindmillServerStub windmillServer; - ComputationStateCache computationStateCache; - GrpcWindmillStreamFactory windmillStreamFactory; if (options.isEnableStreamingEngine()) { GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); - configFetcher = + ComputationConfig.Fetcher configFetcher = StreamingEngineComputationConfigFetcher.create( options.getGlobalConfigRefreshPeriod().getMillis(), dataflowServiceClient); configFetcher.getGlobalConfigHandle().registerConfigObserver(dispatcherClient::onJobConfig); - computationStateCache = computationStateCacheFactory.apply(configFetcher); - windmillStreamFactory = + ComputationStateCache computationStateCache = + computationStateCacheFactory.apply(configFetcher); + GrpcWindmillStreamFactory windmillStreamFactory = windmillStreamFactoryBuilder .setProcessHeartbeatResponses( new WorkHeartbeatResponseProcessor(computationStateCache::get)) .setHealthCheckIntervalMillis( options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()) .build(); - windmillServer = GrpcWindmillServer.create(options, windmillStreamFactory, dispatcherClient); - } else { - if (options.getWindmillServiceEndpoint() != null - || options.getLocalWindmillHostport().startsWith("grpc:")) { - windmillStreamFactory = - windmillStreamFactoryBuilder - .setHealthCheckIntervalMillis( - options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()) - .build(); - windmillServer = - GrpcWindmillServer.create( - options, - windmillStreamFactory, - GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options))); - } else { - windmillStreamFactory = windmillStreamFactoryBuilder.build(); - windmillServer = new JniWindmillApplianceServer(options.getLocalWindmillHostport()); + return ConfigFetcherComputationStateCacheAndWindmillClient.builder() + .setWindmillDispatcherClient(dispatcherClient) + .setConfigFetcher(configFetcher) + .setComputationStateCache(computationStateCache) + .setWindmillStreamFactory(windmillStreamFactory) + .setWindmillServer( + GrpcWindmillServer.create(options, windmillStreamFactory, dispatcherClient)) + .build(); + } + + // Build with local Windmill client. + if (options.getWindmillServiceEndpoint() != null + || options.getLocalWindmillHostport().startsWith("grpc:")) { + GrpcDispatcherClient dispatcherClient = + GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); + GrpcWindmillStreamFactory windmillStreamFactory = + windmillStreamFactoryBuilder + .setHealthCheckIntervalMillis( + options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()) + .build(); + GrpcWindmillServer windmillServer = + GrpcWindmillServer.create(options, windmillStreamFactory, dispatcherClient); + ComputationConfig.Fetcher configFetcher = + createApplianceComputationConfigFetcher(windmillServer); + return ConfigFetcherComputationStateCacheAndWindmillClient.builder() + .setWindmillDispatcherClient(dispatcherClient) + .setWindmillServer(windmillServer) + .setWindmillStreamFactory(windmillStreamFactory) + .setConfigFetcher(configFetcher) + .setComputationStateCache(computationStateCacheFactory.apply(configFetcher)) + .build(); + } + + WindmillServerStub windmillServer = + new JniWindmillApplianceServer(options.getLocalWindmillHostport()); + ComputationConfig.Fetcher configFetcher = + createApplianceComputationConfigFetcher(windmillServer); + return ConfigFetcherComputationStateCacheAndWindmillClient.builder() + .setWindmillStreamFactory(windmillStreamFactoryBuilder.build()) + .setWindmillServer(windmillServer) + .setConfigFetcher(configFetcher) + .setComputationStateCache(computationStateCacheFactory.apply(configFetcher)) + .build(); + } + + private static StreamingApplianceComputationConfigFetcher createApplianceComputationConfigFetcher( + ApplianceWindmillClient windmillClient) { + return new StreamingApplianceComputationConfigFetcher( + windmillClient::getConfig, + new FixedGlobalConfigHandle(StreamingGlobalConfig.builder().build())); + } + + private static boolean isDirectPathPipeline(DataflowWorkerHarnessOptions options) { + if (options.isEnableStreamingEngine() && options.getIsWindmillServiceDirectPathEnabled()) { + boolean isIpV6Enabled = + Optional.ofNullable(options.getDataflowServiceOptions()) + .map(serviceOptions -> serviceOptions.contains(ENABLE_IPV6_EXPERIMENT)) + .orElse(false); + + if (isIpV6Enabled) { + return true; } - configFetcher = - new StreamingApplianceComputationConfigFetcher( - windmillServer::getConfig, - new FixedGlobalConfigHandle(StreamingGlobalConfig.builder().build())); - computationStateCache = computationStateCacheFactory.apply(configFetcher); + LOG.warn( + "DirectPath is currently only supported with IPv6 networking stack. This requires setting " + + "\"enable_private_ipv6_google_access\" in experimental pipeline options. " + + "For information on how to set experimental pipeline options see " + + "https://cloud.google.com/dataflow/docs/guides/setting-pipeline-options#experimental. " + + "Defaulting to CloudPath."); } - return ConfigFetcherComputationStateCacheAndWindmillClient.create( - configFetcher, computationStateCache, windmillServer, windmillStreamFactory); + return false; + } + + private static void validateWorkerOptions(DataflowWorkerHarnessOptions options) { + Preconditions.checkArgument( + options.isStreaming(), + "%s instantiated with options indicating batch use", + StreamingDataflowWorker.class.getName()); + + Preconditions.checkArgument( + !DataflowRunner.hasExperiment(options, BEAM_FN_API_EXPERIMENT), + "%s cannot be main() class with beam_fn_api enabled", + StreamingDataflowWorker.class.getSimpleName()); + } + + private static ChannelCachingStubFactory createFanOutStubFactory( + DataflowWorkerHarnessOptions workerOptions) { + return ChannelCachingRemoteStubFactory.create( + workerOptions.getGcpCredential(), + ChannelCache.create( + serviceAddress -> + // IsolationChannel will create and manage separate RPC channels to the same + // serviceAddress. + IsolationChannel.create( + () -> + remoteChannel( + serviceAddress, + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec())))); } @VisibleForTesting @@ -499,7 +664,9 @@ static StreamingDataflowWorker forTesting( Supplier clock, Function executorSupplier, StreamingGlobalConfigHandleImpl globalConfigHandle, - int localRetryTimeoutMs) { + int localRetryTimeoutMs, + StreamingCounters streamingCounters, + WindmillStubFactoryFactory stubFactory) { ConcurrentMap stageInfo = new ConcurrentHashMap<>(); BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options); WindmillStateCache stateCache = @@ -542,7 +709,6 @@ static StreamingDataflowWorker forTesting( stateNameMap, stateCache.forComputation(mapTask.getStageName()))); MemoryMonitor memoryMonitor = MemoryMonitor.fromOptions(options); - StreamingCounters streamingCounters = StreamingCounters.create(); FailureTracker failureTracker = options.isEnableStreamingEngine() ? StreamingEngineFailureTracker.create( @@ -558,19 +724,23 @@ static StreamingDataflowWorker forTesting( () -> Optional.ofNullable(memoryMonitor.tryToDumpHeap()), clock, localRetryTimeoutMs); - StreamingWorkerStatusReporter workerStatusReporter = - StreamingWorkerStatusReporter.forTesting( - publishCounters, - workUnitClient, - windmillServer::getAndResetThrottleTime, - stageInfo::values, - failureTracker, - streamingCounters, - memoryMonitor, - workExecutor, - executorSupplier, - options.getWindmillHarnessUpdateReportingPeriod().getMillis(), - options.getPerWorkerMetricsUpdateReportingPeriodMillis()); + StreamingWorkerStatusReporterFactory workerStatusReporterFactory = + throttleTimeSupplier -> + StreamingWorkerStatusReporter.builder() + .setPublishCounters(publishCounters) + .setDataflowServiceClient(workUnitClient) + .setWindmillQuotaThrottleTime(throttleTimeSupplier) + .setAllStageInfo(stageInfo::values) + .setFailureTracker(failureTracker) + .setStreamingCounters(streamingCounters) + .setMemoryMonitor(memoryMonitor) + .setWorkExecutor(workExecutor) + .setExecutorFactory(executorSupplier) + .setWindmillHarnessUpdateReportingPeriodMillis( + options.getWindmillHarnessUpdateReportingPeriod().getMillis()) + .setPerWorkerMetricsUpdateReportingPeriodMillis( + options.getPerWorkerMetricsUpdateReportingPeriodMillis()) + .build(); GrpcWindmillStreamFactory.Builder windmillStreamFactory = createGrpcwindmillStreamFactoryBuilder(options, 1) @@ -588,7 +758,7 @@ static StreamingDataflowWorker forTesting( options, hotKeyLogger, clock, - workerStatusReporter, + workerStatusReporterFactory, failureTracker, workFailureProcessor, streamingCounters, @@ -599,8 +769,9 @@ static StreamingDataflowWorker forTesting( options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()) .build() : windmillStreamFactory.build(), - executorSupplier, - stageInfo); + executorSupplier.apply("RefreshWork"), + stageInfo, + GrpcDispatcherClient.create(options, stubFactory)); } private static GrpcWindmillStreamFactory.Builder createGrpcwindmillStreamFactoryBuilder( @@ -609,13 +780,7 @@ private static GrpcWindmillStreamFactory.Builder createGrpcwindmillStreamFactory !options.isEnableStreamingEngine() && options.getLocalWindmillHostport() != null ? GrpcWindmillServer.LOCALHOST_MAX_BACKOFF : Duration.millis(options.getWindmillServiceStreamMaxBackoffMillis()); - return GrpcWindmillStreamFactory.of( - JobHeader.newBuilder() - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .setWorkerId(options.getWorkerId()) - .setClientId(clientId) - .build()) + return GrpcWindmillStreamFactory.of(createJobHeader(options, clientId)) .setWindmillMessagesBetweenIsReadyChecks(options.getWindmillMessagesBetweenIsReadyChecks()) .setMaxBackOffSupplier(() -> maxBackoff) .setLogEveryNStreamFailures(options.getWindmillServiceStreamingLogEveryNStreamFailures()) @@ -626,6 +791,15 @@ private static GrpcWindmillStreamFactory.Builder createGrpcwindmillStreamFactory options, "streaming_engine_disable_new_heartbeat_requests")); } + private static JobHeader createJobHeader(DataflowWorkerHarnessOptions options, long clientId) { + return JobHeader.newBuilder() + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .setClientId(clientId) + .build(); + } + private static BoundedQueueExecutor createWorkUnitExecutor(DataflowWorkerHarnessOptions options) { return new BoundedQueueExecutor( chooseMaxThreads(options), @@ -644,15 +818,7 @@ public static void main(String[] args) throws Exception { DataflowWorkerHarnessHelper.initializeGlobalStateAndPipelineOptions( StreamingDataflowWorker.class, DataflowWorkerHarnessOptions.class); DataflowWorkerHarnessHelper.configureLogging(options); - checkArgument( - options.isStreaming(), - "%s instantiated with options indicating batch use", - StreamingDataflowWorker.class.getName()); - - checkArgument( - !DataflowRunner.hasExperiment(options, "beam_fn_api"), - "%s cannot be main() class with beam_fn_api enabled", - StreamingDataflowWorker.class.getSimpleName()); + validateWorkerOptions(options); CoderTranslation.verifyModelCodersRegistered(); @@ -705,21 +871,6 @@ void reportPeriodicWorkerUpdatesForTest() { workerStatusReporter.reportPeriodicWorkerUpdates(); } - private int chooseMaximumNumberOfThreads() { - if (options.getNumberOfWorkerHarnessThreads() != 0) { - return options.getNumberOfWorkerHarnessThreads(); - } - return MAX_PROCESSING_THREADS; - } - - private int chooseMaximumBundlesOutstanding() { - int maxBundles = options.getMaxBundlesFromWindmillOutstanding(); - if (maxBundles > 0) { - return maxBundles; - } - return chooseMaximumNumberOfThreads() + 100; - } - @VisibleForTesting public boolean workExecutorIsEmpty() { return workUnitExecutor.executorQueueIsEmpty(); @@ -727,7 +878,7 @@ public boolean workExecutorIsEmpty() { @VisibleForTesting int numCommitThreads() { - return workCommitter.parallelism(); + return numCommitThreads; } @VisibleForTesting @@ -740,7 +891,6 @@ ComputationStateCache getComputationStateCache() { return computationStateCache; } - @SuppressWarnings("FutureReturnValueIgnored") public void start() { running.set(true); configFetcher.start(); @@ -791,27 +941,17 @@ private void onCompleteCommit(CompleteCommit completeCommit) { completeCommit.shardedKey(), completeCommit.workId())); } - @VisibleForTesting - public Iterable buildCounters() { - return Iterables.concat( - streamingCounters - .pendingDeltaCounters() - .extractModifiedDeltaUpdates(DataflowCounterUpdateExtractor.INSTANCE), - streamingCounters - .pendingCumulativeCounters() - .extractUpdates(false, DataflowCounterUpdateExtractor.INSTANCE)); + @FunctionalInterface + private interface StreamingWorkerStatusReporterFactory { + StreamingWorkerStatusReporter createStatusReporter(ThrottledTimeTracker throttledTimeTracker); } @AutoValue abstract static class ConfigFetcherComputationStateCacheAndWindmillClient { - private static ConfigFetcherComputationStateCacheAndWindmillClient create( - ComputationConfig.Fetcher configFetcher, - ComputationStateCache computationStateCache, - WindmillServerStub windmillServer, - GrpcWindmillStreamFactory windmillStreamFactory) { - return new AutoValue_StreamingDataflowWorker_ConfigFetcherComputationStateCacheAndWindmillClient( - configFetcher, computationStateCache, windmillServer, windmillStreamFactory); + private static Builder builder() { + return new AutoValue_StreamingDataflowWorker_ConfigFetcherComputationStateCacheAndWindmillClient + .Builder(); } abstract ComputationConfig.Fetcher configFetcher(); @@ -821,6 +961,23 @@ private static ConfigFetcherComputationStateCacheAndWindmillClient create( abstract WindmillServerStub windmillServer(); abstract GrpcWindmillStreamFactory windmillStreamFactory(); + + abstract @Nullable GrpcDispatcherClient windmillDispatcherClient(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setConfigFetcher(ComputationConfig.Fetcher value); + + abstract Builder setComputationStateCache(ComputationStateCache value); + + abstract Builder setWindmillServer(WindmillServerStub value); + + abstract Builder setWindmillStreamFactory(GrpcWindmillStreamFactory value); + + abstract Builder setWindmillDispatcherClient(GrpcDispatcherClient value); + + abstract ConfigFetcherComputationStateCacheAndWindmillClient build(); + } } /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java index ec5122a8732a..a12a5075c5ee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java @@ -26,8 +26,12 @@ public WorkItemCancelledException(long sharding_key) { super("Work item cancelled for key " + sharding_key); } - public WorkItemCancelledException(Throwable e) { - super(e); + public WorkItemCancelledException(String message, Throwable cause) { + super(message, cause); + } + + public WorkItemCancelledException(Throwable cause) { + super(cause); } /** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java index 5a8e87d23ab9..b4ec170099d5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java @@ -28,16 +28,16 @@ * This uncaught exception handler logs the {@link Throwable} to the logger, {@link System#err} and * exits the application with status code 1. */ -class WorkerUncaughtExceptionHandler implements UncaughtExceptionHandler { +public final class WorkerUncaughtExceptionHandler implements UncaughtExceptionHandler { + @VisibleForTesting public static final int JVM_TERMINATED_STATUS_CODE = 1; private final JvmRuntime runtime; private final Logger logger; - WorkerUncaughtExceptionHandler(Logger logger) { + public WorkerUncaughtExceptionHandler(Logger logger) { this(JvmRuntime.INSTANCE, logger); } - @VisibleForTesting - WorkerUncaughtExceptionHandler(JvmRuntime runtime, Logger logger) { + public WorkerUncaughtExceptionHandler(JvmRuntime runtime, Logger logger) { this.runtime = runtime; this.logger = logger; } @@ -59,7 +59,7 @@ public void uncaughtException(Thread thread, Throwable e) { t.printStackTrace(originalStdErr); } } finally { - runtime.halt(1); + runtime.halt(JVM_TERMINATED_STATUS_CODE); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index c80c3a882e52..aec52cd7d9a6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -240,18 +240,20 @@ synchronized Optional completeWorkAndGetNextWorkForKey( @Nullable Queue workQueue = activeWork.get(shardedKey); if (workQueue == null) { // Work may have been completed due to clearing of stuck commits. - LOG.warn("Unable to complete inactive work for key {} and token {}.", shardedKey, workId); + LOG.warn( + "Unable to complete inactive work for key={} and token={}. Work queue for key does not exist.", + shardedKey, + workId); return Optional.empty(); } + removeCompletedWorkFromQueue(workQueue, shardedKey, workId); return getNextWork(workQueue, shardedKey); } private synchronized void removeCompletedWorkFromQueue( Queue workQueue, ShardedKey shardedKey, WorkId workId) { - // avoid Preconditions.checkState here to prevent eagerly evaluating the - // format string parameters for the error message. - ExecutableWork completedWork = workQueue.peek(); + @Nullable ExecutableWork completedWork = workQueue.peek(); if (completedWork == null) { // Work may have been completed due to clearing of stuck commits. LOG.warn("Active key {} without work, expected token {}", shardedKey, workId); @@ -337,8 +339,18 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) { writer.println( ""); + // Columns. writer.println( - ""); + "" + + "" + + "" + + "" + + "" + + "" + + "" + + "" + + "" + + ""); // Use StringBuilder because we are appending in loop. StringBuilder activeWorkStatus = new StringBuilder(); int commitsPendingCount = 0; @@ -364,6 +376,10 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) { activeWorkStatus.append(activeWork.getState()); activeWorkStatus.append("\n"); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java index 199ad26aed00..4b4acb73f4a7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java @@ -147,8 +147,10 @@ public Optional get(String computationId) { | ComputationStateNotFoundException e) { if (e.getCause() instanceof ComputationStateNotFoundException || e instanceof ComputationStateNotFoundException) { - LOG.error( - "Trying to fetch unknown computation={}, known computations are {}.", + LOG.warn( + "Computation {} is currently unknown, " + + "known computations are {}. " + + "This is transient and will get retried.", computationId, ImmutableSet.copyOf(computationCache.asMap().keySet())); } else { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java index a18ca8cfd6dc..d9fe95f3421b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.streaming; +import static org.apache.beam.runners.dataflow.worker.MetricsToPerStepNamespaceMetricsConverter.KAFKA_SINK_METRICS_NAMESPACE; import static org.apache.beam.sdk.metrics.Metrics.THROTTLE_TIME_COUNTER_NAME; import com.google.api.services.dataflow.model.CounterStructuredName; @@ -118,7 +119,8 @@ public List extractPerWorkerMetricValues() { private void translateKnownPerWorkerCounters(List metrics) { for (PerStepNamespaceMetrics perStepnamespaceMetrics : metrics) { if (!BigQuerySinkMetrics.METRICS_NAMESPACE.equals( - perStepnamespaceMetrics.getMetricsNamespace())) { + perStepnamespaceMetrics.getMetricsNamespace()) + && !KAFKA_SINK_METRICS_NAMESPACE.equals(perStepnamespaceMetrics.getMetricsNamespace())) { continue; } for (MetricValue metric : perStepnamespaceMetrics.getMetricValues()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java index f2893f3e7191..5f039be7b00f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java @@ -18,33 +18,24 @@ package org.apache.beam.runners.dataflow.worker.streaming; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.function.Function; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; -/** Bounded set of queues, with a maximum total weight. */ +/** Queue bounded by a {@link WeightedSemaphore}. */ public final class WeightedBoundedQueue { private final LinkedBlockingQueue queue; - private final int maxWeight; - private final Semaphore limit; - private final Function weigher; + private final WeightedSemaphore weightedSemaphore; private WeightedBoundedQueue( - LinkedBlockingQueue linkedBlockingQueue, - int maxWeight, - Semaphore limit, - Function weigher) { + LinkedBlockingQueue linkedBlockingQueue, WeightedSemaphore weightedSemaphore) { this.queue = linkedBlockingQueue; - this.maxWeight = maxWeight; - this.limit = limit; - this.weigher = weigher; + this.weightedSemaphore = weightedSemaphore; } - public static WeightedBoundedQueue create(int maxWeight, Function weigherFn) { - return new WeightedBoundedQueue<>( - new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, true), weigherFn); + public static WeightedBoundedQueue create(WeightedSemaphore weightedSemaphore) { + return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(), weightedSemaphore); } /** @@ -52,15 +43,15 @@ public static WeightedBoundedQueue create(int maxWeight, Function { + private final int maxWeight; + private final Semaphore limit; + private final Function weigher; + + private WeightedSemaphore(int maxWeight, Semaphore limit, Function weigher) { + this.maxWeight = maxWeight; + this.limit = limit; + this.weigher = weigher; + } + + public static WeightedSemaphore create(int maxWeight, Function weigherFn) { + return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), weigherFn); + } + + public void acquireUninterruptibly(V value) { + limit.acquireUninterruptibly(computePermits(value)); + } + + public void release(V value) { + limit.release(computePermits(value)); + } + + private int computePermits(V value) { + return Math.min(weigher.apply(value), maxWeight); + } + + public int currentWeight() { + return maxWeight - limit.availablePermits(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index e77823602eda..6f97cbca9a80 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -56,7 +56,7 @@ /** * Represents the state of an attempt to process a {@link WorkItem} by executing user code. * - * @implNote Not thread safe, should not be executed or accessed by more than 1 thread at a time. + * @implNote Not thread safe, should not be modified by more than 1 thread at a time. */ @NotThreadSafe @Internal @@ -70,8 +70,9 @@ public final class Work implements RefreshableWork { private final Map totalDurationPerState; private final WorkId id; private final String latencyTrackingId; - private TimedState currentState; + private volatile TimedState currentState; private volatile boolean isFailed; + private volatile String processingThreadName = ""; private Work( WorkItem workItem, @@ -110,7 +111,18 @@ public static ProcessingContext createProcessingContext( GetDataClient getDataClient, Consumer workCommitter, HeartbeatSender heartbeatSender) { - return ProcessingContext.create(computationId, getDataClient, workCommitter, heartbeatSender); + return ProcessingContext.create( + computationId, getDataClient, workCommitter, heartbeatSender, /* backendWorkerToken= */ ""); + } + + public static ProcessingContext createProcessingContext( + String computationId, + GetDataClient getDataClient, + Consumer workCommitter, + HeartbeatSender heartbeatSender, + String backendWorkerToken) { + return ProcessingContext.create( + computationId, getDataClient, workCommitter, heartbeatSender, backendWorkerToken); } private static LatencyAttribution.Builder createLatencyAttributionWithActiveLatencyBreakdown( @@ -167,6 +179,10 @@ public GlobalData fetchSideInput(GlobalDataRequest request) { return processingContext.getDataClient().getSideInputData(request); } + public String backendWorkerToken() { + return processingContext.backendWorkerToken(); + } + public Watermarks watermarks() { return watermarks; } @@ -188,6 +204,14 @@ public void setState(State state) { this.currentState = TimedState.create(state, now); } + public String getProcessingThreadName() { + return processingThreadName; + } + + public void setProcessingThreadName(String processingThreadName) { + this.processingThreadName = processingThreadName; + } + @Override public void setFailed() { this.isFailed = true; @@ -342,9 +366,10 @@ private static ProcessingContext create( String computationId, GetDataClient getDataClient, Consumer workCommitter, - HeartbeatSender heartbeatSender) { + HeartbeatSender heartbeatSender, + String backendWorkerToken) { return new AutoValue_Work_ProcessingContext( - computationId, getDataClient, heartbeatSender, workCommitter); + computationId, getDataClient, heartbeatSender, workCommitter, backendWorkerToken); } /** Computation that the {@link Work} belongs to. */ @@ -361,6 +386,8 @@ private static ProcessingContext create( */ public abstract Consumer workCommitter(); + public abstract String backendWorkerToken(); + private Optional fetchKeyedState(KeyedGetDataRequest request) { return Optional.ofNullable(getDataClient().getStateData(computationId(), request)); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 3556b7ce2919..21aaa23d3f85 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -20,20 +20,26 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; -import java.util.Collection; -import java.util.List; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import java.util.HashSet; import java.util.Map.Entry; +import java.util.NoSuchElementException; import java.util.Optional; -import java.util.Queue; -import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; @@ -54,18 +60,16 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; -import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,32 +84,39 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorkerHarness { private static final Logger LOG = LoggerFactory.getLogger(FanOutStreamingEngineWorkerHarness.class); - private static final String PUBLISH_NEW_WORKER_METADATA_THREAD = "PublishNewWorkerMetadataThread"; - private static final String CONSUME_NEW_WORKER_METADATA_THREAD = "ConsumeNewWorkerMetadataThread"; + private static final String WORKER_METADATA_CONSUMER_THREAD_NAME = + "WindmillWorkerMetadataConsumerThread"; + private static final String STREAM_MANAGER_THREAD_NAME = "WindmillStreamManager-%d"; private final JobHeader jobHeader; private final GrpcWindmillStreamFactory streamFactory; private final WorkItemScheduler workItemScheduler; private final ChannelCachingStubFactory channelCachingStubFactory; private final GrpcDispatcherClient dispatcherClient; - private final AtomicBoolean isBudgetRefreshPaused; - private final GetWorkBudgetRefresher getWorkBudgetRefresher; - private final AtomicReference lastBudgetRefresh; + private final GetWorkBudgetDistributor getWorkBudgetDistributor; + private final GetWorkBudget totalGetWorkBudget; private final ThrottleTimer getWorkerMetadataThrottleTimer; - private final ExecutorService newWorkerMetadataPublisher; - private final ExecutorService newWorkerMetadataConsumer; - private final long clientId; - private final Supplier getWorkerMetadataStream; - private final Queue newWindmillEndpoints; private final Function workCommitterFactory; private final ThrottlingGetDataMetricTracker getDataMetricTracker; + private final ExecutorService windmillStreamManager; + private final ExecutorService workerMetadataConsumer; + private final Object metadataLock = new Object(); /** Writes are guarded by synchronization, reads are lock free. */ - private final AtomicReference connections; + private final AtomicReference backends; - private volatile boolean started; + @GuardedBy("this") + private long activeMetadataVersion; + + @GuardedBy("metadataLock") + private long pendingMetadataVersion; + + @GuardedBy("this") + private boolean started; + + @GuardedBy("this") + private @Nullable GetWorkerMetadataStream getWorkerMetadataStream = null; - @SuppressWarnings("FutureReturnValueIgnored") private FanOutStreamingEngineWorkerHarness( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, @@ -114,52 +125,28 @@ private FanOutStreamingEngineWorkerHarness( ChannelCachingStubFactory channelCachingStubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, - ThrottlingGetDataMetricTracker getDataMetricTracker) { + ThrottlingGetDataMetricTracker getDataMetricTracker, + ExecutorService workerMetadataConsumer) { this.jobHeader = jobHeader; this.getDataMetricTracker = getDataMetricTracker; this.started = false; this.streamFactory = streamFactory; this.workItemScheduler = workItemScheduler; - this.connections = new AtomicReference<>(StreamingEngineConnectionState.EMPTY); + this.backends = new AtomicReference<>(StreamingEngineBackends.EMPTY); this.channelCachingStubFactory = channelCachingStubFactory; this.dispatcherClient = dispatcherClient; - this.isBudgetRefreshPaused = new AtomicBoolean(false); this.getWorkerMetadataThrottleTimer = new ThrottleTimer(); - this.newWorkerMetadataPublisher = - singleThreadedExecutorServiceOf(PUBLISH_NEW_WORKER_METADATA_THREAD); - this.newWorkerMetadataConsumer = - singleThreadedExecutorServiceOf(CONSUME_NEW_WORKER_METADATA_THREAD); - this.clientId = clientId; - this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH); - this.newWindmillEndpoints = Queues.synchronizedQueue(EvictingQueue.create(1)); - this.getWorkBudgetRefresher = - new GetWorkBudgetRefresher( - isBudgetRefreshPaused::get, - () -> { - getWorkBudgetDistributor.distributeBudget( - connections.get().windmillStreams().values(), totalGetWorkBudget); - lastBudgetRefresh.set(Instant.now()); - }); - this.getWorkerMetadataStream = - Suppliers.memoize( - () -> - streamFactory.createGetWorkerMetadataStream( - dispatcherClient.getWindmillMetadataServiceStubBlocking(), - getWorkerMetadataThrottleTimer, - endpoints -> - // Run this on a separate thread than the grpc stream thread. - newWorkerMetadataPublisher.submit( - () -> newWindmillEndpoints.add(endpoints)))); + this.windmillStreamManager = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build()); + this.workerMetadataConsumer = workerMetadataConsumer; + this.getWorkBudgetDistributor = getWorkBudgetDistributor; + this.totalGetWorkBudget = totalGetWorkBudget; + this.activeMetadataVersion = Long.MIN_VALUE; this.workCommitterFactory = workCommitterFactory; } - private static ExecutorService singleThreadedExecutorServiceOf(String threadName) { - return Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()); - } - /** * Creates an instance of {@link FanOutStreamingEngineWorkerHarness} in a non-started state. * @@ -183,9 +170,12 @@ public static FanOutStreamingEngineWorkerHarness create( channelCachingStubFactory, getWorkBudgetDistributor, dispatcherClient, - /* clientId= */ new Random().nextLong(), workCommitterFactory, - getDataMetricTracker); + getDataMetricTracker, + Executors.newSingleThreadExecutor( + new ThreadFactoryBuilder() + .setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME) + .build())); } @VisibleForTesting @@ -197,7 +187,6 @@ static FanOutStreamingEngineWorkerHarness forTesting( ChannelCachingStubFactory stubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkProvider = @@ -209,201 +198,238 @@ static FanOutStreamingEngineWorkerHarness forTesting( stubFactory, getWorkBudgetDistributor, dispatcherClient, - clientId, workCommitterFactory, - getDataMetricTracker); + getDataMetricTracker, + // Run the workerMetadataConsumer on the direct calling thread to remove waiting and + // make unit tests more deterministic as we do not have to worry about network IO being + // blocked by the consumeWorkerMetadata() task. Test suites run in different + // environments and non-determinism has lead to past flakiness. See + // https://github.com/apache/beam/issues/28957. + MoreExecutors.newDirectExecutorService()); fanOutStreamingEngineWorkProvider.start(); return fanOutStreamingEngineWorkProvider; } - @SuppressWarnings("ReturnValueIgnored") @Override public synchronized void start() { - Preconditions.checkState(!started, "StreamingEngineClient cannot start twice."); - // Starts the stream, this value is memoized. - getWorkerMetadataStream.get(); - startWorkerMetadataConsumer(); - getWorkBudgetRefresher.start(); + Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); + getWorkerMetadataStream = + streamFactory.createGetWorkerMetadataStream( + dispatcherClient::getWindmillMetadataServiceStubBlocking, + getWorkerMetadataThrottleTimer, + this::consumeWorkerMetadata); + getWorkerMetadataStream.start(); started = true; } public ImmutableSet currentWindmillEndpoints() { - return connections.get().windmillConnections().keySet().stream() + return backends.get().windmillStreams().keySet().stream() .map(Endpoint::directEndpoint) .filter(Optional::isPresent) .map(Optional::get) - .filter( - windmillServiceAddress -> - windmillServiceAddress.getKind() != WindmillServiceAddress.Kind.IPV6) - .map( - windmillServiceAddress -> - windmillServiceAddress.getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS - ? windmillServiceAddress.gcpServiceAddress() - : windmillServiceAddress.authenticatedGcpServiceAddress().gcpServiceAddress()) + .map(WindmillServiceAddress::getServiceAddress) .collect(toImmutableSet()); } /** - * Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or defaults to {@link - * GetDataStream} pointing to dispatcher. + * Fetches {@link GetDataStream} mapped to globalDataKey if or throws {@link + * NoSuchElementException} if one is not found. */ private GetDataStream getGlobalDataStream(String globalDataKey) { - return Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey)) - .map(Supplier::get) - .orElseGet( - () -> - streamFactory.createGetDataStream( - dispatcherClient.getWindmillServiceStub(), new ThrottleTimer())); - } - - @SuppressWarnings("FutureReturnValueIgnored") - private void startWorkerMetadataConsumer() { - newWorkerMetadataConsumer.submit( - () -> { - while (true) { - Optional.ofNullable(newWindmillEndpoints.poll()) - .ifPresent(this::consumeWindmillWorkerEndpoints); - } - }); + return Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey)) + .map(GlobalDataStreamSender::stream) + .orElseThrow( + () -> new NoSuchElementException("No endpoint for global data tag: " + globalDataKey)); } @VisibleForTesting @Override public synchronized void shutdown() { - Preconditions.checkState(started, "StreamingEngineClient never started."); - getWorkerMetadataStream.get().halfClose(); - getWorkBudgetRefresher.stop(); - newWorkerMetadataPublisher.shutdownNow(); - newWorkerMetadataConsumer.shutdownNow(); + Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started."); + Preconditions.checkNotNull(getWorkerMetadataStream).shutdown(); + workerMetadataConsumer.shutdownNow(); + // Close all the streams blocking until this completes to not leak resources. + closeStreamsNotIn(WindmillEndpoints.none()).join(); channelCachingStubFactory.shutdown(); + + try { + Preconditions.checkNotNull(getWorkerMetadataStream).awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for GetWorkerMetadataStream to shutdown.", e); + } + + windmillStreamManager.shutdown(); + boolean isStreamManagerShutdown = false; + try { + isStreamManagerShutdown = windmillStreamManager.awaitTermination(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for windmillStreamManager to shutdown.", e); + } + if (!isStreamManagerShutdown) { + windmillStreamManager.shutdownNow(); + } + } + + private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) { + synchronized (metadataLock) { + // Only process versions greater than what we currently have to prevent double processing of + // metadata. workerMetadataConsumer is single-threaded so we maintain ordering. + if (windmillEndpoints.version() > pendingMetadataVersion) { + pendingMetadataVersion = windmillEndpoints.version(); + workerMetadataConsumer.execute(() -> consumeWindmillWorkerEndpoints(windmillEndpoints)); + } + } } - /** - * {@link java.util.function.Consumer} used to update {@link #connections} on - * new backend worker metadata. - */ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWindmillEndpoints) { - isBudgetRefreshPaused.set(true); - LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints); - ImmutableMap newWindmillConnections = - createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints()); - - StreamingEngineConnectionState newConnectionsState = - StreamingEngineConnectionState.builder() - .setWindmillConnections(newWindmillConnections) - .setWindmillStreams( - closeStaleStreamsAndCreateNewStreams(newWindmillConnections.values())) + // Since this is run on a single threaded executor, multiple versions of the metadata maybe + // queued up while a previous version of the windmillEndpoints were being consumed. Only consume + // the endpoints if they are the most current version. + synchronized (metadataLock) { + if (newWindmillEndpoints.version() < pendingMetadataVersion) { + return; + } + } + + LOG.debug( + "Consuming new endpoints: {}. previous metadata version: {}, current metadata version: {}", + newWindmillEndpoints, + activeMetadataVersion, + newWindmillEndpoints.version()); + closeStreamsNotIn(newWindmillEndpoints); + ImmutableMap newStreams = + createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join(); + StreamingEngineBackends newBackends = + StreamingEngineBackends.builder() + .setWindmillStreams(newStreams) .setGlobalDataStreams( createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints())) .build(); + backends.set(newBackends); + getWorkBudgetDistributor.distributeBudget(newStreams.values(), totalGetWorkBudget); + activeMetadataVersion = newWindmillEndpoints.version(); + } + + /** Close the streams that are no longer valid asynchronously. */ + @CanIgnoreReturnValue + private CompletableFuture closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { + StreamingEngineBackends currentBackends = backends.get(); + Stream> closeStreamFutures = + currentBackends.windmillStreams().entrySet().stream() + .filter( + connectionAndStream -> + !newWindmillEndpoints + .windmillEndpoints() + .contains(connectionAndStream.getKey())) + .map( + entry -> + CompletableFuture.runAsync( + () -> closeStreamSender(entry.getKey(), entry.getValue()), + windmillStreamManager)); + + Set newGlobalDataEndpoints = + new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); + Stream> closeGlobalDataStreamFutures = + currentBackends.globalDataStreams().values().stream() + .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) + .map( + sender -> + CompletableFuture.runAsync( + () -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager)); - LOG.info( - "Setting new connections: {}. Previous connections: {}.", - newConnectionsState, - connections.get()); - connections.set(newConnectionsState); - isBudgetRefreshPaused.set(false); - getWorkBudgetRefresher.requestBudgetRefresh(); + return CompletableFuture.allOf( + Streams.concat(closeStreamFutures, closeGlobalDataStreamFutures) + .toArray(CompletableFuture[]::new)); + } + + private void closeStreamSender(Endpoint endpoint, StreamSender sender) { + LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender); + try { + sender.close(); + endpoint.directEndpoint().ifPresent(channelCachingStubFactory::remove); + LOG.debug("Successfully closed streams to {}", endpoint); + } catch (Exception e) { + LOG.error("Error closing streams to endpoint={}, sender={}", endpoint, sender); + } + } + + private synchronized CompletableFuture> + createAndStartNewStreams(ImmutableSet newWindmillEndpoints) { + ImmutableMap currentStreams = backends.get().windmillStreams(); + return MoreFutures.allAsList( + newWindmillEndpoints.stream() + .map(endpoint -> getOrCreateWindmillStreamSenderFuture(endpoint, currentStreams)) + .collect(Collectors.toList())) + .thenApply( + backends -> backends.stream().collect(toImmutableMap(Pair::getLeft, Pair::getRight))) + .toCompletableFuture(); + } + + private CompletionStage> + getOrCreateWindmillStreamSenderFuture( + Endpoint endpoint, ImmutableMap currentStreams) { + return Optional.ofNullable(currentStreams.get(endpoint)) + .map(backend -> CompletableFuture.completedFuture(Pair.of(endpoint, backend))) + .orElseGet( + () -> + MoreFutures.supplyAsync( + () -> Pair.of(endpoint, createAndStartWindmillStreamSender(endpoint)), + windmillStreamManager) + .toCompletableFuture()); } /** Add up all the throttle times of all streams including GetWorkerMetadataStream. */ - public long getAndResetThrottleTimes() { - return connections.get().windmillStreams().values().stream() + @Override + public long getAndResetThrottleTime() { + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getAndResetThrottleTime) .reduce(0L, Long::sum) + getWorkerMetadataThrottleTimer.getAndResetThrottleTime(); } public long currentActiveCommitBytes() { - return connections.get().windmillStreams().values().stream() + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getCurrentActiveCommitBytes) .reduce(0L, Long::sum); } @VisibleForTesting - StreamingEngineConnectionState getCurrentConnections() { - return connections.get(); - } - - private synchronized ImmutableMap createNewWindmillConnections( - List newWindmillEndpoints) { - ImmutableMap currentConnections = - connections.get().windmillConnections(); - return newWindmillEndpoints.stream() - .collect( - toImmutableMap( - Function.identity(), - endpoint -> - // Reuse existing stubs if they exist. Optional.orElseGet only calls the - // supplier if the value is not present, preventing constructing expensive - // objects. - Optional.ofNullable(currentConnections.get(endpoint)) - .orElseGet( - () -> WindmillConnection.from(endpoint, this::createWindmillStub)))); - } - - private synchronized ImmutableMap - closeStaleStreamsAndCreateNewStreams(Collection newWindmillConnections) { - ImmutableMap currentStreams = - connections.get().windmillStreams(); - - // Close the streams that are no longer valid. - currentStreams.entrySet().stream() - .filter( - connectionAndStream -> !newWindmillConnections.contains(connectionAndStream.getKey())) - .forEach( - entry -> { - entry.getValue().closeAllStreams(); - entry.getKey().directEndpoint().ifPresent(channelCachingStubFactory::remove); - }); - - return newWindmillConnections.stream() - .collect( - toImmutableMap( - Function.identity(), - newConnection -> - Optional.ofNullable(currentStreams.get(newConnection)) - .orElseGet(() -> createAndStartWindmillStreamSenderFor(newConnection)))); + StreamingEngineBackends currentBackends() { + return backends.get(); } - private ImmutableMap> createNewGlobalDataStreams( + private ImmutableMap createNewGlobalDataStreams( ImmutableMap newGlobalDataEndpoints) { - ImmutableMap> currentGlobalDataStreams = - connections.get().globalDataStreams(); + ImmutableMap currentGlobalDataStreams = + backends.get().globalDataStreams(); return newGlobalDataEndpoints.entrySet().stream() .collect( toImmutableMap( Entry::getKey, keyedEndpoint -> - existingOrNewGetDataStreamFor(keyedEndpoint, currentGlobalDataStreams))); + getOrCreateGlobalDataSteam(keyedEndpoint, currentGlobalDataStreams))); } - private Supplier existingOrNewGetDataStreamFor( + private GlobalDataStreamSender getOrCreateGlobalDataSteam( Entry keyedEndpoint, - ImmutableMap> currentGlobalDataStreams) { - return Preconditions.checkNotNull( - currentGlobalDataStreams.getOrDefault( - keyedEndpoint.getKey(), + ImmutableMap currentGlobalDataStreams) { + return Optional.ofNullable(currentGlobalDataStreams.get(keyedEndpoint.getKey())) + .orElseGet( () -> - streamFactory.createGetDataStream( - newOrExistingStubFor(keyedEndpoint.getValue()), new ThrottleTimer()))); - } - - private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint endpoint) { - return Optional.ofNullable(connections.get().windmillConnections().get(endpoint)) - .map(WindmillConnection::stub) - .orElseGet(() -> createWindmillStub(endpoint)); + new GlobalDataStreamSender( + streamFactory.createGetDataStream( + createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()), + keyedEndpoint.getValue())); } - private WindmillStreamSender createAndStartWindmillStreamSenderFor( - WindmillConnection connection) { - // Initially create each stream with no budget. The budget will be eventually assigned by the - // GetWorkBudgetDistributor. + private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) { WindmillStreamSender windmillStreamSender = WindmillStreamSender.create( - connection, + WindmillConnection.from(endpoint, this::createWindmillStub), GetWorkRequest.newBuilder() - .setClientId(clientId) + .setClientId(jobHeader.getClientId()) .setJobId(jobHeader.getJobId()) .setProjectId(jobHeader.getProjectId()) .setWorkerId(jobHeader.getWorkerId()) @@ -415,7 +441,7 @@ private WindmillStreamSender createAndStartWindmillStreamSenderFor( StreamGetDataClient.create( getDataStream, this::getGlobalDataStream, getDataMetricTracker), workCommitterFactory); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); return windmillStreamSender; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java new file mode 100644 index 000000000000..d590e69c17d0 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.sdk.annotations.Internal; + +@Internal +@ThreadSafe +final class GlobalDataStreamSender implements StreamSender { + private final Endpoint endpoint; + private final GetDataStream delegate; + private volatile boolean started; + + GlobalDataStreamSender(GetDataStream delegate, Endpoint endpoint) { + this.delegate = delegate; + this.started = false; + this.endpoint = endpoint; + } + + GetDataStream stream() { + if (!started) { + // Starting the stream possibly perform IO. Start the stream lazily since not all pipeline + // implementations need to fetch global/side input data. + startStream(); + } + + return delegate; + } + + private synchronized void startStream() { + // Check started again after we acquire the lock. + if (!started) { + delegate.start(); + started = true; + } + } + + @Override + public void close() { + delegate.shutdown(); + } + + Endpoint endpoint() { + return endpoint; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java index bc93e6d89c41..65203288e169 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java @@ -33,10 +33,11 @@ import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.RpcException; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WindmillRpcException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottledTimeTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.StreamingWorkScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; @@ -66,6 +67,7 @@ public final class SingleSourceWorkerHarness implements StreamingWorkerHarness { private final Function> computationStateFetcher; private final ExecutorService workProviderExecutor; private final GetWorkSender getWorkSender; + private final ThrottledTimeTracker throttledTimeTracker; SingleSourceWorkerHarness( WorkCommitter workCommitter, @@ -74,7 +76,8 @@ public final class SingleSourceWorkerHarness implements StreamingWorkerHarness { StreamingWorkScheduler streamingWorkScheduler, Runnable waitForResources, Function> computationStateFetcher, - GetWorkSender getWorkSender) { + GetWorkSender getWorkSender, + ThrottledTimeTracker throttledTimeTracker) { this.workCommitter = workCommitter; this.getDataClient = getDataClient; this.heartbeatSender = heartbeatSender; @@ -82,7 +85,7 @@ public final class SingleSourceWorkerHarness implements StreamingWorkerHarness { this.waitForResources = waitForResources; this.computationStateFetcher = computationStateFetcher; this.workProviderExecutor = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) .setPriority(Thread.MIN_PRIORITY) @@ -90,6 +93,7 @@ public final class SingleSourceWorkerHarness implements StreamingWorkerHarness { .build()); this.isRunning = new AtomicBoolean(false); this.getWorkSender = getWorkSender; + this.throttledTimeTracker = throttledTimeTracker; } public static SingleSourceWorkerHarness.Builder builder() { @@ -144,6 +148,11 @@ public void shutdown() { workCommitter.stop(); } + @Override + public long getAndResetThrottleTime() { + return throttledTimeTracker.getAndResetThrottleTime(); + } + private void streamingEngineDispatchLoop( Function getWorkStreamFactory) { while (isRunning.get()) { @@ -199,7 +208,7 @@ private void applianceDispatchLoop(Supplier getWorkFn) if (workResponse.getWorkCount() > 0) { break; } - } catch (RpcException e) { + } catch (WindmillRpcException e) { LOG.warn("GetWork failed, retrying:", e); } sleepUninterruptibly(backoff, TimeUnit.MILLISECONDS); @@ -254,6 +263,8 @@ Builder setComputationStateFetcher( Builder setGetWorkSender(GetWorkSender getWorkSender); + Builder setThrottledTimeTracker(ThrottledTimeTracker throttledTimeTracker); + SingleSourceWorkerHarness build(); } diff --git a/sdks/java/io/kafka/kafka-111/build.gradle b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java similarity index 72% rename from sdks/java/io/kafka/kafka-111/build.gradle rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java index c2b0c8f82827..40a63571620f 100644 --- a/sdks/java/io/kafka/kafka-111/build.gradle +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java @@ -4,21 +4,19 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ -project.ext { - delimited="1.1.1" - undelimited="111" - sdfCompatible=false -} +package org.apache.beam.runners.dataflow.worker.streaming.harness; -apply from: "../kafka-integration-test.gradle" \ No newline at end of file +interface StreamSender { + void close(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java similarity index 55% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java index 3c85ee6abe1f..14290b486830 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java @@ -18,47 +18,37 @@ package org.apache.beam.runners.dataflow.worker.streaming.harness; import com.google.auto.value.AutoValue; -import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; /** - * Represents the current state of connections to Streaming Engine. Connections are updated when - * backend workers assigned to the key ranges being processed by this user worker change during + * Represents the current state of connections to the Streaming Engine backend. Backends are updated + * when backend workers assigned to the key ranges being processed by this user worker change during * pipeline execution. For example, changes can happen via autoscaling, load-balancing, or other * backend updates. */ @AutoValue -abstract class StreamingEngineConnectionState { - static final StreamingEngineConnectionState EMPTY = builder().build(); +abstract class StreamingEngineBackends { + static final StreamingEngineBackends EMPTY = builder().build(); static Builder builder() { - return new AutoValue_StreamingEngineConnectionState.Builder() - .setWindmillConnections(ImmutableMap.of()) + return new AutoValue_StreamingEngineBackends.Builder() .setWindmillStreams(ImmutableMap.of()) .setGlobalDataStreams(ImmutableMap.of()); } - abstract ImmutableMap windmillConnections(); - - abstract ImmutableMap windmillStreams(); + abstract ImmutableMap windmillStreams(); /** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them. */ - abstract ImmutableMap> globalDataStreams(); + abstract ImmutableMap globalDataStreams(); @AutoValue.Builder abstract static class Builder { - public abstract Builder setWindmillConnections( - ImmutableMap value); - - public abstract Builder setWindmillStreams( - ImmutableMap value); + public abstract Builder setWindmillStreams(ImmutableMap value); public abstract Builder setGlobalDataStreams( - ImmutableMap> value); + ImmutableMap value); - public abstract StreamingEngineConnectionState build(); + public abstract StreamingEngineBackends build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerHarness.java index c1b4570e2260..731a5a4b1b51 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerHarness.java @@ -17,11 +17,12 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottledTimeTracker; import org.apache.beam.sdk.annotations.Internal; /** Provides an interface to start streaming worker processing. */ @Internal -public interface StreamingWorkerHarness { +public interface StreamingWorkerHarness extends ThrottledTimeTracker { void start(); void shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusPages.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusPages.java index 6981312eff1d..ddfc6809231a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusPages.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusPages.java @@ -258,7 +258,7 @@ public interface Builder { Builder setDebugCapture(DebugCapture.Manager debugCapture); - Builder setChannelzServlet(ChannelzServlet channelzServlet); + Builder setChannelzServlet(@Nullable ChannelzServlet channelzServlet); Builder setStateCache(WindmillStateCache stateCache); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java index ba77d8e1ce26..3557f0d193c5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java @@ -27,6 +27,7 @@ import com.google.api.services.dataflow.model.WorkItemStatus; import com.google.api.services.dataflow.model.WorkerMessage; import com.google.api.services.dataflow.model.WorkerMessageResponse; +import com.google.auto.value.AutoBuilder; import java.io.IOException; import java.math.RoundingMode; import java.util.ArrayList; @@ -51,6 +52,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottledTimeTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; @@ -78,7 +80,7 @@ public final class StreamingWorkerStatusReporter { private final int initialMaxThreadCount; private final int initialMaxBundlesOutstanding; private final WorkUnitClient dataflowServiceClient; - private final Supplier windmillQuotaThrottleTime; + private final ThrottledTimeTracker windmillQuotaThrottleTime; private final Supplier> allStageInfo; private final FailureTracker failureTracker; private final StreamingCounters streamingCounters; @@ -97,10 +99,10 @@ public final class StreamingWorkerStatusReporter { // Used to track the number of WorkerMessages that have been sent without PerWorkerMetrics. private final AtomicLong workerMessagesIndex; - private StreamingWorkerStatusReporter( + StreamingWorkerStatusReporter( boolean publishCounters, WorkUnitClient dataflowServiceClient, - Supplier windmillQuotaThrottleTime, + ThrottledTimeTracker windmillQuotaThrottleTime, Supplier> allStageInfo, FailureTracker failureTracker, StreamingCounters streamingCounters, @@ -131,57 +133,13 @@ private StreamingWorkerStatusReporter( this.workerMessagesIndex = new AtomicLong(); } - public static StreamingWorkerStatusReporter create( - WorkUnitClient workUnitClient, - Supplier windmillQuotaThrottleTime, - Supplier> allStageInfo, - FailureTracker failureTracker, - StreamingCounters streamingCounters, - MemoryMonitor memoryMonitor, - BoundedQueueExecutor workExecutor, - long windmillHarnessUpdateReportingPeriodMillis, - long perWorkerMetricsUpdateReportingPeriodMillis) { - return new StreamingWorkerStatusReporter( - /* publishCounters= */ true, - workUnitClient, - windmillQuotaThrottleTime, - allStageInfo, - failureTracker, - streamingCounters, - memoryMonitor, - workExecutor, - threadName -> - Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()), - windmillHarnessUpdateReportingPeriodMillis, - perWorkerMetricsUpdateReportingPeriodMillis); - } - - @VisibleForTesting - public static StreamingWorkerStatusReporter forTesting( - boolean publishCounters, - WorkUnitClient workUnitClient, - Supplier windmillQuotaThrottleTime, - Supplier> allStageInfo, - FailureTracker failureTracker, - StreamingCounters streamingCounters, - MemoryMonitor memoryMonitor, - BoundedQueueExecutor workExecutor, - Function executorFactory, - long windmillHarnessUpdateReportingPeriodMillis, - long perWorkerMetricsUpdateReportingPeriodMillis) { - return new StreamingWorkerStatusReporter( - publishCounters, - workUnitClient, - windmillQuotaThrottleTime, - allStageInfo, - failureTracker, - streamingCounters, - memoryMonitor, - workExecutor, - executorFactory, - windmillHarnessUpdateReportingPeriodMillis, - perWorkerMetricsUpdateReportingPeriodMillis); + public static Builder builder() { + return new AutoBuilder_StreamingWorkerStatusReporter_Builder() + .setPublishCounters(true) + .setExecutorFactory( + threadName -> + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat(threadName).build())); } /** @@ -228,6 +186,22 @@ private static void shutdownExecutor(ScheduledExecutorService executor) { } } + // Calculates the PerWorkerMetrics reporting frequency, ensuring alignment with the + // WorkerMessages RPC schedule. The desired reporting period + // (perWorkerMetricsUpdateReportingPeriodMillis) is adjusted to the nearest multiple + // of the RPC interval (windmillHarnessUpdateReportingPeriodMillis). + private static long getPerWorkerMetricsUpdateFrequency( + long windmillHarnessUpdateReportingPeriodMillis, + long perWorkerMetricsUpdateReportingPeriodMillis) { + if (windmillHarnessUpdateReportingPeriodMillis == 0) { + return 0; + } + return LongMath.divide( + perWorkerMetricsUpdateReportingPeriodMillis, + windmillHarnessUpdateReportingPeriodMillis, + RoundingMode.CEILING); + } + @SuppressWarnings("FutureReturnValueIgnored") public void start() { reportHarnessStartup(); @@ -276,27 +250,13 @@ private void reportHarnessStartup() { } } - // Calculates the PerWorkerMetrics reporting frequency, ensuring alignment with the - // WorkerMessages RPC schedule. The desired reporting period - // (perWorkerMetricsUpdateReportingPeriodMillis) is adjusted to the nearest multiple - // of the RPC interval (windmillHarnessUpdateReportingPeriodMillis). - private static long getPerWorkerMetricsUpdateFrequency( - long windmillHarnessUpdateReportingPeriodMillis, - long perWorkerMetricsUpdateReportingPeriodMillis) { - if (windmillHarnessUpdateReportingPeriodMillis == 0) { - return 0; - } - return LongMath.divide( - perWorkerMetricsUpdateReportingPeriodMillis, - windmillHarnessUpdateReportingPeriodMillis, - RoundingMode.CEILING); - } - /** Sends counter updates to Dataflow backend. */ private void sendWorkerUpdatesToDataflowService( CounterSet deltaCounters, CounterSet cumulativeCounters) throws IOException { // Throttle time is tracked by the windmillServer but is reported to DFE here. - streamingCounters.windmillQuotaThrottling().addValue(windmillQuotaThrottleTime.get()); + streamingCounters + .windmillQuotaThrottling() + .addValue(windmillQuotaThrottleTime.getAndResetThrottleTime()); if (memoryMonitor.isThrashing()) { streamingCounters.memoryThrashing().addValue(1); } @@ -496,4 +456,33 @@ private void updateThreadMetrics() { .maxOutstandingBundles() .addValue((long) workExecutor.maximumElementsOutstanding()); } + + @AutoBuilder + public interface Builder { + Builder setPublishCounters(boolean publishCounters); + + Builder setDataflowServiceClient(WorkUnitClient dataflowServiceClient); + + Builder setWindmillQuotaThrottleTime(ThrottledTimeTracker windmillQuotaThrottledTimeTracker); + + Builder setAllStageInfo(Supplier> allStageInfo); + + Builder setFailureTracker(FailureTracker failureTracker); + + Builder setStreamingCounters(StreamingCounters streamingCounters); + + Builder setMemoryMonitor(MemoryMonitor memoryMonitor); + + Builder setWorkExecutor(BoundedQueueExecutor workExecutor); + + Builder setExecutorFactory(Function executorFactory); + + Builder setWindmillHarnessUpdateReportingPeriodMillis( + long windmillHarnessUpdateReportingPeriodMillis); + + Builder setPerWorkerMetricsUpdateReportingPeriodMillis( + long perWorkerMetricsUpdateReportingPeriodMillis); + + StreamingWorkerStatusReporter build(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 45aa403ee71b..2a2f49dff846 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -17,10 +17,14 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; @@ -36,20 +40,13 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetSpender; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.FixedStreamHeartbeatSender; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; /** * Owns and maintains a set of streams used to communicate with a specific Windmill worker. - * Underlying streams are "cached" in a threadsafe manner so that once {@link Supplier#get} is - * called, a stream that is already started is returned. - * - *

Holds references to {@link - * Supplier} because - * initializing the streams automatically start them, and we want to do so lazily here once the - * {@link GetWorkBudget} is set. * *

Once started, the underlying streams are "alive" until they are manually closed via {@link - * #closeAllStreams()}. + * #close()}. * *

If closed, it means that the backend endpoint is no longer in the worker set. Once closed, * these instances are not reused. @@ -59,14 +56,16 @@ */ @Internal @ThreadSafe -final class WindmillStreamSender implements GetWorkBudgetSpender { +final class WindmillStreamSender implements GetWorkBudgetSpender, StreamSender { + private static final String STREAM_STARTER_THREAD_NAME = "StartWindmillStreamThread-%d"; private final AtomicBoolean started; private final AtomicReference getWorkBudget; - private final Supplier getWorkStream; - private final Supplier getDataStream; - private final Supplier commitWorkStream; - private final Supplier workCommitter; + private final GetWorkStream getWorkStream; + private final GetDataStream getDataStream; + private final CommitWorkStream commitWorkStream; + private final WorkCommitter workCommitter; private final StreamingEngineThrottleTimers streamingEngineThrottleTimers; + private final ExecutorService streamStarter; private WindmillStreamSender( WindmillConnection connection, @@ -80,33 +79,28 @@ private WindmillStreamSender( this.getWorkBudget = getWorkBudget; this.streamingEngineThrottleTimers = StreamingEngineThrottleTimers.create(); - // All streams are memoized/cached since they are expensive to create and some implementations - // perform side effects on construction (i.e. sending initial requests to the stream server to - // initiate the streaming RPC connection). Stream instances connect/reconnect internally, so we - // can reuse the same instance through the entire lifecycle of WindmillStreamSender. + // Stream instances connect/reconnect internally, so we can reuse the same instance through the + // entire lifecycle of WindmillStreamSender. this.getDataStream = - Suppliers.memoize( - () -> - streamingEngineStreamFactory.createGetDataStream( - connection.stub(), streamingEngineThrottleTimers.getDataThrottleTimer())); + streamingEngineStreamFactory.createDirectGetDataStream( + connection, streamingEngineThrottleTimers.getDataThrottleTimer()); this.commitWorkStream = - Suppliers.memoize( - () -> - streamingEngineStreamFactory.createCommitWorkStream( - connection.stub(), streamingEngineThrottleTimers.commitWorkThrottleTimer())); - this.workCommitter = - Suppliers.memoize(() -> workCommitterFactory.apply(commitWorkStream.get())); + streamingEngineStreamFactory.createDirectCommitWorkStream( + connection, streamingEngineThrottleTimers.commitWorkThrottleTimer()); + this.workCommitter = workCommitterFactory.apply(commitWorkStream); this.getWorkStream = - Suppliers.memoize( - () -> - streamingEngineStreamFactory.createDirectGetWorkStream( - connection, - withRequestBudget(getWorkRequest, getWorkBudget.get()), - streamingEngineThrottleTimers.getWorkThrottleTimer(), - () -> FixedStreamHeartbeatSender.create(getDataStream.get()), - () -> getDataClientFactory.apply(getDataStream.get()), - workCommitter, - workItemScheduler)); + streamingEngineStreamFactory.createDirectGetWorkStream( + connection, + withRequestBudget(getWorkRequest, getWorkBudget.get()), + streamingEngineThrottleTimers.getWorkThrottleTimer(), + FixedStreamHeartbeatSender.create(getDataStream), + getDataClientFactory.apply(getDataStream), + workCommitter, + workItemScheduler); + // 3 threads, 1 for each stream type (GetWork, GetData, CommitWork). + this.streamStarter = + Executors.newFixedThreadPool( + 3, new ThreadFactoryBuilder().setNameFormat(STREAM_STARTER_THREAD_NAME).build()); } static WindmillStreamSender create( @@ -131,39 +125,37 @@ private static GetWorkRequest withRequestBudget(GetWorkRequest request, GetWorkB return request.toBuilder().setMaxItems(budget.items()).setMaxBytes(budget.bytes()).build(); } - @SuppressWarnings("ReturnValueIgnored") - void startStreams() { - getWorkStream.get(); - getDataStream.get(); - commitWorkStream.get(); - workCommitter.get().start(); - // *stream.get() is all memoized in a threadsafe manner. - started.set(true); - } + synchronized void start() { + if (!started.get()) { + checkState(!streamStarter.isShutdown(), "WindmillStreamSender has already been shutdown."); - void closeAllStreams() { - // Supplier.get() starts the stream which is an expensive operation as it initiates the - // streaming RPCs by possibly making calls over the network. Do not close the streams unless - // they have already been started. - if (started.get()) { - getWorkStream.get().shutdown(); - getDataStream.get().shutdown(); - workCommitter.get().stop(); - commitWorkStream.get().shutdown(); + // Start these 3 streams in parallel since they each may perform blocking IO. + CompletableFuture.allOf( + CompletableFuture.runAsync(getWorkStream::start, streamStarter), + CompletableFuture.runAsync(getDataStream::start, streamStarter), + CompletableFuture.runAsync(commitWorkStream::start, streamStarter)) + .join(); + workCommitter.start(); + started.set(true); } } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta)); - if (started.get()) { - getWorkStream.get().adjustBudget(itemsDelta, bytesDelta); - } + public synchronized void close() { + streamStarter.shutdownNow(); + getWorkStream.shutdown(); + getDataStream.shutdown(); + workCommitter.stop(); + commitWorkStream.shutdown(); } @Override - public GetWorkBudget remainingBudget() { - return started.get() ? getWorkStream.get().remainingBudget() : getWorkBudget.get(); + public void setBudget(long items, long bytes) { + GetWorkBudget budget = GetWorkBudget.builder().setItems(items).setBytes(bytes).build(); + getWorkBudget.set(budget); + if (started.get()) { + getWorkStream.setBudget(budget); + } } long getAndResetThrottleTime() { @@ -171,6 +163,6 @@ long getAndResetThrottleTime() { } long getCurrentActiveCommitBytes() { - return started.get() ? workCommitter.get().currentActiveCommitBytes() : 0; + return workCommitter.currentActiveCommitBytes(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java index 5e3f293f7d5b..9286be84ceaa 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java @@ -22,8 +22,6 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import javax.annotation.concurrent.GuardedBy; -import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard; @@ -223,18 +221,10 @@ private void executeMonitorHeld(Runnable work, long workBytes) { try { executor.execute( () -> { - String threadName = Thread.currentThread().getName(); try { - if (work instanceof ExecutableWork) { - String workToken = - debugFormattedWorkToken( - ((ExecutableWork) work).work().getWorkItem().getWorkToken()); - Thread.currentThread().setName(threadName + ":" + workToken); - } work.run(); } finally { decrementCounters(workBytes); - Thread.currentThread().setName(threadName); } }); } catch (RuntimeException e) { @@ -244,11 +234,6 @@ private void executeMonitorHeld(Runnable work, long workBytes) { } } - @VisibleForTesting - public static String debugFormattedWorkToken(long workToken) { - return String.format("%016x", workToken); - } - private void decrementCounters(long workBytes) { monitor.enter(); --elementsOutstanding; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index d7ed83def43e..dd7fdd45ab08 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -17,8 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; import com.google.auto.value.AutoValue; import java.net.Inet6Address; @@ -27,8 +27,8 @@ import java.util.Map; import java.util.Optional; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,6 +40,16 @@ @AutoValue public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + private static final WindmillEndpoints NO_ENDPOINTS = + WindmillEndpoints.builder() + .setVersion(Long.MAX_VALUE) + .setWindmillEndpoints(ImmutableSet.of()) + .setGlobalDataEndpoints(ImmutableMap.of()) + .build(); + + public static WindmillEndpoints none() { + return NO_ENDPOINTS; + } public static WindmillEndpoints from( Windmill.WorkerMetadataResponse workerMetadataResponseProto) { @@ -53,14 +63,15 @@ public static WindmillEndpoints from( endpoint.getValue(), workerMetadataResponseProto.getExternalEndpoint()))); - ImmutableList windmillServers = + ImmutableSet windmillServers = workerMetadataResponseProto.getWorkEndpointsList().stream() .map( endpointProto -> Endpoint.from(endpointProto, workerMetadataResponseProto.getExternalEndpoint())) - .collect(toImmutableList()); + .collect(toImmutableSet()); return WindmillEndpoints.builder() + .setVersion(workerMetadataResponseProto.getMetadataVersion()) .setGlobalDataEndpoints(globalDataServers) .setWindmillEndpoints(windmillServers) .build(); @@ -77,17 +88,23 @@ private static Optional parseDirectEndpoint( .map(address -> AuthenticatedGcpServiceAddress.create(authenticatingService, address)) .map(WindmillServiceAddress::create); - return directEndpointIpV6Address.isPresent() - ? directEndpointIpV6Address - : tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint()) - .map(WindmillServiceAddress::create); + Optional windmillServiceAddress = + directEndpointIpV6Address.isPresent() + ? directEndpointIpV6Address + : tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint()) + .map(WindmillServiceAddress::create); + + if (!windmillServiceAddress.isPresent()) { + LOG.warn("Endpoint {} could not be parsed into a WindmillServiceAddress.", endpointProto); + } + + return windmillServiceAddress; } private static Optional tryParseEndpointIntoHostAndPort(String directEndpoint) { try { return Optional.of(HostAndPort.fromString(directEndpoint)); } catch (IllegalArgumentException e) { - LOG.warn("{} cannot be parsed into a gcpServiceAddress", directEndpoint); return Optional.empty(); } } @@ -102,19 +119,12 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( try { directEndpointAddress = Inet6Address.getByName(endpointProto.getDirectEndpoint()); } catch (UnknownHostException e) { - LOG.warn( - "Error occurred trying to parse direct_endpoint={} into IPv6 address. Exception={}", - endpointProto.getDirectEndpoint(), - e.toString()); return Optional.empty(); } // Inet6Address.getByAddress returns either an IPv4 or an IPv6 address depending on the format // of the direct_endpoint string. if (!(directEndpointAddress instanceof Inet6Address)) { - LOG.warn( - "{} is not an IPv6 address. Direct endpoints are expected to be in IPv6 format.", - endpointProto.getDirectEndpoint()); return Optional.empty(); } @@ -123,6 +133,9 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( directEndpointAddress.getHostAddress(), (int) endpointProto.getPort())); } + /** Version of the endpoints which increases with every modification. */ + public abstract long version(); + /** * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key * is a global data tag and the value is the endpoint where the data associated with the global @@ -138,7 +151,7 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( * Windmill servers. Returns a list of endpoints used to communicate with the corresponding * Windmill servers. */ - public abstract ImmutableList windmillEndpoints(); + public abstract ImmutableSet windmillEndpoints(); /** * Representation of an endpoint in {@link Windmill.WorkerMetadataResponse.Endpoint} proto with @@ -204,13 +217,15 @@ public abstract static class Builder { @AutoValue.Builder public abstract static class Builder { + public abstract Builder setVersion(long version); + public abstract Builder setGlobalDataEndpoints( ImmutableMap globalDataServers); public abstract Builder setWindmillEndpoints( - ImmutableList windmillServers); + ImmutableSet windmillServers); - abstract ImmutableList.Builder windmillEndpointsBuilder(); + abstract ImmutableSet.Builder windmillEndpointsBuilder(); public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint endpoint) { windmillEndpointsBuilder().add(endpoint); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java index cd753cb8ec91..2ae97087fec7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java @@ -30,10 +30,16 @@ public abstract class WindmillServerStub @Override public void appendSummaryHtml(PrintWriter writer) {} - /** Generic Exception type for implementors to use to represent errors while making RPCs. */ - public static final class RpcException extends RuntimeException { - public RpcException(Throwable cause) { + /** + * Generic Exception type for implementors to use to represent errors while making Windmill RPCs. + */ + public static final class WindmillRpcException extends RuntimeException { + public WindmillRpcException(Throwable cause) { super(cause); } + + public WindmillRpcException(String message, Throwable cause) { + super(message, cause); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java index 90f93b072673..0b895652efe2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java @@ -19,38 +19,36 @@ import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; -import java.net.Inet6Address; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; /** Used to create channels to communicate with Streaming Engine via gRpc. */ @AutoOneOf(WindmillServiceAddress.Kind.class) public abstract class WindmillServiceAddress { - public static WindmillServiceAddress create(Inet6Address ipv6Address) { - return AutoOneOf_WindmillServiceAddress.ipv6(ipv6Address); - } public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) { return AutoOneOf_WindmillServiceAddress.gcpServiceAddress(gcpServiceAddress); } - public abstract Kind getKind(); + public static WindmillServiceAddress create( + AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { + return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( + authenticatedGcpServiceAddress); + } - public abstract Inet6Address ipv6(); + public abstract Kind getKind(); public abstract HostAndPort gcpServiceAddress(); public abstract AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress(); - public static WindmillServiceAddress create( - AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { - return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( - authenticatedGcpServiceAddress); + public final HostAndPort getServiceAddress() { + return getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS + ? gcpServiceAddress() + : authenticatedGcpServiceAddress().gcpServiceAddress(); } public enum Kind { - IPV6, GCP_SERVICE_ADDRESS, - // TODO(m-trieu): Use for direct connections when ALTS is enabled. AUTHENTICATED_GCP_SERVICE_ADDRESS } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 58aecfc71e00..8b48459eba94 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -17,30 +17,27 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.io.IOException; import java.io.PrintWriter; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.concurrent.GuardedBy; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.api.client.util.Sleeper; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Status; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.StatusRuntimeException; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Base class for persistent streams connecting to Windmill. @@ -49,46 +46,55 @@ * stream if it is broken. Subclasses are responsible for retrying requests that have been lost on a * broken stream. * - *

Subclasses should override onResponse to handle responses from the server, and onNewStream to - * perform any work that must be done when a new stream is created, such as sending headers or - * retrying requests. + *

Subclasses should override {@link #onResponse(ResponseT)} to handle responses from the server, + * and {@link #onNewStream()} to perform any work that must be done when a new stream is created, + * such as sending headers or retrying requests. * - *

send and startStream should not be called from onResponse; use executor() instead. + *

{@link #trySend(RequestT)} and {@link #startStream()} should not be called from {@link + * #onResponse(ResponseT)}; use {@link #executeSafely(Runnable)} instead. * *

Synchronization on this is used to synchronize the gRpc stream state and internal data * structures. Since grpc channel operations may block, synchronization on this stream may also * block. This is generally not a problem since streams are used in a single-threaded manner. * However, some accessors used for status page and other debugging need to take care not to require * synchronizing on this. + * + *

{@link #start()} and {@link #shutdown()} are called once in the lifetime of the stream. Once + * {@link #shutdown()}, a stream in considered invalid and cannot be restarted/reused. */ public abstract class AbstractWindmillStream implements WindmillStream { - public static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce // per-chunk overhead, and small enough that we can still perform granular flow-control. protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20; - private static final Logger LOG = LoggerFactory.getLogger(AbstractWindmillStream.class); - protected final AtomicBoolean clientClosed; - private final AtomicBoolean isShutdown; - private final AtomicLong lastSendTimeMs; - private final Executor executor; + // Indicates that the logical stream has been half-closed and is waiting for clean server + // shutdown. + private static final Status OK_STATUS = Status.fromCode(Status.Code.OK); + private static final String NEVER_RECEIVED_RESPONSE_LOG_STRING = "never received response"; + private static final String NOT_SHUTDOWN = "not shutdown"; + protected final Sleeper sleeper; + + private final Logger logger; + private final ExecutorService executor; private final BackOff backoff; - private final AtomicLong startTimeMs; - private final AtomicLong lastResponseTimeMs; - private final AtomicInteger errorCount; - private final AtomicReference lastError; - private final AtomicReference lastErrorTime; - private final AtomicLong sleepUntil; private final CountDownLatch finishLatch; private final Set> streamRegistry; private final int logEveryNStreamFailures; - private final Supplier> requestObserverSupplier; - // Indicates if the current stream in requestObserver is closed by calling close() method - private final AtomicBoolean streamClosed; private final String backendWorkerToken; - private @Nullable StreamObserver requestObserver; + private final ResettableThrowingStreamObserver requestObserver; + private final StreamDebugMetrics debugMetrics; + + @GuardedBy("this") + protected boolean clientClosed; + + @GuardedBy("this") + protected boolean isShutdown; + + @GuardedBy("this") + private boolean started; protected AbstractWindmillStream( + Logger logger, String debugStreamType, Function, StreamObserver> clientFactory, BackOff backoff, @@ -106,21 +112,20 @@ protected AbstractWindmillStream( this.backoff = backoff; this.streamRegistry = streamRegistry; this.logEveryNStreamFailures = logEveryNStreamFailures; - this.clientClosed = new AtomicBoolean(); - this.streamClosed = new AtomicBoolean(); - this.startTimeMs = new AtomicLong(); - this.lastSendTimeMs = new AtomicLong(); - this.lastResponseTimeMs = new AtomicLong(); - this.errorCount = new AtomicInteger(); - this.lastError = new AtomicReference<>(); - this.lastErrorTime = new AtomicReference<>(); - this.sleepUntil = new AtomicLong(); + this.clientClosed = false; + this.isShutdown = false; + this.started = false; this.finishLatch = new CountDownLatch(1); - this.isShutdown = new AtomicBoolean(false); - this.requestObserverSupplier = - () -> - streamObserverFactory.from( - clientFactory, new AbstractWindmillStream.ResponseObserver()); + this.logger = logger; + this.requestObserver = + new ResettableThrowingStreamObserver<>( + () -> + streamObserverFactory.from( + clientFactory, + new AbstractWindmillStream.ResponseObserver()), + logger); + this.sleeper = Sleeper.DEFAULT; + this.debugMetrics = StreamDebugMetrics.create(); } private static String createThreadName(String streamType, String backendWorkerToken) { @@ -129,18 +134,11 @@ private static String createThreadName(String streamType, String backendWorkerTo : String.format("%s-WindmillStream-thread", streamType); } - private static long debugDuration(long nowMs, long startMs) { - if (startMs <= 0) { - return -1; - } - return Math.max(0, nowMs - startMs); - } - /** Called on each response from the server. */ protected abstract void onResponse(ResponseT response); /** Called when a new underlying stream to the server has been opened. */ - protected abstract void onNewStream(); + protected abstract void onNewStream() throws WindmillStreamShutdownException; /** Returns whether there are any pending requests that should be retried on a stream break. */ protected abstract boolean hasPendingRequests(); @@ -152,114 +150,161 @@ private static long debugDuration(long nowMs, long startMs) { */ protected abstract void startThrottleTimer(); - /** Reflects that {@link #shutdown()} was explicitly called. */ - protected boolean isShutdown() { - return isShutdown.get(); - } - - private StreamObserver requestObserver() { - if (requestObserver == null) { - throw new NullPointerException( - "requestObserver cannot be null. Missing a call to startStream() to initialize."); + /** Try to send a request to the server. Returns true if the request was successfully sent. */ + @CanIgnoreReturnValue + protected final synchronized boolean trySend(RequestT request) + throws WindmillStreamShutdownException { + debugMetrics.recordSend(); + try { + requestObserver.onNext(request); + return true; + } catch (ResettableThrowingStreamObserver.StreamClosedException e) { + // Stream was broken, requests may be retried when stream is reopened. } - return requestObserver; + return false; } - /** Send a request to the server. */ - protected final void send(RequestT request) { - lastSendTimeMs.set(Instant.now().getMillis()); + @Override + public final void start() { + boolean shouldStartStream = false; synchronized (this) { - if (streamClosed.get()) { - throw new IllegalStateException("Send called on a client closed stream."); + if (!isShutdown && !started) { + started = true; + shouldStartStream = true; } + } - requestObserver().onNext(request); + if (shouldStartStream) { + startStream(); } } /** Starts the underlying stream. */ - protected final void startStream() { + private void startStream() { // Add the stream to the registry after it has been fully constructed. streamRegistry.add(this); while (true) { try { synchronized (this) { - startTimeMs.set(Instant.now().getMillis()); - lastResponseTimeMs.set(0); - streamClosed.set(false); - // lazily initialize the requestObserver. Gets reset whenever the stream is reopened. - requestObserver = requestObserverSupplier.get(); + debugMetrics.recordStart(); + requestObserver.reset(); onNewStream(); - if (clientClosed.get()) { + if (clientClosed) { halfClose(); } return; } + } catch (WindmillStreamShutdownException e) { + // shutdown() is responsible for cleaning up pending requests. + logger.debug("Stream was shutdown while creating new stream.", e); + break; } catch (Exception e) { - LOG.error("Failed to create new stream, retrying: ", e); + logger.error("Failed to create new stream, retrying: ", e); try { long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); - Thread.sleep(sleep); - } catch (InterruptedException | IOException i) { + debugMetrics.recordSleep(sleep); + sleeper.sleep(sleep); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + logger.info( + "Interrupted during {} creation backoff. The stream will not be created.", + getClass()); + // Shutdown the stream to clean up any dangling resources and pending requests. + shutdown(); + break; + } catch (IOException ioe) { // Keep trying to create the stream. } } } + + // We were never able to start the stream, remove it from the stream registry. Otherwise, it is + // removed when closed. + streamRegistry.remove(this); } - protected final Executor executor() { - return executor; + /** + * Execute the runnable using the {@link #executor} handling the executor being in a shutdown + * state. + */ + protected final void executeSafely(Runnable runnable) { + try { + executor.execute(runnable); + } catch (RejectedExecutionException e) { + logger.debug("{}-{} has been shutdown.", getClass(), backendWorkerToken); + } } public final synchronized void maybeSendHealthCheck(Instant lastSendThreshold) { - if (lastSendTimeMs.get() < lastSendThreshold.getMillis() && !clientClosed.get()) { + if (!clientClosed && debugMetrics.getLastSendTimeMs() < lastSendThreshold.getMillis()) { try { sendHealthCheck(); - } catch (RuntimeException e) { - LOG.debug("Received exception sending health check.", e); + } catch (Exception e) { + logger.debug("Received exception sending health check.", e); } } } - protected abstract void sendHealthCheck(); + protected abstract void sendHealthCheck() throws WindmillStreamShutdownException; - // Care is taken that synchronization on this is unnecessary for all status page information. - // Blocking sends are made beneath this stream object's lock which could block status page - // rendering. + /** + * @implNote Care is taken that synchronization on this is unnecessary for all status page + * information. Blocking sends are made beneath this stream object's lock which could block + * status page rendering. + */ public final void appendSummaryHtml(PrintWriter writer) { appendSpecificHtml(writer); - if (errorCount.get() > 0) { - writer.format( - ", %d errors, last error [ %s ] at [%s]", - errorCount.get(), lastError.get(), lastErrorTime.get()); - } - if (clientClosed.get()) { + StreamDebugMetrics.Snapshot summaryMetrics = debugMetrics.getSummaryMetrics(); + summaryMetrics + .restartMetrics() + .ifPresent( + metrics -> + writer.format( + ", %d restarts, last restart reason [ %s ] at [%s], %d errors", + metrics.restartCount(), + metrics.lastRestartReason(), + metrics.lastRestartTime().orElse(null), + metrics.errorCount())); + + if (summaryMetrics.isClientClosed()) { writer.write(", client closed"); } - long nowMs = Instant.now().getMillis(); - long sleepLeft = sleepUntil.get() - nowMs; - if (sleepLeft > 0) { - writer.format(", %dms backoff remaining", sleepLeft); + + if (summaryMetrics.sleepLeft() > 0) { + writer.format(", %dms backoff remaining", summaryMetrics.sleepLeft()); } + writer.format( - ", current stream is %dms old, last send %dms, last response %dms, closed: %s", - debugDuration(nowMs, startTimeMs.get()), - debugDuration(nowMs, lastSendTimeMs.get()), - debugDuration(nowMs, lastResponseTimeMs.get()), - streamClosed.get()); + ", current stream is %dms old, last send %dms, last response %dms, closed: %s, " + + "shutdown time: %s", + summaryMetrics.streamAge(), + summaryMetrics.timeSinceLastSend(), + summaryMetrics.timeSinceLastResponse(), + requestObserver.isClosed(), + summaryMetrics.shutdownTime().map(DateTime::toString).orElse(NOT_SHUTDOWN)); } - // Don't require synchronization on stream, see the appendSummaryHtml comment. + /** + * @implNote Don't require synchronization on stream, see the {@link + * #appendSummaryHtml(PrintWriter)} comment. + */ protected abstract void appendSpecificHtml(PrintWriter writer); @Override public final synchronized void halfClose() { // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. - clientClosed.set(true); - requestObserver().onCompleted(); - streamClosed.set(true); + debugMetrics.recordHalfClose(); + clientClosed = true; + try { + requestObserver.onCompleted(); + } catch (ResettableThrowingStreamObserver.StreamClosedException e) { + logger.warn("Stream was previously closed."); + } catch (WindmillStreamShutdownException e) { + logger.warn("Stream was previously shutdown."); + } catch (IllegalStateException e) { + logger.warn("Unexpected error when trying to close stream", e); + } } @Override @@ -269,7 +314,7 @@ public final boolean awaitTermination(int time, TimeUnit unit) throws Interrupte @Override public final Instant startTime() { - return new Instant(startTimeMs.get()); + return new Instant(debugMetrics.getStartTimeMs()); } @Override @@ -278,22 +323,31 @@ public String backendWorkerToken() { } @Override - public void shutdown() { - if (isShutdown.compareAndSet(false, true)) { - requestObserver() - .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + public final void shutdown() { + // Don't lock on "this" before poisoning the request observer since otherwise the observer may + // be blocking in send(). + requestObserver.poison(); + synchronized (this) { + if (!isShutdown) { + isShutdown = true; + debugMetrics.recordShutdown(); + shutdownInternal(); + } } } - private void setLastError(String error) { - lastError.set(error); - lastErrorTime.set(DateTime.now()); - } + protected abstract void shutdownInternal(); - public static class WindmillStreamShutdownException extends RuntimeException { - public WindmillStreamShutdownException(String message) { - super(message); + /** Returns true if the stream was torn down and should not be restarted internally. */ + private synchronized boolean maybeTearDownStream() { + if (isShutdown || (clientClosed && !hasPendingRequests())) { + streamRegistry.remove(AbstractWindmillStream.this); + finishLatch.countDown(); + executor.shutdownNow(); + return true; } + + return false; } private class ResponseObserver implements StreamObserver { @@ -305,77 +359,83 @@ public void onNext(ResponseT response) { } catch (IOException e) { // Ignore. } - lastResponseTimeMs.set(Instant.now().getMillis()); + debugMetrics.recordResponse(); onResponse(response); } @Override public void onError(Throwable t) { - onStreamFinished(t); + if (maybeTearDownStream()) { + return; + } + + Status errorStatus = Status.fromThrowable(t); + recordStreamStatus(errorStatus); + + // If the stream was stopped due to a resource exhausted error then we are throttled. + if (errorStatus.getCode() == Status.Code.RESOURCE_EXHAUSTED) { + startThrottleTimer(); + } + + try { + long sleep = backoff.nextBackOffMillis(); + debugMetrics.recordSleep(sleep); + sleeper.sleep(sleep); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } catch (IOException e) { + // Ignore. + } + + executeSafely(AbstractWindmillStream.this::startStream); } @Override public void onCompleted() { - onStreamFinished(null); + if (maybeTearDownStream()) { + return; + } + recordStreamStatus(OK_STATUS); + executeSafely(AbstractWindmillStream.this::startStream); } - private void onStreamFinished(@Nullable Throwable t) { - synchronized (this) { - if (isShutdown.get() || (clientClosed.get() && !hasPendingRequests())) { - streamRegistry.remove(AbstractWindmillStream.this); - finishLatch.countDown(); - return; - } - } - if (t != null) { - Status status = null; - if (t instanceof StatusRuntimeException) { - status = ((StatusRuntimeException) t).getStatus(); - } - String statusError = status == null ? "" : status.toString(); - setLastError(statusError); - if (errorCount.getAndIncrement() % logEveryNStreamFailures == 0) { + private void recordStreamStatus(Status status) { + int currentRestartCount = debugMetrics.incrementAndGetRestarts(); + if (status.isOk()) { + String restartReason = + "Stream completed successfully but did not complete requested operations, " + + "recreating"; + logger.warn(restartReason); + debugMetrics.recordRestartReason(restartReason); + } else { + int currentErrorCount = debugMetrics.incrementAndGetErrors(); + debugMetrics.recordRestartReason(status.toString()); + Throwable t = status.getCause(); + if (t instanceof StreamObserverCancelledException) { + logger.error( + "StreamObserver was unexpectedly cancelled for stream={}, worker={}. stacktrace={}", + getClass(), + backendWorkerToken, + t.getStackTrace(), + t); + } else if (currentRestartCount % logEveryNStreamFailures == 0) { + // Don't log every restart since it will get noisy, and many errors transient. long nowMillis = Instant.now().getMillis(); - String responseDebug; - if (lastResponseTimeMs.get() == 0) { - responseDebug = "never received response"; - } else { - responseDebug = - "received response " + (nowMillis - lastResponseTimeMs.get()) + "ms ago"; - } - LOG.debug( - "{} streaming Windmill RPC errors for {}, last was: {} with status {}." - + " created {}ms ago, {}. This is normal with autoscaling.", + logger.debug( + "{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}" + + " with status: {}. created {}ms ago; {}. This is normal with autoscaling.", AbstractWindmillStream.this.getClass(), - errorCount.get(), + currentRestartCount, + currentErrorCount, t, - statusError, - nowMillis - startTimeMs.get(), - responseDebug); - } - // If the stream was stopped due to a resource exhausted error then we are throttled. - if (status != null && status.getCode() == Status.Code.RESOURCE_EXHAUSTED) { - startThrottleTimer(); - } - - try { - long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); - Thread.sleep(sleep); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } catch (IOException e) { - // Ignore. + status, + nowMillis - debugMetrics.getStartTimeMs(), + debugMetrics + .responseDebugString(nowMillis) + .orElse(NEVER_RECEIVED_RESPONSE_LOG_STRING)); } - } else { - errorCount.incrementAndGet(); - String error = - "Stream completed successfully but did not complete requested operations, " - + "recreating"; - LOG.warn(error); - setLastError(error); } - executor.execute(AbstractWindmillStream.this::startStream); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java new file mode 100644 index 000000000000..1db6d8de791d --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.TerminatingStreamObserver; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.slf4j.Logger; + +/** + * Request observer that allows resetting its internal delegate using the given {@link + * #streamObserverFactory}. + * + * @implNote {@link StreamObserver}s generated by {@link #streamObserverFactory} are expected to be + * {@link ThreadSafe}. Has same methods declared in {@link StreamObserver}, but they throw + * {@link StreamClosedException} and {@link WindmillStreamShutdownException}, which much be + * handled by callers. + */ +@ThreadSafe +@Internal +final class ResettableThrowingStreamObserver { + private final Supplier> streamObserverFactory; + private final Logger logger; + + @GuardedBy("this") + private @Nullable TerminatingStreamObserver delegateStreamObserver; + + @GuardedBy("this") + private boolean isPoisoned = false; + + /** + * Indicates that the current delegate is closed via {@link #poison() or {@link #onCompleted()}}. + * If not poisoned, a call to {@link #reset()} is required to perform future operations on the + * StreamObserver. + */ + @GuardedBy("this") + private boolean isCurrentStreamClosed = true; + + ResettableThrowingStreamObserver( + Supplier> streamObserverFactory, Logger logger) { + this.streamObserverFactory = streamObserverFactory; + this.logger = logger; + this.delegateStreamObserver = null; + } + + private synchronized StreamObserver delegate() + throws WindmillStreamShutdownException, StreamClosedException { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Stream is already shutdown."); + } + + if (isCurrentStreamClosed) { + throw new StreamClosedException( + "Current stream is closed, requires reset() for future stream operations."); + } + + return Preconditions.checkNotNull(delegateStreamObserver, "requestObserver cannot be null."); + } + + /** Creates a new delegate to use for future {@link StreamObserver} methods. */ + synchronized void reset() throws WindmillStreamShutdownException { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Stream is already shutdown."); + } + + delegateStreamObserver = streamObserverFactory.get(); + isCurrentStreamClosed = false; + } + + /** + * Indicates that the request observer should no longer be used. Attempts to perform operations on + * the request observer will throw an {@link WindmillStreamShutdownException}. + */ + synchronized void poison() { + if (!isPoisoned) { + isPoisoned = true; + if (delegateStreamObserver != null) { + delegateStreamObserver.terminate( + new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + delegateStreamObserver = null; + isCurrentStreamClosed = true; + } + } + } + + public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownException { + // Make sure onNext and onError below to be called on the same StreamObserver instance. + StreamObserver delegate = delegate(); + try { + // Do NOT lock while sending message over the stream as this will block other StreamObserver + // operations. + delegate.onNext(t); + } catch (StreamObserverCancelledException cancellationException) { + synchronized (this) { + if (isPoisoned) { + logger.debug("Stream was shutdown during send.", cancellationException); + return; + } + } + + try { + delegate.onError(cancellationException); + } catch (IllegalStateException onErrorException) { + // If the delegate above was already terminated via onError or onComplete from another + // thread. + logger.warn( + "StreamObserver was already cancelled {} due to error.", + onErrorException, + cancellationException); + } catch (RuntimeException onErrorException) { + logger.warn( + "Encountered unexpected error {} when cancelling due to error.", + onErrorException, + cancellationException); + } + } + } + + public synchronized void onError(Throwable throwable) + throws StreamClosedException, WindmillStreamShutdownException { + delegate().onError(throwable); + isCurrentStreamClosed = true; + } + + public synchronized void onCompleted() + throws StreamClosedException, WindmillStreamShutdownException { + delegate().onCompleted(); + isCurrentStreamClosed = true; + } + + synchronized boolean isClosed() { + return isCurrentStreamClosed; + } + + /** + * Indicates that the current stream was closed and the {@link StreamObserver} has finished via + * {@link StreamObserver#onCompleted()}. The stream may perform + */ + static final class StreamClosedException extends Exception { + StreamClosedException(String s) { + super(s); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java new file mode 100644 index 000000000000..4cda12a85ea2 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import com.google.auto.value.AutoValue; +import java.util.Optional; +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.joda.time.DateTime; +import org.joda.time.Instant; + +/** Records stream metrics for debugging. */ +@ThreadSafe +final class StreamDebugMetrics { + private final Supplier clock; + + @GuardedBy("this") + private int errorCount = 0; + + @GuardedBy("this") + private int restartCount = 0; + + @GuardedBy("this") + private long sleepUntil = 0; + + @GuardedBy("this") + private String lastRestartReason = ""; + + @GuardedBy("this") + private @Nullable DateTime lastRestartTime = null; + + @GuardedBy("this") + private long lastResponseTimeMs = 0; + + @GuardedBy("this") + private long lastSendTimeMs = 0; + + @GuardedBy("this") + private long startTimeMs = 0; + + @GuardedBy("this") + private @Nullable DateTime shutdownTime = null; + + @GuardedBy("this") + private boolean clientClosed = false; + + private StreamDebugMetrics(Supplier clock) { + this.clock = clock; + } + + static StreamDebugMetrics create() { + return new StreamDebugMetrics(Instant::now); + } + + @VisibleForTesting + static StreamDebugMetrics forTesting(Supplier fakeClock) { + return new StreamDebugMetrics(fakeClock); + } + + private static long debugDuration(long nowMs, long startMs) { + return startMs <= 0 ? -1 : Math.max(0, nowMs - startMs); + } + + private long nowMs() { + return clock.get().getMillis(); + } + + synchronized void recordSend() { + lastSendTimeMs = nowMs(); + } + + synchronized void recordStart() { + startTimeMs = nowMs(); + lastResponseTimeMs = 0; + } + + synchronized void recordResponse() { + lastResponseTimeMs = nowMs(); + } + + synchronized void recordRestartReason(String error) { + lastRestartReason = error; + lastRestartTime = clock.get().toDateTime(); + } + + synchronized long getStartTimeMs() { + return startTimeMs; + } + + synchronized long getLastSendTimeMs() { + return lastSendTimeMs; + } + + synchronized void recordSleep(long sleepMs) { + sleepUntil = nowMs() + sleepMs; + } + + synchronized int incrementAndGetRestarts() { + return restartCount++; + } + + synchronized int incrementAndGetErrors() { + return errorCount++; + } + + synchronized void recordShutdown() { + shutdownTime = clock.get().toDateTime(); + } + + synchronized void recordHalfClose() { + clientClosed = true; + } + + synchronized Optional responseDebugString(long nowMillis) { + return lastResponseTimeMs == 0 + ? Optional.empty() + : Optional.of("received response " + (nowMillis - lastResponseTimeMs) + "ms ago"); + } + + private synchronized Optional getRestartMetrics() { + if (restartCount > 0) { + return Optional.of( + RestartMetrics.create(restartCount, lastRestartReason, lastRestartTime, errorCount)); + } + + return Optional.empty(); + } + + synchronized Snapshot getSummaryMetrics() { + long nowMs = clock.get().getMillis(); + return Snapshot.create( + debugDuration(nowMs, startTimeMs), + debugDuration(nowMs, lastSendTimeMs), + debugDuration(nowMs, lastResponseTimeMs), + getRestartMetrics(), + sleepUntil - nowMs(), + shutdownTime, + clientClosed); + } + + @AutoValue + abstract static class Snapshot { + private static Snapshot create( + long streamAge, + long timeSinceLastSend, + long timeSinceLastResponse, + Optional restartMetrics, + long sleepLeft, + @Nullable DateTime shutdownTime, + boolean isClientClosed) { + return new AutoValue_StreamDebugMetrics_Snapshot( + streamAge, + timeSinceLastSend, + timeSinceLastResponse, + restartMetrics, + sleepLeft, + Optional.ofNullable(shutdownTime), + isClientClosed); + } + + abstract long streamAge(); + + abstract long timeSinceLastSend(); + + abstract long timeSinceLastResponse(); + + abstract Optional restartMetrics(); + + abstract long sleepLeft(); + + abstract Optional shutdownTime(); + + abstract boolean isClientClosed(); + } + + @AutoValue + abstract static class RestartMetrics { + private static RestartMetrics create( + int restartCount, + String restartReason, + @Nullable DateTime lastRestartTime, + int errorCount) { + return new AutoValue_StreamDebugMetrics_RestartMetrics( + restartCount, restartReason, Optional.ofNullable(lastRestartTime), errorCount); + } + + abstract int restartCount(); + + abstract String lastRestartReason(); + + abstract Optional lastRestartTime(); + + abstract int errorCount(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 31bd4e146a78..51bc03e8e0e7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -34,6 +34,12 @@ @ThreadSafe public interface WindmillStream { + /** + * Start the stream, opening a connection to the backend server. A call to start() is required for + * any further interactions on the stream. + */ + void start(); + /** An identifier for the backend worker where the stream is sending/receiving RPCs. */ String backendWorkerToken(); @@ -47,8 +53,9 @@ public interface WindmillStream { Instant startTime(); /** - * Shutdown the stream. There should be no further interactions with the stream once this has been - * called. + * Shuts down the stream. No further interactions should be made with the stream, and the stream + * will no longer try to connect internally. Any pending retries or in-flight requests will be + * cancelled and all responses dropped and considered invalid. */ void shutdown(); @@ -56,10 +63,11 @@ public interface WindmillStream { @ThreadSafe interface GetWorkStream extends WindmillStream { /** Adjusts the {@link GetWorkBudget} for the stream. */ - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(GetWorkBudget newBudget); - /** Returns the remaining in-flight {@link GetWorkBudget}. */ - GetWorkBudget remainingBudget(); + default void setBudget(long newItems, long newBytes) { + setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build()); + } } /** Interface for streaming GetDataRequests to Windmill. */ @@ -67,13 +75,16 @@ interface GetWorkStream extends WindmillStream { interface GetDataStream extends WindmillStream { /** Issues a keyed GetData fetch, blocking until the result is ready. */ Windmill.KeyedGetDataResponse requestKeyedData( - String computation, Windmill.KeyedGetDataRequest request); + String computation, Windmill.KeyedGetDataRequest request) + throws WindmillStreamShutdownException; /** Issues a global GetData fetch, blocking until the result is ready. */ - Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request); + Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request) + throws WindmillStreamShutdownException; /** Tells windmill processing is ongoing for the given keys. */ - void refreshActiveWork(Map> heartbeats); + void refreshActiveWork(Map> heartbeats) + throws WindmillStreamShutdownException; void onHeartbeatResponse(List responses); } @@ -85,7 +96,7 @@ interface CommitWorkStream extends WindmillStream { * Returns a builder that can be used for sending requests. Each builder is not thread-safe but * different builders for the same stream may be used simultaneously. */ - CommitWorkStream.RequestBatcher batcher(); + RequestBatcher batcher(); @NotThreadSafe interface RequestBatcher extends Closeable { diff --git a/sdks/python/container/py38/build.gradle b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java similarity index 54% rename from sdks/python/container/py38/build.gradle rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java index 304895a83718..566c15c58036 100644 --- a/sdks/python/container/py38/build.gradle +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java @@ -4,25 +4,26 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.beam.runners.dataflow.worker.windmill.client; -plugins { - id 'base' - id 'org.apache.beam.module' +/** + * Thrown when operations are requested on a {@link WindmillStream} has been shutdown. Future + * operations on the stream are not allowed and will throw an {@link + * WindmillStreamShutdownException}. + */ +public final class WindmillStreamShutdownException extends Exception { + public WindmillStreamShutdownException(String message) { + super(message); + } } -applyDockerNature() -applyPythonNature() - -pythonVersion = '3.8' - -apply from: "../common.gradle" diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java new file mode 100644 index 000000000000..498e90f78e29 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; + +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; + +/** Utility class for commits. */ +@Internal +public final class Commits { + + /** Max bytes of commits queued on the user worker. */ + @VisibleForTesting static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB + + private Commits() {} + + public static WeightedSemaphore maxCommitByteSemaphore() { + return WeightedSemaphore.create(MAX_QUEUED_COMMITS_BYTES, Commit::getSize); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index d092ebf53fc1..20b95b0661d0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -42,7 +42,6 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class); private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20; - private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB private final Consumer commitWorkFn; private final WeightedBoundedQueue commitQueue; @@ -53,11 +52,9 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter { private StreamingApplianceWorkCommitter( Consumer commitWorkFn, Consumer onCommitComplete) { this.commitWorkFn = commitWorkFn; - this.commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitQueue = WeightedBoundedQueue.create(Commits.maxCommitByteSemaphore()); this.commitWorkers = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) .setPriority(Thread.MAX_PRIORITY) @@ -73,10 +70,9 @@ public static StreamingApplianceWorkCommitter create( } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { if (!commitWorkers.isShutdown()) { - commitWorkers.submit(this::commitLoop); + commitWorkers.execute(this::commitLoop); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index bf1007bc4bfb..85fa1d67c6c3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -28,6 +28,7 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; @@ -46,7 +47,6 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class); private static final int TARGET_COMMIT_BATCH_KEYS = 5; - private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB private static final String NO_BACKEND_WORKER_TOKEN = ""; private final Supplier> commitWorkStreamFactory; @@ -61,11 +61,10 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { Supplier> commitWorkStreamFactory, int numCommitSenders, Consumer onCommitComplete, - String backendWorkerToken) { + String backendWorkerToken, + WeightedSemaphore commitByteSemaphore) { this.commitWorkStreamFactory = commitWorkStreamFactory; - this.commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitQueue = WeightedBoundedQueue.create(commitByteSemaphore); this.commitSenders = Executors.newFixedThreadPool( numCommitSenders, @@ -90,12 +89,11 @@ public static Builder builder() { } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { Preconditions.checkState( isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start()."); for (int i = 0; i < numCommitSenders; i++) { - commitSenders.submit(this::streamingCommitLoop); + commitSenders.execute(this::streamingCommitLoop); } } @@ -166,6 +164,8 @@ private void streamingCommitLoop() { return; } } + + // take() blocks until a value is available in the commitQueue. Preconditions.checkNotNull(initialCommit); if (initialCommit.work().isFailed()) { @@ -258,6 +258,8 @@ public interface Builder { Builder setCommitWorkStreamFactory( Supplier> commitWorkStreamFactory); + Builder setCommitByteSemaphore(WeightedSemaphore commitByteSemaphore); + Builder setNumCommitSenders(int numCommitSenders); Builder setOnCommitComplete(Consumer onCommitComplete); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java index c8e058e7e230..ab12946ad18b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java @@ -21,8 +21,8 @@ import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.sdk.annotations.Internal; /** {@link GetDataClient} that fetches data directly from a specific {@link GetDataStream}. */ @@ -61,7 +61,7 @@ public Windmill.KeyedGetDataResponse getStateData( String computationId, Windmill.KeyedGetDataRequest request) throws GetDataException { try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { return getDataStream.requestKeyedData(computationId, request); - } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + } catch (WindmillStreamShutdownException e) { throw new WorkItemCancelledException(request.getShardingKey()); } catch (Exception e) { throw new GetDataException( @@ -86,7 +86,7 @@ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) sideInputGetDataStreamFactory.apply(request.getDataId().getTag()); try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { return sideInputGetDataStream.requestGlobalData(request); - } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + } catch (WindmillStreamShutdownException e) { throw new WorkItemCancelledException(e); } catch (Exception e) { throw new GetDataException( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java index 98545a429461..b15f73645dee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java @@ -134,6 +134,12 @@ public void close() throws IOException { stream.close(); } + static class InvalidInputStreamStateException extends IllegalStateException { + public InvalidInputStreamStateException() { + super("Got poison pill or timeout but stream is not done."); + } + } + @SuppressWarnings("NullableProblems") private class InputStreamEnumeration implements Enumeration { // The first stream is eagerly read on SequenceInputStream creation. For this reason @@ -159,7 +165,7 @@ public boolean hasMoreElements() { if (complete.get()) { return false; } - throw new IllegalStateException("Got poison pill or timeout but stream is not done."); + throw new InvalidInputStreamStateException(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new CancellationException(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 053843a8af25..2dd069b9c443 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -19,14 +19,19 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import com.google.auto.value.AutoValue; import java.io.PrintWriter; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; +import javax.annotation.Nullable; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; @@ -35,22 +40,25 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class GrpcCommitWorkStream +final class GrpcCommitWorkStream extends AbstractWindmillStream implements CommitWorkStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcCommitWorkStream.class); private static final long HEARTBEAT_REQUEST_ID = Long.MAX_VALUE; - private final Map pending; + private final ConcurrentMap pending; private final AtomicLong idGenerator; private final JobHeader jobHeader; private final ThrottleTimer commitWorkThrottleTimer; @@ -69,6 +77,7 @@ private GrpcCommitWorkStream( AtomicLong idGenerator, int streamingRpcBatchLimit) { super( + LOG, "CommitWorkStream", startCommitWorkRpcFn, backoff, @@ -83,7 +92,7 @@ private GrpcCommitWorkStream( this.streamingRpcBatchLimit = streamingRpcBatchLimit; } - public static GrpcCommitWorkStream create( + static GrpcCommitWorkStream create( String backendWorkerToken, Function, StreamObserver> startCommitWorkRpcFn, @@ -95,20 +104,17 @@ public static GrpcCommitWorkStream create( JobHeader jobHeader, AtomicLong idGenerator, int streamingRpcBatchLimit) { - GrpcCommitWorkStream commitWorkStream = - new GrpcCommitWorkStream( - backendWorkerToken, - startCommitWorkRpcFn, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - commitWorkThrottleTimer, - jobHeader, - idGenerator, - streamingRpcBatchLimit); - commitWorkStream.startStream(); - return commitWorkStream; + return new GrpcCommitWorkStream( + backendWorkerToken, + startCommitWorkRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + commitWorkThrottleTimer, + jobHeader, + idGenerator, + streamingRpcBatchLimit); } @Override @@ -117,8 +123,8 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - protected synchronized void onNewStream() { - send(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); + protected synchronized void onNewStream() throws WindmillStreamShutdownException { + trySend(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); try (Batcher resendBatcher = new Batcher()) { for (Map.Entry entry : pending.entrySet()) { if (!resendBatcher.canAccept(entry.getValue().getBytes())) { @@ -144,11 +150,11 @@ protected boolean hasPendingRequests() { } @Override - public void sendHealthCheck() { + public void sendHealthCheck() throws WindmillStreamShutdownException { if (hasPendingRequests()) { StreamingCommitWorkRequest.Builder builder = StreamingCommitWorkRequest.newBuilder(); builder.addCommitChunkBuilder().setRequestId(HEARTBEAT_REQUEST_ID); - send(builder.build()); + trySend(builder.build()); } } @@ -156,29 +162,49 @@ public void sendHealthCheck() { protected void onResponse(StreamingCommitResponse response) { commitWorkThrottleTimer.stop(); - RuntimeException finalException = null; + CommitCompletionFailureHandler failureHandler = new CommitCompletionFailureHandler(); for (int i = 0; i < response.getRequestIdCount(); ++i) { long requestId = response.getRequestId(i); if (requestId == HEARTBEAT_REQUEST_ID) { continue; } - PendingRequest done = pending.remove(requestId); - if (done == null) { - LOG.error("Got unknown commit request ID: {}", requestId); + + // From windmill.proto: Indices must line up with the request_id field, but trailing OKs may + // be omitted. + CommitStatus commitStatus = + i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK; + + @Nullable PendingRequest pendingRequest = pending.remove(requestId); + if (pendingRequest == null) { + synchronized (this) { + if (!isShutdown) { + // Missing responses are expected after shutdown() because it removes them. + LOG.error("Got unknown commit request ID: {}", requestId); + } + } } else { try { - done.onDone.accept( - (i < response.getStatusCount()) ? response.getStatus(i) : CommitStatus.OK); + pendingRequest.completeWithStatus(commitStatus); } catch (RuntimeException e) { // Catch possible exceptions to ensure that an exception for one commit does not prevent - // other commits from being processed. + // other commits from being processed. Aggregate all the failures to throw after + // processing the response if they exist. LOG.warn("Exception while processing commit response.", e); - finalException = e; + failureHandler.addError(commitStatus, e); } } } - if (finalException != null) { - throw finalException; + + failureHandler.throwIfNonEmpty(); + } + + @Override + protected void shutdownInternal() { + Iterator pendingRequests = pending.values().iterator(); + while (pendingRequests.hasNext()) { + PendingRequest pendingRequest = pendingRequests.next(); + pendingRequest.abort(); + pendingRequests.remove(); } } @@ -187,13 +213,15 @@ protected void startThrottleTimer() { commitWorkThrottleTimer.start(); } - private void flushInternal(Map requests) { + private void flushInternal(Map requests) + throws WindmillStreamShutdownException { if (requests.isEmpty()) { return; } + if (requests.size() == 1) { Map.Entry elem = requests.entrySet().iterator().next(); - if (elem.getValue().request.getSerializedSize() + if (elem.getValue().request().getSerializedSize() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { issueMultiChunkRequest(elem.getKey(), elem.getValue()); } else { @@ -204,100 +232,171 @@ private void flushInternal(Map requests) { } } - private void issueSingleRequest(final long id, PendingRequest pendingRequest) { + private void issueSingleRequest(long id, PendingRequest pendingRequest) + throws WindmillStreamShutdownException { StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); requestBuilder .addCommitChunkBuilder() - .setComputationId(pendingRequest.computation) + .setComputationId(pendingRequest.computationId()) .setRequestId(id) - .setShardingKey(pendingRequest.request.getShardingKey()) - .setSerializedWorkItemCommit(pendingRequest.request.toByteString()); + .setShardingKey(pendingRequest.shardingKey()) + .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { - pending.put(id, pendingRequest); - try { - send(chunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. + if (!prepareForSend(id, pendingRequest)) { + pendingRequest.abort(); + return; } + trySend(chunk); } } - private void issueBatchedRequest(Map requests) { + private void issueBatchedRequest(Map requests) + throws WindmillStreamShutdownException { StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); String lastComputation = null; for (Map.Entry entry : requests.entrySet()) { PendingRequest request = entry.getValue(); StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); - if (lastComputation == null || !lastComputation.equals(request.computation)) { - chunkBuilder.setComputationId(request.computation); - lastComputation = request.computation; + if (lastComputation == null || !lastComputation.equals(request.computationId())) { + chunkBuilder.setComputationId(request.computationId()); + lastComputation = request.computationId(); } - chunkBuilder.setRequestId(entry.getKey()); - chunkBuilder.setShardingKey(request.request.getShardingKey()); - chunkBuilder.setSerializedWorkItemCommit(request.request.toByteString()); + chunkBuilder + .setRequestId(entry.getKey()) + .setShardingKey(request.shardingKey()) + .setSerializedWorkItemCommit(request.serializedCommit()); } StreamingCommitWorkRequest request = requestBuilder.build(); synchronized (this) { - pending.putAll(requests); - try { - send(request); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. + if (!prepareForSend(requests)) { + requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); + return; } + trySend(request); } } - private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { - checkNotNull(pendingRequest.computation); - final ByteString serializedCommit = pendingRequest.request.toByteString(); - + private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) + throws WindmillStreamShutdownException { + checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); + ByteString serializedCommit = pendingRequest.serializedCommit(); synchronized (this) { - pending.put(id, pendingRequest); + if (!prepareForSend(id, pendingRequest)) { + pendingRequest.abort(); + return; + } + for (int i = 0; i < serializedCommit.size(); i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); - StreamingCommitRequestChunk.Builder chunkBuilder = StreamingCommitRequestChunk.newBuilder() .setRequestId(id) .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.computation) - .setShardingKey(pendingRequest.request.getShardingKey()); + .setComputationId(pendingRequest.computationId()) + .setShardingKey(pendingRequest.shardingKey()); int remaining = serializedCommit.size() - end; if (remaining > 0) { chunkBuilder.setRemainingBytesForWorkItem(remaining); } - StreamingCommitWorkRequest requestChunk = StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); - try { - send(requestChunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. + + if (!trySend(requestChunk)) { + // The stream broke, don't try to send the rest of the chunks here. break; } } } } - private static class PendingRequest { + /** Returns true if prepare for send succeeded. */ + private synchronized boolean prepareForSend(long id, PendingRequest request) { + if (!isShutdown) { + pending.put(id, request); + return true; + } + return false; + } + + /** Returns true if prepare for send succeeded. */ + private synchronized boolean prepareForSend(Map requests) { + if (!isShutdown) { + pending.putAll(requests); + return true; + } + return false; + } + + @AutoValue + abstract static class PendingRequest { + + private static PendingRequest create( + String computationId, WorkItemCommitRequest request, Consumer onDone) { + return new AutoValue_GrpcCommitWorkStream_PendingRequest(computationId, request, onDone); + } + + abstract String computationId(); + + abstract WorkItemCommitRequest request(); + + abstract Consumer onDone(); + + private long getBytes() { + return (long) request().getSerializedSize() + computationId().length(); + } + + private ByteString serializedCommit() { + return request().toByteString(); + } + + private void completeWithStatus(CommitStatus commitStatus) { + onDone().accept(commitStatus); + } + + private long shardingKey() { + return request().getShardingKey(); + } + + private void abort() { + completeWithStatus(CommitStatus.ABORTED); + } + } + + private static class CommitCompletionException extends RuntimeException { + private CommitCompletionException(String message) { + super(message); + } + } + + private static class CommitCompletionFailureHandler { + private static final int MAX_PRINTABLE_ERRORS = 10; + private final Map>, Integer> errorCounter; + private final EvictingQueue detailedErrors; - private final String computation; - private final WorkItemCommitRequest request; - private final Consumer onDone; + private CommitCompletionFailureHandler() { + this.errorCounter = new HashMap<>(); + this.detailedErrors = EvictingQueue.create(MAX_PRINTABLE_ERRORS); + } - PendingRequest( - String computation, WorkItemCommitRequest request, Consumer onDone) { - this.computation = computation; - this.request = request; - this.onDone = onDone; + private void addError(CommitStatus commitStatus, Throwable error) { + errorCounter.compute( + Pair.of(commitStatus, error.getClass()), + (ignored, current) -> current == null ? 1 : current + 1); + detailedErrors.add(error); } - long getBytes() { - return (long) request.getSerializedSize() + computation.length(); + private void throwIfNonEmpty() { + if (!errorCounter.isEmpty()) { + String errorMessage = + String.format( + "Exception while processing commit response. ErrorCounter: %s; Details: %s", + errorCounter, detailedErrors); + throw new CommitCompletionException(errorMessage); + } } } @@ -317,7 +416,8 @@ public boolean commitWorkItem( if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { return false; } - PendingRequest request = new PendingRequest(computation, commitRequest, onDone); + + PendingRequest request = PendingRequest.create(computation, commitRequest, onDone); add(idGenerator.incrementAndGet(), request); return true; } @@ -325,13 +425,18 @@ public boolean commitWorkItem( /** Flushes any pending work items to the wire. */ @Override public void flush() { - flushInternal(queue); - queuedBytes = 0; - queue.clear(); + try { + flushInternal(queue); + } catch (WindmillStreamShutdownException e) { + queue.forEach((ignored, request) -> request.abort()); + } finally { + queuedBytes = 0; + queue.clear(); + } } void add(long id, PendingRequest request) { - assert (canAccept(request.getBytes())); + Preconditions.checkState(canAccept(request.getBytes())); queuedBytes += request.getBytes(); queue.put(id, request); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 45d010d7cfac..27f457900e6c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -23,7 +23,8 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.concurrent.GuardedBy; +import net.jcip.annotations.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -33,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GetWorkResponseChunkAssembler.AssembledWorkItem; @@ -44,8 +46,8 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Implementation of {@link GetWorkStream} that passes along a specific {@link @@ -55,9 +57,10 @@ * these direct streams are used to facilitate these RPC calls to specific backend workers. */ @Internal -public final class GrpcDirectGetWorkStream +final class GrpcDirectGetWorkStream extends AbstractWindmillStream implements GetWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcDirectGetWorkStream.class); private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST = StreamingGetWorkRequest.newBuilder() .setRequestExtension( @@ -67,15 +70,14 @@ public final class GrpcDirectGetWorkStream .build()) .build(); - private final AtomicReference inFlightBudget; - private final AtomicReference nextBudgetAdjustment; - private final AtomicReference pendingResponseBudget; - private final GetWorkRequest request; + private final GetWorkBudgetTracker budgetTracker; + private final GetWorkRequest requestHeader; private final WorkItemScheduler workItemScheduler; private final ThrottleTimer getWorkThrottleTimer; - private final Supplier heartbeatSender; - private final Supplier workCommitter; - private final Supplier getDataClient; + private final HeartbeatSender heartbeatSender; + private final WorkCommitter workCommitter; + private final GetDataClient getDataClient; + private final AtomicReference lastRequest; /** * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they @@ -92,17 +94,18 @@ private GrpcDirectGetWorkStream( StreamObserver, StreamObserver> startGetWorkRpcFn, - GetWorkRequest request, + GetWorkRequest requestHeader, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { super( + LOG, "GetWorkStream", startGetWorkRpcFn, backoff, @@ -110,19 +113,23 @@ private GrpcDirectGetWorkStream( streamRegistry, logEveryNStreamFailures, backendWorkerToken); - this.request = request; + this.requestHeader = requestHeader; this.getWorkThrottleTimer = getWorkThrottleTimer; this.workItemScheduler = workItemScheduler; this.workItemAssemblers = new ConcurrentHashMap<>(); - this.heartbeatSender = Suppliers.memoize(heartbeatSender::get); - this.workCommitter = Suppliers.memoize(workCommitter::get); - this.getDataClient = Suppliers.memoize(getDataClient::get); - this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); - this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget()); - this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget()); + this.heartbeatSender = heartbeatSender; + this.workCommitter = workCommitter; + this.getDataClient = getDataClient; + this.lastRequest = new AtomicReference<>(); + this.budgetTracker = + new GetWorkBudgetTracker( + GetWorkBudget.builder() + .setItems(requestHeader.getMaxItems()) + .setBytes(requestHeader.getMaxBytes()) + .build()); } - public static GrpcDirectGetWorkStream create( + static GrpcDirectGetWorkStream create( String backendWorkerToken, Function< StreamObserver, @@ -134,26 +141,23 @@ public static GrpcDirectGetWorkStream create( Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { - GrpcDirectGetWorkStream getWorkStream = - new GrpcDirectGetWorkStream( - backendWorkerToken, - startGetWorkRpcFn, - request, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getWorkThrottleTimer, - heartbeatSender, - getDataClient, - workCommitter, - workItemScheduler); - getWorkStream.startStream(); - return getWorkStream; + return new GrpcDirectGetWorkStream( + backendWorkerToken, + startGetWorkRpcFn, + request, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + heartbeatSender, + getDataClient, + workCommitter, + workItemScheduler); } private static Watermarks createWatermarks( @@ -165,46 +169,50 @@ private static Watermarks createWatermarks( .build(); } - private void sendRequestExtension(GetWorkBudget adjustment) { - inFlightBudget.getAndUpdate(budget -> budget.apply(adjustment)); - StreamingGetWorkRequest extension = - StreamingGetWorkRequest.newBuilder() - .setRequestExtension( - Windmill.StreamingGetWorkRequestExtension.newBuilder() - .setMaxItems(adjustment.items()) - .setMaxBytes(adjustment.bytes())) - .build(); - - executor() - .execute( - () -> { - try { - send(extension); - } catch (IllegalStateException e) { - // Stream was closed. - } - }); + /** + * @implNote Do not lock/synchronize here due to this running on grpc serial executor for message + * which can deadlock since we send on the stream beneath the synchronization. {@link + * AbstractWindmillStream#trySend(Object)} is synchronized so the sends are already guarded. + */ + private void maybeSendRequestExtension(GetWorkBudget extension) { + if (extension.items() > 0 || extension.bytes() > 0) { + executeSafely( + () -> { + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequestExtension( + Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(extension.items()) + .setMaxBytes(extension.bytes())) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(extension); + try { + trySend(request); + } catch (WindmillStreamShutdownException e) { + // Stream was closed. + } + }); + } } @Override - protected synchronized void onNewStream() { + protected synchronized void onNewStream() throws WindmillStreamShutdownException { workItemAssemblers.clear(); - // Add the current in-flight budget to the next adjustment. Only positive values are allowed - // here - // with negatives defaulting to 0, since GetWorkBudgets cannot be created with negative values. - GetWorkBudget budgetAdjustment = nextBudgetAdjustment.get().apply(inFlightBudget.get()); - inFlightBudget.set(budgetAdjustment); - send( + budgetTracker.reset(); + GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); + StreamingGetWorkRequest request = StreamingGetWorkRequest.newBuilder() .setRequest( - request + requestHeader .toBuilder() - .setMaxBytes(budgetAdjustment.bytes()) - .setMaxItems(budgetAdjustment.items())) - .build()); - - // We just sent the budget, reset it. - nextBudgetAdjustment.set(GetWorkBudget.noBudget()); + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build()) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(initialGetWorkBudget); + trySend(request); } @Override @@ -216,15 +224,19 @@ protected boolean hasPendingRequests() { public void appendSpecificHtml(PrintWriter writer) { // Number of buffers is same as distinct workers that sent work on this stream. writer.format( - "GetWorkStream: %d buffers, %s inflight budget allowed.", - workItemAssemblers.size(), inFlightBudget.get()); + "GetWorkStream: %d buffers, " + "last sent request: %s; ", + workItemAssemblers.size(), lastRequest.get()); + writer.print(budgetTracker.debugString()); } @Override - public void sendHealthCheck() { - send(HEALTH_CHECK_REQUEST); + public void sendHealthCheck() throws WindmillStreamShutdownException { + trySend(HEALTH_CHECK_REQUEST); } + @Override + protected void shutdownInternal() {} + @Override protected void onResponse(StreamingGetWorkResponseChunk chunk) { getWorkThrottleTimer.stop(); @@ -235,26 +247,22 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) { } private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) { - // Record the fact that there are now fewer outstanding messages and bytes on the stream. - inFlightBudget.updateAndGet(budget -> budget.subtract(1, assembledWorkItem.bufferedSize())); WorkItem workItem = assembledWorkItem.workItem(); GetWorkResponseChunkAssembler.ComputationMetadata metadata = assembledWorkItem.computationMetadata(); - pendingResponseBudget.getAndUpdate(budget -> budget.apply(1, workItem.getSerializedSize())); - try { - workItemScheduler.scheduleWork( - workItem, - createWatermarks(workItem, Preconditions.checkNotNull(metadata)), - createProcessingContext(Preconditions.checkNotNull(metadata.computationId())), - assembledWorkItem.latencyAttributions()); - } finally { - pendingResponseBudget.getAndUpdate(budget -> budget.apply(-1, -workItem.getSerializedSize())); - } + workItemScheduler.scheduleWork( + workItem, + createWatermarks(workItem, metadata), + createProcessingContext(metadata.computationId()), + assembledWorkItem.latencyAttributions()); + budgetTracker.recordBudgetReceived(assembledWorkItem.bufferedSize()); + GetWorkBudget extension = budgetTracker.computeBudgetExtension(); + maybeSendRequestExtension(extension); } private Work.ProcessingContext createProcessingContext(String computationId) { return Work.createProcessingContext( - computationId, getDataClient.get(), workCommitter.get()::commit, heartbeatSender.get()); + computationId, getDataClient, workCommitter::commit, heartbeatSender, backendWorkerToken()); } @Override @@ -263,25 +271,102 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - GetWorkBudget adjustment = - nextBudgetAdjustment - // Get the current value, and reset the nextBudgetAdjustment. This will be set again - // when adjustBudget is called. - .getAndUpdate(unused -> GetWorkBudget.noBudget()) - .apply(itemsDelta, bytesDelta); - sendRequestExtension(adjustment); + public void setBudget(GetWorkBudget newBudget) { + GetWorkBudget extension = budgetTracker.consumeAndComputeBudgetUpdate(newBudget); + maybeSendRequestExtension(extension); } - @Override - public GetWorkBudget remainingBudget() { - // Snapshot the current budgets. - GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get(); - GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get(); - GetWorkBudget currentInflightBudget = inFlightBudget.get(); - - return currentPendingResponseBudget - .apply(currentNextBudgetAdjustment) - .apply(currentInflightBudget); + /** + * Tracks sent, received, max {@link GetWorkBudget} and uses this information to generate request + * extensions. + */ + @ThreadSafe + private static final class GetWorkBudgetTracker { + + @GuardedBy("GetWorkBudgetTracker.this") + private GetWorkBudget maxGetWorkBudget; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsReceived = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesReceived = 0; + + private GetWorkBudgetTracker(GetWorkBudget maxGetWorkBudget) { + this.maxGetWorkBudget = maxGetWorkBudget; + } + + private synchronized void reset() { + itemsRequested = 0; + bytesRequested = 0; + itemsReceived = 0; + bytesReceived = 0; + } + + private synchronized String debugString() { + return String.format( + "max budget: %s; " + + "in-flight budget: %s; " + + "total budget requested: %s; " + + "total budget received: %s.", + maxGetWorkBudget, inFlightBudget(), totalRequestedBudget(), totalReceivedBudget()); + } + + /** Consumes the new budget and computes an extension based on the new budget. */ + private synchronized GetWorkBudget consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) { + maxGetWorkBudget = newBudget; + return computeBudgetExtension(); + } + + private synchronized void recordBudgetRequested(GetWorkBudget budgetRequested) { + itemsRequested += budgetRequested.items(); + bytesRequested += budgetRequested.bytes(); + } + + private synchronized void recordBudgetReceived(long returnedBudget) { + itemsReceived++; + bytesReceived += returnedBudget; + } + + /** + * If the outstanding items or bytes limit has gotten too low, top both off with a + * GetWorkExtension. The goal is to keep the limits relatively close to their maximum values + * without sending too many extension requests. + */ + private synchronized GetWorkBudget computeBudgetExtension() { + // Expected items and bytes can go negative here, since WorkItems returned might be larger + // than the initially requested budget. + long inFlightItems = itemsRequested - itemsReceived; + long inFlightBytes = bytesRequested - bytesReceived; + + // Don't send negative budget extensions. + long requestBytes = Math.max(0, maxGetWorkBudget.bytes() - inFlightBytes); + long requestItems = Math.max(0, maxGetWorkBudget.items() - inFlightItems); + + return (inFlightItems > requestItems / 2 && inFlightBytes > requestBytes / 2) + ? GetWorkBudget.noBudget() + : GetWorkBudget.builder().setItems(requestItems).setBytes(requestBytes).build(); + } + + private synchronized GetWorkBudget inFlightBudget() { + return GetWorkBudget.builder() + .setItems(itemsRequested - itemsReceived) + .setBytes(bytesRequested - bytesReceived) + .build(); + } + + private synchronized GetWorkBudget totalRequestedBudget() { + return GetWorkBudget.builder().setItems(itemsRequested).setBytes(bytesRequested).build(); + } + + private synchronized GetWorkBudget totalReceivedBudget() { + return GetWorkBudget.builder().setItems(itemsReceived).setBytes(bytesReceived).build(); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index f96464150d4a..234888831779 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -30,7 +30,6 @@ import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; @@ -53,8 +52,6 @@ public class GrpcDispatcherClient { private static final Logger LOG = LoggerFactory.getLogger(GrpcDispatcherClient.class); - static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS = - "streaming_engine_use_job_settings_for_isolated_channels"; private final CountDownLatch onInitializedEndpoints; /** @@ -80,18 +77,12 @@ private GrpcDispatcherClient( DispatcherStubs initialDispatcherStubs, Random rand) { this.windmillStubFactoryFactory = windmillStubFactoryFactory; - if (DataflowRunner.hasExperiment( - options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)) { - if (options.getUseWindmillIsolatedChannels() != null) { - this.useIsolatedChannels.set(options.getUseWindmillIsolatedChannels()); - this.reactToIsolatedChannelsJobSetting = false; - } else { - this.useIsolatedChannels.set(false); - this.reactToIsolatedChannelsJobSetting = true; - } - } else { - this.useIsolatedChannels.set(Boolean.TRUE.equals(options.getUseWindmillIsolatedChannels())); + if (options.getUseWindmillIsolatedChannels() != null) { + this.useIsolatedChannels.set(options.getUseWindmillIsolatedChannels()); this.reactToIsolatedChannelsJobSetting = false; + } else { + this.useIsolatedChannels.set(false); + this.reactToIsolatedChannelsJobSetting = true; } this.windmillStubFactory.set( windmillStubFactoryFactory.makeWindmillStubFactory(useIsolatedChannels.get())); @@ -159,6 +150,8 @@ public CloudWindmillMetadataServiceV1Alpha1Stub getWindmillMetadataServiceStubBl } } + LOG.info("Windmill Service endpoint initialized after {} seconds.", secondsWaited); + ImmutableList windmillMetadataServiceStubs = dispatcherStubs.get().windmillMetadataServiceStubs(); @@ -199,7 +192,7 @@ public void onJobConfig(StreamingGlobalConfig config) { public synchronized void consumeWindmillDispatcherEndpoints( ImmutableSet dispatcherEndpoints) { - consumeWindmillDispatcherEndpoints(dispatcherEndpoints, /*forceRecreateStubs=*/ false); + consumeWindmillDispatcherEndpoints(dispatcherEndpoints, /* forceRecreateStubs= */ false); } private synchronized void consumeWindmillDispatcherEndpoints( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 0e9a0c6316ee..b5b49c8ee976 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -31,10 +31,11 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest; @@ -49,6 +50,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedBatch; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; @@ -59,12 +61,17 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class GrpcGetDataStream +@ThreadSafe +final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); + private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = + StreamingGetDataRequest.newBuilder().build(); + /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@code this} */ private final Deque batches; + private final Map pending; private final AtomicLong idGenerator; private final ThrottleTimer getDataThrottleTimer; @@ -90,6 +97,7 @@ private GrpcGetDataStream( boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses) { super( + LOG, "GetDataStream", startGetDataRpcFn, backoff, @@ -107,7 +115,7 @@ private GrpcGetDataStream( this.processHeartbeatResponses = processHeartbeatResponses; } - public static GrpcGetDataStream create( + static GrpcGetDataStream create( String backendWorkerToken, Function, StreamObserver> startGetDataRpcFn, @@ -121,32 +129,44 @@ public static GrpcGetDataStream create( int streamingRpcBatchLimit, boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses) { - GrpcGetDataStream getDataStream = - new GrpcGetDataStream( - backendWorkerToken, - startGetDataRpcFn, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getDataThrottleTimer, - jobHeader, - idGenerator, - streamingRpcBatchLimit, - sendKeyedGetDataRequests, - processHeartbeatResponses); - getDataStream.startStream(); - return getDataStream; + return new GrpcGetDataStream( + backendWorkerToken, + startGetDataRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getDataThrottleTimer, + jobHeader, + idGenerator, + streamingRpcBatchLimit, + sendKeyedGetDataRequests, + processHeartbeatResponses); + } + + private static WindmillStreamShutdownException shutdownExceptionFor(QueuedBatch batch) { + return new WindmillStreamShutdownException( + "Stream was closed when attempting to send " + batch.requestsCount() + " requests."); + } + + private static WindmillStreamShutdownException shutdownExceptionFor(QueuedRequest request) { + return new WindmillStreamShutdownException( + "Cannot send request=[" + request + "] on closed stream."); + } + + private void sendIgnoringClosed(StreamingGetDataRequest getDataRequest) + throws WindmillStreamShutdownException { + trySend(getDataRequest); } @Override - protected synchronized void onNewStream() { - send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); - if (clientClosed.get()) { + protected synchronized void onNewStream() throws WindmillStreamShutdownException { + trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); + if (clientClosed) { // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. - verify(!hasPendingRequests()); + verify(!hasPendingRequests(), "Pending requests not expected if we've half-closed."); } else { for (AppendableInputStream responseStream : pending.values()) { responseStream.cancel(); @@ -160,7 +180,6 @@ protected boolean hasPendingRequests() { } @Override - @SuppressWarnings("dereference.of.nullable") protected void onResponse(StreamingGetDataResponse chunk) { checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); @@ -168,8 +187,15 @@ protected void onResponse(StreamingGetDataResponse chunk) { onHeartbeatResponse(chunk.getComputationHeartbeatResponseList()); for (int i = 0; i < chunk.getRequestIdCount(); ++i) { - AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); - verify(responseStream != null, "No pending response stream"); + @Nullable AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); + if (responseStream == null) { + synchronized (this) { + // shutdown()/shutdownInternal() cleans up pending, else we expect a pending + // responseStream for every response. + verify(isShutdown, "No pending response stream"); + } + continue; + } responseStream.append(chunk.getSerializedResponse(i).newInput()); if (chunk.getRemainingBytesForResponse() == 0) { responseStream.complete(); @@ -187,23 +213,22 @@ private long uniqueId() { } @Override - public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) { + public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) + throws WindmillStreamShutdownException { return issueRequest( QueuedRequest.forComputation(uniqueId(), computation, request), KeyedGetDataResponse::parseFrom); } @Override - public GlobalData requestGlobalData(GlobalDataRequest request) { + public GlobalData requestGlobalData(GlobalDataRequest request) + throws WindmillStreamShutdownException { return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); } @Override - public void refreshActiveWork(Map> heartbeats) { - if (isShutdown()) { - throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); - } - + public void refreshActiveWork(Map> heartbeats) + throws WindmillStreamShutdownException { StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); if (sendKeyedGetDataRequests) { long builderBytes = 0; @@ -214,7 +239,7 @@ public void refreshActiveWork(Map> heartbea if (builderBytes > 0 && (builderBytes + bytes > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE || builder.getRequestIdCount() >= streamingRpcBatchLimit)) { - send(builder.build()); + sendIgnoringClosed(builder.build()); builderBytes = 0; builder.clear(); } @@ -233,7 +258,7 @@ public void refreshActiveWork(Map> heartbea } if (builderBytes > 0) { - send(builder.build()); + sendIgnoringClosed(builder.build()); } } else { // No translation necessary, but we must still respect `RPC_STREAM_CHUNK_SIZE`. @@ -248,7 +273,7 @@ public void refreshActiveWork(Map> heartbea if (computationHeartbeatBuilder.getHeartbeatRequestsCount() > 0) { builder.addComputationHeartbeatRequest(computationHeartbeatBuilder.build()); } - send(builder.build()); + sendIgnoringClosed(builder.build()); builderBytes = 0; builder.clear(); computationHeartbeatBuilder.clear().setComputationId(entry.getKey()); @@ -260,7 +285,7 @@ public void refreshActiveWork(Map> heartbea } if (builderBytes > 0) { - send(builder.build()); + sendIgnoringClosed(builder.build()); } } } @@ -271,12 +296,26 @@ public void onHeartbeatResponse(List resp } @Override - public void sendHealthCheck() { + public void sendHealthCheck() throws WindmillStreamShutdownException { if (hasPendingRequests()) { - send(StreamingGetDataRequest.newBuilder().build()); + trySend(HEALTH_CHECK_REQUEST); } } + @Override + protected synchronized void shutdownInternal() { + // Stream has been explicitly closed. Drain pending input streams and request batches. + // Future calls to send RPCs will fail. + pending.values().forEach(AppendableInputStream::cancel); + pending.clear(); + batches.forEach( + batch -> { + batch.markFinalized(); + batch.notifyFailed(); + }); + batches.clear(); + } + @Override public void appendSpecificHtml(PrintWriter writer) { writer.format( @@ -301,20 +340,23 @@ public void appendSpecificHtml(PrintWriter writer) { writer.append("]"); } - private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) { + private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) + throws WindmillStreamShutdownException { while (true) { request.resetResponseStream(); try { queueRequestAndWait(request); return parseFn.parse(request.getResponseStream()); - } catch (CancellationException e) { - // Retry issuing the request since the response stream was cancelled. - continue; + } catch (AppendableInputStream.InvalidInputStreamStateException | CancellationException e) { + throwIfShutdown(request, e); + if (!(e instanceof CancellationException)) { + throw e; + } } catch (IOException e) { LOG.error("Parsing GetData response failed: ", e); - continue; } catch (InterruptedException e) { Thread.currentThread().interrupt(); + throwIfShutdown(request, e); throw new RuntimeException(e); } finally { pending.remove(request.id()); @@ -322,18 +364,32 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn= streamingRpcBatchLimit + || batch.requestsCount() >= streamingRpcBatchLimit || batch.byteSize() + request.byteSize() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { if (batch != null) { - waitForSendLatch = batch.getLatch(); + prevBatch = batch; } batch = new QueuedBatch(); batches.addLast(batch); @@ -342,64 +398,80 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept batch.addRequest(request); } if (responsibleForSend) { - if (waitForSendLatch == null) { + if (prevBatch == null) { // If there was not a previous batch wait a little while to improve // batching. - Thread.sleep(1); + sleeper.sleep(1); } else { - waitForSendLatch.await(); + prevBatch.waitForSendOrFailNotification(); } // Finalize the batch so that no additional requests will be added. Leave the batch in the // queue so that a subsequent batch will wait for its completion. - synchronized (batches) { - verify(batch == batches.peekFirst()); + synchronized (this) { + if (isShutdown) { + throw shutdownExceptionFor(batch); + } + + verify(batch == batches.peekFirst(), "GetDataStream request batch removed before send()."); batch.markFinalized(); } - sendBatch(batch.requests()); - synchronized (batches) { - verify(batch == batches.pollFirst()); + trySendBatch(batch); + } else { + // Wait for this batch to be sent before parsing the response. + batch.waitForSendOrFailNotification(); + } + } + + void trySendBatch(QueuedBatch batch) throws WindmillStreamShutdownException { + try { + sendBatch(batch); + synchronized (this) { + if (isShutdown) { + throw shutdownExceptionFor(batch); + } + + verify( + batch == batches.pollFirst(), + "Sent GetDataStream request batch removed before send() was complete."); } // Notify all waiters with requests in this batch as well as the sender // of the next batch (if one exists). - batch.countDown(); - } else { - // Wait for this batch to be sent before parsing the response. - batch.await(); + batch.notifySent(); + } catch (Exception e) { + // Free waiters if the send() failed. + batch.notifyFailed(); + // Propagate the exception to the calling thread. + throw e; } } - @SuppressWarnings("NullableProblems") - private void sendBatch(List requests) { - StreamingGetDataRequest batchedRequest = flushToBatch(requests); + private void sendBatch(QueuedBatch batch) throws WindmillStreamShutdownException { + if (batch.isEmpty()) { + return; + } + + // Synchronization of pending inserts is necessary with send to ensure duplicates are not + // sent on stream reconnect. synchronized (this) { - // Synchronization of pending inserts is necessary with send to ensure duplicates are not - // sent on stream reconnect. - for (QueuedRequest request : requests) { + if (isShutdown) { + throw shutdownExceptionFor(batch); + } + + for (QueuedRequest request : batch.requestsReadOnly()) { // Map#put returns null if there was no previous mapping for the key, meaning we have not // seen it before. - verify(pending.put(request.id(), request.getResponseStream()) == null); + verify( + pending.put(request.id(), request.getResponseStream()) == null, + "Request already sent."); } - try { - send(batchedRequest); - } catch (IllegalStateException e) { + + if (!trySend(batch.asGetDataRequest())) { // The stream broke before this call went through; onNewStream will retry the fetch. - LOG.warn("GetData stream broke before call started.", e); + LOG.warn("GetData stream broke before call started."); } } } - @SuppressWarnings("argument") - private StreamingGetDataRequest flushToBatch(List requests) { - // Put all global data requests first because there is only a single repeated field for - // request ids and the initial ids correspond to global data requests if they are present. - requests.sort(QueuedRequest.globalRequestsFirst()); - StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); - for (QueuedRequest request : requests) { - request.addToStreamingGetDataRequest(builder); - } - return builder.build(); - } - @FunctionalInterface private interface ParseFn { ResponseT parse(InputStream input) throws IOException; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java index cda9537127d9..ef7f5b20bb07 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java @@ -17,20 +17,35 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; + import com.google.auto.value.AutoOneOf; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.stream.Stream; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** Utility data classes for {@link GrpcGetDataStream}. */ final class GrpcGetDataStreamRequests { + private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStreamRequests.class); + private static final int STREAM_CANCELLED_ERROR_LOG_LIMIT = 3; + private GrpcGetDataStreamRequests() {} + private static String debugFormat(long value) { + return String.format("%016x", value); + } + static class QueuedRequest { private final long id; private final ComputationOrGlobalDataRequest dataRequest; @@ -81,6 +96,10 @@ void resetResponseStream() { this.responseStream = new AppendableInputStream(); } + public ComputationOrGlobalDataRequest getDataRequest() { + return dataRequest; + } + void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder builder) { builder.addRequestId(id); if (dataRequest.isForComputation()) { @@ -89,20 +108,51 @@ void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder build builder.addGlobalDataRequest(dataRequest.global()); } } + + @Override + public final String toString() { + return "QueuedRequest{" + "dataRequest=" + dataRequest + ", id=" + id + '}'; + } } + /** + * Represents a batch of queued requests. Methods are not thread-safe unless commented otherwise. + */ static class QueuedBatch { private final List requests = new ArrayList<>(); private final CountDownLatch sent = new CountDownLatch(1); private long byteSize = 0; - private boolean finalized = false; + private volatile boolean finalized = false; + private volatile boolean failed = false; - CountDownLatch getLatch() { - return sent; + /** Returns a read-only view of requests. */ + List requestsReadOnly() { + return Collections.unmodifiableList(requests); } - List requests() { - return requests; + /** + * Converts the batch to a {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest}. + */ + Windmill.StreamingGetDataRequest asGetDataRequest() { + Windmill.StreamingGetDataRequest.Builder builder = + Windmill.StreamingGetDataRequest.newBuilder(); + + requests.stream() + // Put all global data requests first because there is only a single repeated field for + // request ids and the initial ids correspond to global data requests if they are present. + .sorted(QueuedRequest.globalRequestsFirst()) + .forEach(request -> request.addToStreamingGetDataRequest(builder)); + + return builder.build(); + } + + boolean isEmpty() { + return requests.isEmpty(); + } + + int requestsCount() { + return requests.size(); } long byteSize() { @@ -117,17 +167,83 @@ void markFinalized() { finalized = true; } + /** Adds a request to the batch. */ void addRequest(QueuedRequest request) { requests.add(request); byteSize += request.byteSize(); } - void countDown() { + /** + * Let waiting for threads know that the request has been successfully sent. + * + * @implNote Thread safe. + */ + void notifySent() { + sent.countDown(); + } + + /** + * Let waiting for threads know that a failure occurred. + * + * @implNote Thread safe. + */ + void notifyFailed() { + failed = true; sent.countDown(); } - void await() throws InterruptedException { + /** + * Block until notified of a successful send via {@link #notifySent()} or a non-retryable + * failure via {@link #notifyFailed()}. On failure, throw an exception for waiters. + * + * @implNote Thread safe. + */ + void waitForSendOrFailNotification() + throws InterruptedException, WindmillStreamShutdownException { sent.await(); + if (failed) { + ImmutableList cancelledRequests = createStreamCancelledErrorMessages(); + if (!cancelledRequests.isEmpty()) { + LOG.error("Requests failed for the following batches: {}", cancelledRequests); + throw new WindmillStreamShutdownException( + "Requests failed for batch containing " + + String.join(", ", cancelledRequests) + + " ... requests. This is most likely due to the stream being explicitly closed" + + " which happens when the work is marked as invalid on the streaming" + + " backend when key ranges shuffle around. This is transient and corresponding" + + " work will eventually be retried."); + } + + throw new WindmillStreamShutdownException("Stream was shutdown while waiting for send."); + } + } + + private ImmutableList createStreamCancelledErrorMessages() { + return requests.stream() + .flatMap( + request -> { + switch (request.getDataRequest().getKind()) { + case GLOBAL: + return Stream.of("GetSideInput=" + request.getDataRequest().global()); + case COMPUTATION: + return request.getDataRequest().computation().getRequestsList().stream() + .map( + keyedRequest -> + "KeyedGetState=[" + + "shardingKey=" + + debugFormat(keyedRequest.getShardingKey()) + + "cacheToken=" + + debugFormat(keyedRequest.getCacheToken()) + + "workToken" + + debugFormat(keyedRequest.getWorkToken()) + + "]"); + default: + // Will never happen switch is exhaustive. + throw new IllegalStateException(); + } + }) + .limit(STREAM_CANCELLED_ERROR_LOG_LIMIT) + .collect(toImmutableList()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index 09ecbf3f3051..fcfefab71c8c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GetWorkResponseChunkAssembler.AssembledWorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; @@ -36,11 +37,15 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; final class GrpcGetWorkStream extends AbstractWindmillStream implements GetWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcGetWorkStream.class); + private final GetWorkRequest request; private final WorkItemReceiver receiver; private final ThrottleTimer getWorkThrottleTimer; @@ -62,6 +67,7 @@ private GrpcGetWorkStream( ThrottleTimer getWorkThrottleTimer, WorkItemReceiver receiver) { super( + LOG, "GetWorkStream", startGetWorkRpcFn, backoff, @@ -90,19 +96,16 @@ public static GrpcGetWorkStream create( int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, WorkItemReceiver receiver) { - GrpcGetWorkStream getWorkStream = - new GrpcGetWorkStream( - backendWorkerToken, - startGetWorkRpcFn, - request, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getWorkThrottleTimer, - receiver); - getWorkStream.startStream(); - return getWorkStream; + return new GrpcGetWorkStream( + backendWorkerToken, + startGetWorkRpcFn, + request, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + receiver); } private void sendRequestExtension(long moreItems, long moreBytes) { @@ -114,25 +117,27 @@ private void sendRequestExtension(long moreItems, long moreBytes) { .setMaxBytes(moreBytes)) .build(); - executor() - .execute( - () -> { - try { - send(extension); - } catch (IllegalStateException e) { - // Stream was closed. - } - }); + executeSafely( + () -> { + try { + trySend(extension); + } catch (WindmillStreamShutdownException e) { + // Stream was closed. + } + }); } @Override - protected synchronized void onNewStream() { + protected synchronized void onNewStream() throws WindmillStreamShutdownException { workItemAssemblers.clear(); inflightMessages.set(request.getMaxItems()); inflightBytes.set(request.getMaxBytes()); - send(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); + trySend(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); } + @Override + protected void shutdownInternal() {} + @Override protected boolean hasPendingRequests() { return false; @@ -147,8 +152,8 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - public void sendHealthCheck() { - send( + public void sendHealthCheck() throws WindmillStreamShutdownException { + trySend( StreamingGetWorkRequest.newBuilder() .setRequestExtension( StreamingGetWorkRequestExtension.newBuilder().setMaxItems(0).setMaxBytes(0).build()) @@ -194,15 +199,7 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op } - - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setBytes(request.getMaxBytes() - inflightBytes.get()) - .setItems(request.getMaxItems() - inflightMessages.get()) - .build(); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index 44e21a9b18ed..4ce2f651f0b7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; @@ -47,9 +48,6 @@ public final class GrpcGetWorkerMetadataStream private final Consumer serverMappingConsumer; private final Object metadataLock; - @GuardedBy("metadataLock") - private long metadataVersion; - @GuardedBy("metadataLock") private WorkerMetadataResponse latestResponse; @@ -61,10 +59,10 @@ private GrpcGetWorkerMetadataStream( Set> streamRegistry, int logEveryNStreamFailures, JobHeader jobHeader, - long metadataVersion, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer serverMappingConsumer) { super( + LOG, "GetWorkerMetadataStream", startGetWorkerMetadataRpcFn, backoff, @@ -73,7 +71,6 @@ private GrpcGetWorkerMetadataStream( logEveryNStreamFailures, ""); this.workerMetadataRequest = WorkerMetadataRequest.newBuilder().setHeader(jobHeader).build(); - this.metadataVersion = metadataVersion; this.getWorkerMetadataThrottleTimer = getWorkerMetadataThrottleTimer; this.serverMappingConsumer = serverMappingConsumer; this.latestResponse = WorkerMetadataResponse.getDefaultInstance(); @@ -88,23 +85,17 @@ public static GrpcGetWorkerMetadataStream create( Set> streamRegistry, int logEveryNStreamFailures, JobHeader jobHeader, - int metadataVersion, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer serverMappingUpdater) { - GrpcGetWorkerMetadataStream getWorkerMetadataStream = - new GrpcGetWorkerMetadataStream( - startGetWorkerMetadataRpcFn, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - jobHeader, - metadataVersion, - getWorkerMetadataThrottleTimer, - serverMappingUpdater); - LOG.info("Started GetWorkerMetadataStream. {}", getWorkerMetadataStream); - getWorkerMetadataStream.startStream(); - return getWorkerMetadataStream; + return new GrpcGetWorkerMetadataStream( + startGetWorkerMetadataRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + jobHeader, + getWorkerMetadataThrottleTimer, + serverMappingUpdater); } /** @@ -118,25 +109,23 @@ protected void onResponse(WorkerMetadataResponse response) { /** * Acquires the {@link #metadataLock} Returns {@link Optional} if the - * metadataVersion in the response is not stale (older or equal to {@link #metadataVersion}), else - * returns empty {@link Optional}. + * metadataVersion in the response is not stale (older or equal to current {@link + * WorkerMetadataResponse#getMetadataVersion()}), else returns empty {@link Optional}. */ private Optional extractWindmillEndpointsFrom( WorkerMetadataResponse response) { synchronized (metadataLock) { - if (response.getMetadataVersion() > this.metadataVersion) { - this.metadataVersion = response.getMetadataVersion(); + if (response.getMetadataVersion() > latestResponse.getMetadataVersion()) { this.latestResponse = response; return Optional.of(WindmillEndpoints.from(response)); } else { // If the currentMetadataVersion is greater than or equal to one in the response, the // response data is stale, and we do not want to do anything. - LOG.info( - "Received WorkerMetadataResponse={}; Received metadata version={}; Current metadata version={}. " + LOG.debug( + "Received metadata version={}; Current metadata version={}. " + "Skipping update because received stale metadata", - response, response.getMetadataVersion(), - this.metadataVersion); + latestResponse.getMetadataVersion()); } } @@ -144,10 +133,13 @@ private Optional extractWindmillEndpointsFrom( } @Override - protected synchronized void onNewStream() { - send(workerMetadataRequest); + protected void onNewStream() throws WindmillStreamShutdownException { + trySend(workerMetadataRequest); } + @Override + protected void shutdownInternal() {} + @Override protected boolean hasPendingRequests() { return false; @@ -159,16 +151,16 @@ protected void startThrottleTimer() { } @Override - protected void sendHealthCheck() { - send(HEALTH_CHECK_REQUEST); + protected void sendHealthCheck() throws WindmillStreamShutdownException { + trySend(HEALTH_CHECK_REQUEST); } @Override protected void appendSpecificHtml(PrintWriter writer) { synchronized (metadataLock) { writer.format( - "GetWorkerMetadataStream: version=[%d] , job_header=[%s], latest_response=[%s]", - this.metadataVersion, workerMetadataRequest.getHeader(), this.latestResponse); + "GetWorkerMetadataStream: job_header=[%s], current_metadata=[%s]", + workerMetadataRequest.getHeader(), latestResponse); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java index 310495982679..f35b9b23d091 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java @@ -290,13 +290,13 @@ private ResponseT callWithBackoff(Supplier function) { e.getStatus()); } if (!BackOffUtils.next(Sleeper.DEFAULT, backoff)) { - throw new RpcException(e); + throw new WindmillRpcException(e); } } catch (IOException | InterruptedException i) { if (i instanceof InterruptedException) { Thread.currentThread().interrupt(); } - RpcException rpcException = new RpcException(e); + WindmillRpcException rpcException = new WindmillRpcException(e); rpcException.addSuppressed(i); throw rpcException; } @@ -310,7 +310,7 @@ public GetWorkResponse getWork(GetWorkRequest request) { return callWithBackoff(() -> syncApplianceStub.getWork(request)); } - throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("GetWork")); + throw new WindmillRpcException(unsupportedUnaryRequestInStreamingEngineException("GetWork")); } @Override @@ -319,7 +319,7 @@ public GetDataResponse getData(GetDataRequest request) { return callWithBackoff(() -> syncApplianceStub.getData(request)); } - throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("GetData")); + throw new WindmillRpcException(unsupportedUnaryRequestInStreamingEngineException("GetData")); } @Override @@ -327,32 +327,53 @@ public CommitWorkResponse commitWork(CommitWorkRequest request) { if (syncApplianceStub != null) { return callWithBackoff(() -> syncApplianceStub.commitWork(request)); } - throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("CommitWork")); + throw new WindmillRpcException(unsupportedUnaryRequestInStreamingEngineException("CommitWork")); } + /** + * @implNote Returns a {@link GetWorkStream} in the started state (w/ the initial header already + * sent). + */ @Override public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { - return windmillStreamFactory.createGetWorkStream( - dispatcherClient.getWindmillServiceStub(), - GetWorkRequest.newBuilder(request) - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .setWorkerId(options.getWorkerId()) - .build(), - throttleTimers.getWorkThrottleTimer(), - receiver); + GetWorkStream getWorkStream = + windmillStreamFactory.createGetWorkStream( + dispatcherClient.getWindmillServiceStub(), + GetWorkRequest.newBuilder(request) + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .build(), + throttleTimers.getWorkThrottleTimer(), + receiver); + getWorkStream.start(); + return getWorkStream; } + /** + * @implNote Returns a {@link GetDataStream} in the started state (w/ the initial header already + * sent). + */ @Override public GetDataStream getDataStream() { - return windmillStreamFactory.createGetDataStream( - dispatcherClient.getWindmillServiceStub(), throttleTimers.getDataThrottleTimer()); + GetDataStream getDataStream = + windmillStreamFactory.createGetDataStream( + dispatcherClient.getWindmillServiceStub(), throttleTimers.getDataThrottleTimer()); + getDataStream.start(); + return getDataStream; } + /** + * @implNote Returns a {@link CommitWorkStream} in the started state (w/ the initial header + * already sent). + */ @Override public CommitWorkStream commitWorkStream() { - return windmillStreamFactory.createCommitWorkStream( - dispatcherClient.getWindmillServiceStub(), throttleTimers.commitWorkThrottleTimer()); + CommitWorkStream commitWorkStream = + windmillStreamFactory.createCommitWorkStream( + dispatcherClient.getWindmillServiceStub(), throttleTimers.commitWorkThrottleTimer()); + commitWorkStream.start(); + return commitWorkStream; } @Override @@ -361,7 +382,7 @@ public GetConfigResponse getConfig(GetConfigRequest request) { return callWithBackoff(() -> syncApplianceStub.getConfig(request)); } - throw new RpcException( + throw new WindmillRpcException( new UnsupportedOperationException("GetConfig not supported in Streaming Engine.")); } @@ -371,7 +392,7 @@ public ReportStatsResponse reportStats(ReportStatsRequest request) { return callWithBackoff(() -> syncApplianceStub.reportStats(request)); } - throw new RpcException( + throw new WindmillRpcException( new UnsupportedOperationException("ReportStats not supported in Streaming Engine.")); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 92f031db9972..df69af207899 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -17,10 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; -import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; import com.google.auto.value.AutoBuilder; import java.io.PrintWriter; +import java.util.Collection; import java.util.List; import java.util.Set; import java.util.Timer; @@ -29,6 +30,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; @@ -55,7 +57,9 @@ import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.AbstractStub; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.joda.time.Duration; import org.joda.time.Instant; @@ -66,6 +70,8 @@ @ThreadSafe @Internal public class GrpcWindmillStreamFactory implements StatusDataProvider { + + private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; private static final Duration MIN_BACKOFF = Duration.millis(1); private static final Duration DEFAULT_MAX_BACKOFF = Duration.standardSeconds(30); private static final int DEFAULT_LOG_EVERY_N_STREAM_FAILURES = 1; @@ -73,6 +79,7 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider { private static final int DEFAULT_WINDMILL_MESSAGES_BETWEEN_IS_READY_CHECKS = 1; private static final int NO_HEALTH_CHECKS = -1; private static final String NO_BACKEND_WORKER_TOKEN = ""; + private static final String DISPATCHER_DEBUG_NAME = "Dispatcher"; private final JobHeader jobHeader; private final int logEveryNStreamFailures; @@ -173,8 +180,20 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) { private static > T withDefaultDeadline(T stub) { // Deadlines are absolute points in time, so generate a new one everytime this function is // called. - return stub.withDeadlineAfter( - AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS); + return stub.withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS); + } + + private static void printSummaryHtmlForWorker( + String workerToken, Collection> streams, PrintWriter writer) { + writer.write( + "" + (workerToken.isEmpty() ? DISPATCHER_DEBUG_NAME : workerToken) + ""); + writer.write("
"); + streams.forEach( + stream -> { + stream.appendSummaryHtml(writer); + writer.write("
"); + }); + writer.write("
"); } public GetWorkStream createGetWorkStream( @@ -198,13 +217,13 @@ public GetWorkStream createDirectGetWorkStream( WindmillConnection connection, GetWorkRequest request, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( connection.backendWorkerToken(), - responseObserver -> withDefaultDeadline(connection.stub()).getWorkStream(responseObserver), + responseObserver -> connection.stub().getWorkStream(responseObserver), request, grpcBackOff.get(), newStreamObserverFactory(), @@ -234,6 +253,23 @@ public GetDataStream createGetDataStream( processHeartbeatResponses); } + public GetDataStream createDirectGetDataStream( + WindmillConnection connection, ThrottleTimer getDataThrottleTimer) { + return GrpcGetDataStream.create( + connection.backendWorkerToken(), + responseObserver -> connection.stub().getDataStream(responseObserver), + grpcBackOff.get(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + getDataThrottleTimer, + jobHeader, + streamIdGenerator, + streamingRpcBatchLimit, + sendKeyedGetDataRequests, + processHeartbeatResponses); + } + public CommitWorkStream createCommitWorkStream( CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer commitWorkThrottleTimer) { return GrpcCommitWorkStream.create( @@ -249,18 +285,32 @@ public CommitWorkStream createCommitWorkStream( streamingRpcBatchLimit); } + public CommitWorkStream createDirectCommitWorkStream( + WindmillConnection connection, ThrottleTimer commitWorkThrottleTimer) { + return GrpcCommitWorkStream.create( + connection.backendWorkerToken(), + responseObserver -> connection.stub().commitWorkStream(responseObserver), + grpcBackOff.get(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + commitWorkThrottleTimer, + jobHeader, + streamIdGenerator, + streamingRpcBatchLimit); + } + public GetWorkerMetadataStream createGetWorkerMetadataStream( - CloudWindmillMetadataServiceV1Alpha1Stub stub, + Supplier stub, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer onNewWindmillEndpoints) { return GrpcGetWorkerMetadataStream.create( - responseObserver -> withDefaultDeadline(stub).getWorkerMetadata(responseObserver), + responseObserver -> withDefaultDeadline(stub.get()).getWorkerMetadata(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, jobHeader, - 0, getWorkerMetadataThrottleTimer, onNewWindmillEndpoints); } @@ -273,10 +323,17 @@ private StreamObserverFactory newStreamObserverFactory() { @Override public void appendSummaryHtml(PrintWriter writer) { writer.write("Active Streams:
"); - for (AbstractWindmillStream stream : streamRegistry) { - stream.appendSummaryHtml(writer); - writer.write("
"); - } + streamRegistry.stream() + .collect( + toImmutableListMultimap( + AbstractWindmillStream::backendWorkerToken, Function.identity())) + .asMap() + .forEach((workerToken, streams) -> printSummaryHtmlForWorker(workerToken, streams, writer)); + } + + @VisibleForTesting + final ImmutableSet> streamRegistry() { + return ImmutableSet.copyOf(streamRegistry); } @Internal diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 9d57df1af317..8710d66d2c80 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -22,8 +22,10 @@ import java.util.concurrent.TimeoutException; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WindmillRpcException; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,27 +39,33 @@ * becomes ready. */ @ThreadSafe -public final class DirectStreamObserver implements StreamObserver { +final class DirectStreamObserver implements TerminatingStreamObserver { private static final Logger LOG = LoggerFactory.getLogger(DirectStreamObserver.class); - private final Phaser phaser; + private static final long OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS = 30; + private final Phaser isReadyNotifier; + private final long deadlineSeconds; + private final int messagesBetweenIsReadyChecks; private final Object lock = new Object(); @GuardedBy("lock") private final CallStreamObserver outboundObserver; - private final long deadlineSeconds; - private final int messagesBetweenIsReadyChecks; + @GuardedBy("lock") + private boolean isOutboundObserverClosed = false; + + @GuardedBy("lock") + private boolean isUserClosed = false; @GuardedBy("lock") private int messagesSinceReady = 0; - public DirectStreamObserver( - Phaser phaser, + DirectStreamObserver( + Phaser isReadyNotifier, CallStreamObserver outboundObserver, long deadlineSeconds, int messagesBetweenIsReadyChecks) { - this.phaser = phaser; + this.isReadyNotifier = isReadyNotifier; this.outboundObserver = outboundObserver; this.deadlineSeconds = deadlineSeconds; // We always let the first message pass through without blocking because it is performed under @@ -66,6 +74,12 @@ public DirectStreamObserver( this.messagesBetweenIsReadyChecks = Math.max(1, messagesBetweenIsReadyChecks); } + /** + * @throws StreamObserverCancelledException if the StreamObserver was closed via {@link + * #onError(Throwable)}, {@link #onCompleted()}, or {@link #terminate(Throwable)} while + * waiting for {@code outboundObserver#isReady}. + * @throws WindmillRpcException if we time out for waiting for {@code outboundObserver#isReady}. + */ @Override public void onNext(T value) { int awaitPhase = -1; @@ -74,6 +88,24 @@ public void onNext(T value) { while (true) { try { synchronized (lock) { + int currentPhase = isReadyNotifier.getPhase(); + // Phaser is terminated so don't use the outboundObserver. Since onError and onCompleted + // are synchronized after terminating the phaser if we observe that the phaser is not + // terminated the onNext calls below are guaranteed to not be called on a closed observer. + if (currentPhase < 0) { + throw new StreamObserverCancelledException("StreamObserver was terminated."); + } + + // Closing is performed under "lock" after terminating, so if termination was not observed + // above, the observer should not be closed. + assert !isOutboundObserverClosed; + + // If we awaited previously and timed out, wait for the same phase. Otherwise we're + // careful to observe the phase before observing isReady. + if (awaitPhase < 0) { + awaitPhase = currentPhase; + } + // We only check isReady periodically to effectively allow for increasing the outbound // buffer periodically. This reduces the overhead of blocking while still restricting // memory because there is a limited # of streams, and we have a max messages size of 2MB. @@ -81,25 +113,40 @@ public void onNext(T value) { outboundObserver.onNext(value); return; } - // If we awaited previously and timed out, wait for the same phase. Otherwise we're - // careful to observe the phase before observing isReady. - if (awaitPhase < 0) { - awaitPhase = phaser.getPhase(); - } + if (outboundObserver.isReady()) { messagesSinceReady = 0; outboundObserver.onNext(value); return; } } + // A callback has been registered to advance the phaser whenever the observer // transitions to is ready. Since we are waiting for a phase observed before the // outboundObserver.isReady() returned false, we expect it to advance after the // channel has become ready. This doesn't always seem to be the case (despite // documentation stating otherwise) so we poll periodically and enforce an overall // timeout related to the stream deadline. - phaser.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, TimeUnit.SECONDS); + int nextPhase = + isReadyNotifier.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, TimeUnit.SECONDS); + // If nextPhase is a value less than 0, the phaser has been terminated. + if (nextPhase < 0) { + throw new StreamObserverCancelledException("StreamObserver was terminated."); + } + synchronized (lock) { + int currentPhase = isReadyNotifier.getPhase(); + // Phaser is terminated so don't use the outboundObserver. Since onError and onCompleted + // are synchronized after terminating the phaser if we observe that the phaser is not + // terminated the onNext calls below are guaranteed to not be called on a closed observer. + if (currentPhase < 0) { + throw new StreamObserverCancelledException("StreamObserver was terminated."); + } + + // Closing is performed under "lock" after terminating, so if termination was not observed + // above, the observer should not be closed. + assert !isOutboundObserverClosed; + messagesSinceReady = 0; outboundObserver.onNext(value); return; @@ -107,36 +154,78 @@ public void onNext(T value) { } catch (TimeoutException e) { totalSecondsWaited += waitSeconds; if (totalSecondsWaited > deadlineSeconds) { - LOG.error( - "Exceeded timeout waiting for the outboundObserver to become ready meaning " - + "that the stream deadline was not respected."); - throw new RuntimeException(e); + String errorMessage = constructStreamCancelledErrorMessage(totalSecondsWaited); + LOG.error(errorMessage); + throw new WindmillRpcException(errorMessage, e); } - if (totalSecondsWaited > 30) { + + if (totalSecondsWaited > OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS) { LOG.info( "Output channel stalled for {}s, outbound thread {}.", totalSecondsWaited, Thread.currentThread().getName()); } + waitSeconds = waitSeconds * 2; } catch (InterruptedException e) { Thread.currentThread().interrupt(); - throw new RuntimeException(e); + throw new StreamObserverCancelledException(e); } } } + /** @throws IllegalStateException if called multiple times or after {@link #onCompleted()}. */ @Override public void onError(Throwable t) { + isReadyNotifier.forceTermination(); synchronized (lock) { - outboundObserver.onError(t); + Preconditions.checkState(!isUserClosed); + isUserClosed = true; + if (!isOutboundObserverClosed) { + outboundObserver.onError(t); + isOutboundObserverClosed = true; + } } } + /** + * @throws IllegalStateException if called multiple times or after {@link #onError(Throwable)}. + */ @Override public void onCompleted() { + isReadyNotifier.forceTermination(); synchronized (lock) { - outboundObserver.onCompleted(); + Preconditions.checkState(!isUserClosed); + isUserClosed = true; + if (!isOutboundObserverClosed) { + outboundObserver.onCompleted(); + isOutboundObserverClosed = true; + } } } + + @Override + public void terminate(Throwable terminationException) { + // Free the blocked threads in onNext(). + isReadyNotifier.forceTermination(); + synchronized (lock) { + if (!isOutboundObserverClosed) { + outboundObserver.onError(terminationException); + isOutboundObserverClosed = true; + } + } + } + + private String constructStreamCancelledErrorMessage(long totalSecondsWaited) { + return deadlineSeconds > 0 + ? "Waited " + + totalSecondsWaited + + "s which exceeds given deadline of " + + deadlineSeconds + + "s for the outboundObserver to become ready meaning " + + "that the stream deadline was not respected." + : "Output channel has been blocked for " + + totalSecondsWaited + + "s. Restarting stream internally."; + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java index 4ea209f31b1d..70fd3497a37f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java @@ -21,11 +21,15 @@ @Internal public final class StreamObserverCancelledException extends RuntimeException { - public StreamObserverCancelledException(Throwable cause) { + StreamObserverCancelledException(Throwable cause) { super(cause); } - public StreamObserverCancelledException(String message, Throwable cause) { + StreamObserverCancelledException(String message, Throwable cause) { super(message, cause); } + + StreamObserverCancelledException(String message) { + super(message); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java index cb4415bdab18..01e854492bf9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java @@ -33,7 +33,7 @@ public static StreamObserverFactory direct( return new Direct(deadlineSeconds, messagesBetweenIsReadyChecks); } - public abstract StreamObserver from( + public abstract TerminatingStreamObserver from( Function, StreamObserver> clientFactory, StreamObserver responseObserver); @@ -47,7 +47,7 @@ private static class Direct extends StreamObserverFactory { } @Override - public StreamObserver from( + public TerminatingStreamObserver from( Function, StreamObserver> clientFactory, StreamObserver inboundObserver) { AdvancingPhaser phaser = new AdvancingPhaser(1); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java new file mode 100644 index 000000000000..5fb4f95e3e1e --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers; + +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; + +@Internal +public interface TerminatingStreamObserver extends StreamObserver { + + /** + * Terminates the StreamObserver. + * + * @implSpec Different then {@link #onError(Throwable)} and {@link #onCompleted()} which can only + * be called once during the lifetime of each {@link StreamObserver}, terminate() + * implementations are meant to be idempotent and can be called multiple times as well as + * being interleaved with other stream operations. + */ + void terminate(Throwable terminationException); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCache.java index db012c6bb412..c03459ee732e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCache.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; import java.io.PrintWriter; +import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.function.Function; @@ -31,6 +32,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListener; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListeners; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,14 +52,11 @@ public final class ChannelCache implements StatusDataProvider { private ChannelCache( Function channelFactory, - RemovalListener onChannelRemoved) { + RemovalListener onChannelRemoved, + Executor channelCloser) { this.channelCache = CacheBuilder.newBuilder() - .removalListener( - RemovalListeners.asynchronous( - onChannelRemoved, - Executors.newCachedThreadPool( - new ThreadFactoryBuilder().setNameFormat("GrpcChannelCloser").build()))) + .removalListener(RemovalListeners.asynchronous(onChannelRemoved, channelCloser)) .build( new CacheLoader() { @Override @@ -72,11 +71,13 @@ public static ChannelCache create( return new ChannelCache( channelFactory, // Shutdown the channels as they get removed from the cache, so they do not leak. - notification -> shutdownChannel(notification.getValue())); + notification -> shutdownChannel(notification.getValue()), + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat("GrpcChannelCloser").build())); } @VisibleForTesting - static ChannelCache forTesting( + public static ChannelCache forTesting( Function channelFactory, Runnable onChannelShutdown) { return new ChannelCache( channelFactory, @@ -85,7 +86,11 @@ static ChannelCache forTesting( notification -> { shutdownChannel(notification.getValue()); onChannelShutdown.run(); - }); + }, + // Run the removal synchronously on the calling thread to prevent waiting on asynchronous + // tasks to run and make unit tests deterministic. In testing, we verify that things are + // removed from the cache. + MoreExecutors.directExecutor()); } private static void shutdownChannel(ManagedChannel channel) { @@ -108,6 +113,7 @@ public void remove(WindmillServiceAddress windmillServiceAddress) { public void clear() { channelCache.invalidateAll(); + channelCache.cleanUp(); } /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java index 9aec29a3ba4d..f0ea2f550a74 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java @@ -36,7 +36,6 @@ /** Utility class used to create different RPC Channels. */ public final class WindmillChannelFactory { public static final String LOCALHOST = "localhost"; - private static final int DEFAULT_GRPC_PORT = 443; private static final int MAX_REMOTE_TRACE_EVENTS = 100; private WindmillChannelFactory() {} @@ -55,8 +54,6 @@ public static Channel localhostChannel(int port) { public static ManagedChannel remoteChannel( WindmillServiceAddress windmillServiceAddress, int windmillServiceRpcChannelTimeoutSec) { switch (windmillServiceAddress.getKind()) { - case IPV6: - return remoteChannel(windmillServiceAddress.ipv6(), windmillServiceRpcChannelTimeoutSec); case GCP_SERVICE_ADDRESS: return remoteChannel( windmillServiceAddress.gcpServiceAddress(), windmillServiceRpcChannelTimeoutSec); @@ -67,7 +64,8 @@ public static ManagedChannel remoteChannel( windmillServiceRpcChannelTimeoutSec); default: throw new UnsupportedOperationException( - "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS are supported WindmillServiceAddresses."); + "Only GCP_SERVICE_ADDRESS and AUTHENTICATED_GCP_SERVICE_ADDRESS are supported" + + " WindmillServiceAddresses."); } } @@ -105,17 +103,6 @@ public static Channel remoteChannel( } } - public static ManagedChannel remoteChannel( - Inet6Address directEndpoint, int windmillServiceRpcChannelTimeoutSec) { - try { - return createRemoteChannel( - NettyChannelBuilder.forAddress(new InetSocketAddress(directEndpoint, DEFAULT_GRPC_PORT)), - windmillServiceRpcChannelTimeoutSec); - } catch (SSLException sslException) { - throw new WindmillChannelCreationException(directEndpoint.toString(), sslException); - } - } - @SuppressWarnings("nullness") private static ManagedChannel createRemoteChannel( NettyChannelBuilder channelBuilder, int windmillServiceRpcChannelTimeoutSec) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottleTimer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottleTimer.java index f660112721ba..fdcb0339d23d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottleTimer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottleTimer.java @@ -25,7 +25,7 @@ * CommitWork are both blocked for x, totalTime will be 2x. However, if 2 GetWork streams are both * blocked for x totalTime will be x. All methods are thread safe. */ -public final class ThrottleTimer { +public final class ThrottleTimer implements ThrottledTimeTracker { // This is -1 if not currently being throttled or the time in // milliseconds when throttling for this type started. private long startTime = -1; @@ -56,6 +56,7 @@ public synchronized boolean throttled() { } /** Returns the combined total of all throttle times and resets those times to 0. */ + @Override public synchronized long getAndResetThrottleTime() { if (throttled()) { stop(); diff --git a/runners/flink/1.16/job-server-container/build.gradle b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottledTimeTracker.java similarity index 54% rename from runners/flink/1.16/job-server-container/build.gradle rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottledTimeTracker.java index afdb68a0fc91..9bb8fb0a7b5f 100644 --- a/runners/flink/1.16/job-server-container/build.gradle +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottledTimeTracker.java @@ -4,23 +4,29 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.beam.runners.dataflow.worker.windmill.client.throttling; -def basePath = '../../job-server-container' +import org.apache.beam.sdk.annotations.Internal; -project.ext { - resource_path = basePath -} +/** + * Tracks time spent in a throttled state due to {@code Status.RESOURCE_EXHAUSTED} errors returned + * from gRPC calls. + */ +@Internal +@FunctionalInterface +public interface ThrottledTimeTracker { -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server_container.gradle" + /** Returns the combined total of all throttle times and resets those times to 0. */ + long getAndResetThrottleTime(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java index aed03f33e6d6..b17631a8bd0a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java @@ -137,12 +137,7 @@ protected Windmill.WorkItemCommitRequest persistDirectly(WindmillStateCache.ForK keyCoder.encode(key, keyStream, Coder.Context.OUTER); ByteString keyBytes = keyStream.toByteString(); // Leaving data blank means that we delete the tag. - commitBuilder - .addValueUpdatesBuilder() - .setTag(keyBytes) - .setStateFamily(stateFamily) - .getValueBuilder() - .setTimestamp(Long.MAX_VALUE); + commitBuilder.addValueUpdatesBuilder().setTag(keyBytes).setStateFamily(stateFamily); V cachedValue = cachedValues.remove(key); if (cachedValue != null) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java index 403bb99efb4c..8a1ba2556cf2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java @@ -17,18 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath.roundToLong; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide; import java.math.RoundingMode; -import java.util.Map; -import java.util.Map.Entry; -import java.util.function.Function; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,22 +29,11 @@ @Internal final class EvenGetWorkBudgetDistributor implements GetWorkBudgetDistributor { private static final Logger LOG = LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class); - private final Supplier activeWorkBudgetSupplier; - - EvenGetWorkBudgetDistributor(Supplier activeWorkBudgetSupplier) { - this.activeWorkBudgetSupplier = activeWorkBudgetSupplier; - } - - private static boolean isBelowFiftyPercentOfTarget( - GetWorkBudget remaining, GetWorkBudget target) { - return remaining.items() < roundToLong(target.items() * 0.5, RoundingMode.CEILING) - || remaining.bytes() < roundToLong(target.bytes() * 0.5, RoundingMode.CEILING); - } @Override public void distributeBudget( - ImmutableCollection budgetOwners, GetWorkBudget getWorkBudget) { - if (budgetOwners.isEmpty()) { + ImmutableCollection budgetSpenders, GetWorkBudget getWorkBudget) { + if (budgetSpenders.isEmpty()) { LOG.debug("Cannot distribute budget to no owners."); return; } @@ -61,38 +43,15 @@ public void distributeBudget( return; } - Map desiredBudgets = computeDesiredBudgets(budgetOwners, getWorkBudget); - - for (Entry streamAndDesiredBudget : desiredBudgets.entrySet()) { - GetWorkBudgetSpender getWorkBudgetSpender = streamAndDesiredBudget.getKey(); - GetWorkBudget desired = streamAndDesiredBudget.getValue(); - GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget(); - if (isBelowFiftyPercentOfTarget(remaining, desired)) { - GetWorkBudget adjustment = desired.subtract(remaining); - getWorkBudgetSpender.adjustBudget(adjustment); - } - } + GetWorkBudget budgetPerStream = computeDesiredPerStreamBudget(budgetSpenders, getWorkBudget); + budgetSpenders.forEach(getWorkBudgetSpender -> getWorkBudgetSpender.setBudget(budgetPerStream)); } - private ImmutableMap computeDesiredBudgets( + private GetWorkBudget computeDesiredPerStreamBudget( ImmutableCollection streams, GetWorkBudget totalGetWorkBudget) { - GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get(); - LOG.info("Current active work budget: {}", activeWorkBudget); - // TODO: Fix possibly non-deterministic handing out of budgets. - // Rounding up here will drift upwards over the lifetime of the streams. - GetWorkBudget budgetPerStream = - GetWorkBudget.builder() - .setItems( - divide( - totalGetWorkBudget.items() - activeWorkBudget.items(), - streams.size(), - RoundingMode.CEILING)) - .setBytes( - divide( - totalGetWorkBudget.bytes() - activeWorkBudget.bytes(), - streams.size(), - RoundingMode.CEILING)) - .build(); - return streams.stream().collect(toImmutableMap(Function.identity(), unused -> budgetPerStream)); + return GetWorkBudget.builder() + .setItems(divide(totalGetWorkBudget.items(), streams.size(), RoundingMode.CEILING)) + .setBytes(divide(totalGetWorkBudget.bytes(), streams.size(), RoundingMode.CEILING)) + .build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java index 43c0d46139da..2013c9ff1cb7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java @@ -17,13 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; @Internal public final class GetWorkBudgetDistributors { - public static GetWorkBudgetDistributor distributeEvenly( - Supplier activeWorkBudgetSupplier) { - return new EvenGetWorkBudgetDistributor(activeWorkBudgetSupplier); + public static GetWorkBudgetDistributor distributeEvenly() { + return new EvenGetWorkBudgetDistributor(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java index e39aa8dbc8a5..d81c7d0593f3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java @@ -51,7 +51,7 @@ public GetWorkBudgetRefresher( Supplier isBudgetRefreshPaused, Runnable redistributeBudget) { this.budgetRefreshTrigger = new AdvancingPhaser(1); this.budgetRefreshExecutor = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setNameFormat(BUDGET_REFRESH_THREAD) .setUncaughtExceptionHandler( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java index 254b2589062e..decf101a641b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java @@ -22,11 +22,9 @@ * org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget} */ public interface GetWorkBudgetSpender { - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(long items, long bytes); - default void adjustBudget(GetWorkBudget adjustment) { - adjustBudget(adjustment.items(), adjustment.bytes()); + default void setBudget(GetWorkBudget budget) { + setBudget(budget.items(), budget.bytes()); } - - GetWorkBudget remainingBudget(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 965a29126dc2..c74874c465a6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -70,7 +70,7 @@ */ @Internal @ThreadSafe -public final class StreamingWorkScheduler { +public class StreamingWorkScheduler { private static final Logger LOG = LoggerFactory.getLogger(StreamingWorkScheduler.class); private final DataflowWorkerHarnessOptions options; @@ -225,6 +225,7 @@ private void processWork(ComputationState computationState, Work work) { Windmill.WorkItem workItem = work.getWorkItem(); String computationId = computationState.getComputationId(); ByteString key = workItem.getKey(); + work.setProcessingThreadName(Thread.currentThread().getName()); work.setState(Work.State.PROCESSING); setUpWorkLoggingContext(work.getLatencyTrackingId(), computationId); LOG.debug("Starting processing for {}:\n{}", computationId, work); @@ -288,6 +289,7 @@ private void processWork(ComputationState computationState, Work work) { } resetWorkLoggingContext(work.getLatencyTrackingId()); + work.setProcessingThreadName(""); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java index 33a55d1927f8..ed5f2db7f480 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java @@ -20,8 +20,8 @@ import java.util.Objects; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.sdk.annotations.Internal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,7 +61,7 @@ public void sendHeartbeats(Heartbeats heartbeats) { Thread.currentThread().setName(originalThreadName + "-" + backendWorkerToken); } getDataStream.refreshActiveWork(heartbeats.heartbeatRequests().asMap()); - } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + } catch (WindmillStreamShutdownException e) { LOG.warn( "Trying to refresh work w/ {} heartbeats on stream={} after work has moved off of worker." + " heartbeats", diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java index 071bf7fa3d43..6d768e8a972c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java @@ -21,12 +21,14 @@ import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap; /** Heartbeat requests and the work that was used to generate the heartbeat requests. */ +@Internal @AutoValue -abstract class Heartbeats { +public abstract class Heartbeats { static Heartbeats.Builder builder() { return new AutoValue_Heartbeats.Builder(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java index fa36b11ffe55..f54091dc2b95 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java @@ -42,7 +42,7 @@ private StreamPoolHeartbeatSender( this.heartbeatStreamPool.set(heartbeatStreamPool); } - public static StreamPoolHeartbeatSender Create( + public static StreamPoolHeartbeatSender create( @Nonnull WindmillStreamPool heartbeatStreamPool) { return new StreamPoolHeartbeatSender(heartbeatStreamPool); } @@ -55,7 +55,7 @@ public static StreamPoolHeartbeatSender Create( * enabled. * @param getDataPool stream to use when using separate streams for heartbeat is disabled. */ - public static StreamPoolHeartbeatSender Create( + public static StreamPoolHeartbeatSender create( @Nonnull WindmillStreamPool dedicatedHeartbeatPool, @Nonnull WindmillStreamPool getDataPool, @Nonnull StreamingGlobalConfigHandle configHandle) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index b3f7467cdbd3..1da48bd2b7ce 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -236,6 +236,9 @@ public String backendWorkerToken() { return ""; } + @Override + public void start() {} + @Override public void shutdown() {} @@ -245,18 +248,10 @@ public void halfClose() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op. } - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setItems(request.getMaxItems()) - .setBytes(request.getMaxBytes()) - .build(); - } - @Override public boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException { while (done.getCount() > 0) { @@ -307,6 +302,9 @@ public String backendWorkerToken() { return ""; } + @Override + public void start() {} + @Override public void shutdown() {} @@ -388,6 +386,9 @@ public String backendWorkerToken() { return ""; } + @Override + public void start() {} + @Override public void shutdown() {} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index dadf02171235..6eeb7bd6bbfc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -96,6 +96,7 @@ import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.util.Structs; +import org.apache.beam.runners.dataflow.worker.counters.DataflowCounterUpdateExtractor; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ComputationStateCache; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; @@ -104,6 +105,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandleImpl; +import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; import org.apache.beam.runners.dataflow.worker.testing.RestoreDataflowLoggingMDC; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; @@ -129,6 +131,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WatermarkHold; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; +import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactoryFactory; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; @@ -178,6 +183,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.UnsignedLong; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -285,6 +291,7 @@ public Long get() { private final FakeWindmillServer server = new FakeWindmillServer( errorCollector, computationId -> computationStateCache.get(computationId)); + private StreamingCounters streamingCounters; public StreamingDataflowWorkerTest(Boolean streamingEngine) { this.streamingEngine = streamingEngine; @@ -304,9 +311,20 @@ private static CounterUpdate getCounter(Iterable counters, String return null; } + private Iterable buildCounters() { + return Iterables.concat( + streamingCounters + .pendingDeltaCounters() + .extractModifiedDeltaUpdates(DataflowCounterUpdateExtractor.INSTANCE), + streamingCounters + .pendingCumulativeCounters() + .extractUpdates(false, DataflowCounterUpdateExtractor.INSTANCE)); + } + @Before public void setUp() { server.clearCommitsReceived(); + streamingCounters = StreamingCounters.create(); } @After @@ -856,7 +874,13 @@ private StreamingDataflowWorker makeWorker( streamingDataflowWorkerTestParams.clock(), streamingDataflowWorkerTestParams.executorSupplier(), mockGlobalConfigHandle, - streamingDataflowWorkerTestParams.localRetryTimeoutMs()); + streamingDataflowWorkerTestParams.localRetryTimeoutMs(), + streamingCounters, + new FakeWindmillStubFactoryFactory( + new FakeWindmillStubFactory( + () -> + WindmillChannelFactory.inProcessChannel( + "StreamingDataflowWorkerTestChannel")))); this.computationStateCache = worker.getComputationStateCache(); return worker; } @@ -1715,7 +1739,7 @@ public void testMergeWindows() throws Exception { intervalWindowBytes(WINDOW_AT_ZERO))); Map result = server.waitForAndGetCommits(1); - Iterable counters = worker.buildCounters(); + Iterable counters = buildCounters(); // These tags and data are opaque strings and this is a change detector test. // The "/u" indicates the user's namespace, versus "/s" for system namespace @@ -1836,7 +1860,7 @@ public void testMergeWindows() throws Exception { expectedBytesRead += dataBuilder.build().getSerializedSize(); result = server.waitForAndGetCommits(1); - counters = worker.buildCounters(); + counters = buildCounters(); actualOutput = result.get(2L); assertEquals(1, actualOutput.getOutputMessagesCount()); @@ -2004,7 +2028,7 @@ public void testMergeWindowsCaching() throws Exception { intervalWindowBytes(WINDOW_AT_ZERO))); Map result = server.waitForAndGetCommits(1); - Iterable counters = worker.buildCounters(); + Iterable counters = buildCounters(); // These tags and data are opaque strings and this is a change detector test. // The "/u" indicates the user's namespace, versus "/s" for system namespace @@ -2125,7 +2149,7 @@ public void testMergeWindowsCaching() throws Exception { expectedBytesRead += dataBuilder.build().getSerializedSize(); result = server.waitForAndGetCommits(1); - counters = worker.buildCounters(); + counters = buildCounters(); actualOutput = result.get(2L); assertEquals(1, actualOutput.getOutputMessagesCount()); @@ -2430,7 +2454,7 @@ public void testUnboundedSources() throws Exception { null)); Map result = server.waitForAndGetCommits(1); - Iterable counters = worker.buildCounters(); + Iterable counters = buildCounters(); Windmill.WorkItemCommitRequest commit = result.get(1L); UnsignedLong finalizeId = UnsignedLong.fromLongBits(commit.getSourceStateUpdates().getFinalizeIds(0)); @@ -2492,7 +2516,7 @@ public void testUnboundedSources() throws Exception { null)); result = server.waitForAndGetCommits(1); - counters = worker.buildCounters(); + counters = buildCounters(); commit = result.get(2L); finalizeId = UnsignedLong.fromLongBits(commit.getSourceStateUpdates().getFinalizeIds(0)); @@ -2540,7 +2564,7 @@ public void testUnboundedSources() throws Exception { null)); result = server.waitForAndGetCommits(1); - counters = worker.buildCounters(); + counters = buildCounters(); commit = result.get(3L); finalizeId = UnsignedLong.fromLongBits(commit.getSourceStateUpdates().getFinalizeIds(0)); @@ -2710,7 +2734,7 @@ public void testUnboundedSourceWorkRetry() throws Exception { server.whenGetWorkCalled().thenReturn(work); Map result = server.waitForAndGetCommits(1); - Iterable counters = worker.buildCounters(); + Iterable counters = buildCounters(); Windmill.WorkItemCommitRequest commit = result.get(1L); UnsignedLong finalizeId = UnsignedLong.fromLongBits(commit.getSourceStateUpdates().getFinalizeIds(0)); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java index 2d5a8d8266ae..37c5ad261280 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java @@ -366,7 +366,6 @@ public void testExtractPerWorkerMetricUpdates_populatedMetrics() { .setMetricsNamespace("BigQuerySink") .setMetricValues(Collections.singletonList(expectedCounter)); - // Expected histogram metric List bucketCounts = Collections.singletonList(1L); Linear linearOptions = new Linear().setNumberOfBuckets(10).setWidth(10.0).setStart(0.0); @@ -393,6 +392,44 @@ public void testExtractPerWorkerMetricUpdates_populatedMetrics() { assertThat(updates, containsInAnyOrder(histograms, counters)); } + @Test + public void testExtractPerWorkerMetricUpdatesKafka_populatedMetrics() { + StreamingStepMetricsContainer.setEnablePerWorkerMetrics(true); + + MetricName histogramMetricName = MetricName.named("KafkaSink", "histogram"); + HistogramData.LinearBuckets linearBuckets = HistogramData.LinearBuckets.of(0, 10, 10); + c2.getPerWorkerHistogram(histogramMetricName, linearBuckets).update(5.0); + + Iterable updates = + StreamingStepMetricsContainer.extractPerWorkerMetricUpdates(registry); + + // Expected histogram metric + List bucketCounts = Collections.singletonList(1L); + + Linear linearOptions = new Linear().setNumberOfBuckets(10).setWidth(10.0).setStart(0.0); + BucketOptions bucketOptions = new BucketOptions().setLinear(linearOptions); + + DataflowHistogramValue linearHistogram = + new DataflowHistogramValue() + .setCount(1L) + .setBucketOptions(bucketOptions) + .setBucketCounts(bucketCounts); + + MetricValue expectedHistogram = + new MetricValue() + .setMetric("histogram") + .setMetricLabels(new HashMap<>()) + .setValueHistogram(linearHistogram); + + PerStepNamespaceMetrics histograms = + new PerStepNamespaceMetrics() + .setOriginalStep("s2") + .setMetricsNamespace("KafkaSink") + .setMetricValues(Collections.singletonList(expectedHistogram)); + + assertThat(updates, containsInAnyOrder(histograms)); + } + @Test public void testExtractPerWorkerMetricUpdates_emptyMetrics() { StreamingStepMetricsContainer.setEnablePerWorkerMetrics(true); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java index ff114ef2f078..c1e5000f03da 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.theInstance; @@ -153,6 +154,21 @@ private static class TestStatefulDoFn extends DoFn, Void> { public void processElement(ProcessContext c) {} } + private static class TestStatefulDoFnWithWindowExpiration + extends DoFn, Void> { + + public static final String STATE_ID = "state-id"; + + @StateId(STATE_ID) + private final StateSpec> spec = StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void processElement(ProcessContext c) {} + + @OnWindowExpiration + public void onWindowExpiration() {} + } + private static final TupleTag MAIN_OUTPUT = new TupleTag<>("1"); private UserParDoFnFactory factory = UserParDoFnFactory.createDefault(); @@ -373,6 +389,92 @@ public void testCleanupRegistered() throws Exception { firstWindow.maxTimestamp().plus(Duration.millis(1L))); } + /** + * Regression test for global window + OnWindowExpiration + allowed lateness > max allowed time + */ + @Test + public void testCleanupTimerForGlobalWindowWithAllowedLateness() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + CounterSet counters = new CounterSet(); + DoFn initialFn = new TestStatefulDoFnWithWindowExpiration(); + Duration allowedLateness = Duration.standardDays(2); + CloudObject cloudObject = + getCloudObject( + initialFn, WindowingStrategy.globalDefault().withAllowedLateness(allowedLateness)); + + StateInternals stateInternals = InMemoryStateInternals.forKey("dummy"); + + TimerInternals timerInternals = mock(TimerInternals.class); + + DataflowStepContext stepContext = mock(DataflowStepContext.class); + when(stepContext.timerInternals()).thenReturn(timerInternals); + DataflowStepContext userStepContext = mock(DataflowStepContext.class); + when(stepContext.namespacedToUser()).thenReturn(userStepContext); + when(stepContext.stateInternals()).thenReturn(stateInternals); + when(userStepContext.stateInternals()).thenReturn((StateInternals) stateInternals); + + DataflowExecutionContext executionContext = + mock(DataflowExecutionContext.class); + TestOperationContext operationContext = TestOperationContext.create(counters); + when(executionContext.getStepContext(operationContext)).thenReturn(stepContext); + when(executionContext.getSideInputReader(any(), any(), any())) + .thenReturn(NullSideInputReader.empty()); + + ParDoFn parDoFn = + factory.create( + options, + cloudObject, + Collections.emptyList(), + MAIN_OUTPUT, + ImmutableMap.of(MAIN_OUTPUT, 0), + executionContext, + operationContext); + + Receiver rcvr = new OutputReceiver(); + parDoFn.startBundle(rcvr); + + GlobalWindow globalWindow = GlobalWindow.INSTANCE; + parDoFn.processElement( + WindowedValue.of("foo", new Instant(1), globalWindow, PaneInfo.NO_FIRING)); + + assertThat( + globalWindow.maxTimestamp().plus(allowedLateness), + greaterThan(BoundedWindow.TIMESTAMP_MAX_VALUE)); + verify(stepContext) + .setStateCleanupTimer( + SimpleParDoFn.CLEANUP_TIMER_ID, + globalWindow, + GlobalWindow.Coder.INSTANCE, + BoundedWindow.TIMESTAMP_MAX_VALUE, + BoundedWindow.TIMESTAMP_MAX_VALUE.minus(Duration.millis(1))); + + StateNamespace globalWindowNamespace = + StateNamespaces.window(GlobalWindow.Coder.INSTANCE, globalWindow); + StateTag> tag = + StateTags.tagForSpec( + TestStatefulDoFnWithWindowExpiration.STATE_ID, StateSpecs.value(StringUtf8Coder.of())); + + when(userStepContext.getNextFiredTimer((Coder) GlobalWindow.Coder.INSTANCE)).thenReturn(null); + when(stepContext.getNextFiredTimer((Coder) GlobalWindow.Coder.INSTANCE)) + .thenReturn( + TimerData.of( + SimpleParDoFn.CLEANUP_TIMER_ID, + globalWindowNamespace, + BoundedWindow.TIMESTAMP_MAX_VALUE, + BoundedWindow.TIMESTAMP_MAX_VALUE.minus(Duration.millis(1)), + TimeDomain.EVENT_TIME)) + .thenReturn(null); + + // Set up non-empty state. We don't mock + verify calls to clear() but instead + // check that state is actually empty. We mustn't care how it is accomplished. + stateInternals.state(globalWindowNamespace, tag).write("first"); + + // And this should clean up the second window + parDoFn.processTimers(); + + assertThat(stateInternals.state(globalWindowNamespace, tag).read(), nullValue()); + } + @Test public void testCleanupWorks() throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java index 4f035c88774c..c71001fbeee7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; @@ -30,27 +31,29 @@ @RunWith(JUnit4.class) public class WeightBoundedQueueTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final int MAX_WEIGHT = 10; + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Test public void testPut_hasCapacity() { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); int insertedValue = 1; queue.put(insertedValue); - assertEquals(insertedValue, queue.queuedElementsWeight()); + assertEquals(insertedValue, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); assertEquals(insertedValue, (int) queue.poll()); } @Test public void testPut_noCapacity() throws InterruptedException { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); // Insert value that takes all the capacity into the queue. queue.put(MAX_WEIGHT); @@ -71,7 +74,7 @@ public void testPut_noCapacity() throws InterruptedException { // Should only see the first value in the queue, since the queue is at capacity. thread2 // should be blocked. - assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); // Poll the queue, pulling off the only value inside and freeing up the capacity in the queue. @@ -80,14 +83,15 @@ public void testPut_noCapacity() throws InterruptedException { // Wait for the putThread which was previously blocked due to the queue being at capacity. putThread.join(); - assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); } @Test public void testPoll() { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); int insertedValue1 = 1; int insertedValue2 = 2; @@ -95,7 +99,7 @@ public void testPoll() { queue.put(insertedValue1); queue.put(insertedValue2); - assertEquals(insertedValue1 + insertedValue2, queue.queuedElementsWeight()); + assertEquals(insertedValue1 + insertedValue2, weightedSemaphore.currentWeight()); assertEquals(2, queue.size()); assertEquals(insertedValue1, (int) queue.poll()); assertEquals(1, queue.size()); @@ -104,7 +108,8 @@ public void testPoll() { @Test public void testPoll_withTimeout() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int pollWaitTimeMillis = 10000; int insertedValue1 = 1; @@ -132,7 +137,8 @@ public void testPoll_withTimeout() throws InterruptedException { @Test public void testPoll_withTimeout_timesOut() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int defaultPollResult = -10; int pollWaitTimeMillis = 100; int insertedValue1 = 1; @@ -144,13 +150,17 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException { Thread pollThread = new Thread( () -> { - int polled; + @Nullable Integer polled; try { polled = queue.poll(pollWaitTimeMillis, TimeUnit.MILLISECONDS); - pollResult.set(polled); + if (polled != null) { + pollResult.set(polled); + } } catch (InterruptedException e) { throw new RuntimeException(e); } + + assertNull(polled); }); pollThread.start(); @@ -164,7 +174,8 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException { @Test public void testPoll_emptyQueue() { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); assertNull(queue.poll()); } @@ -172,7 +183,8 @@ public void testPoll_emptyQueue() { @Test public void testTake() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); AtomicInteger value = new AtomicInteger(); // Should block until value is available @@ -194,4 +206,39 @@ public void testTake() throws InterruptedException { assertEquals(MAX_WEIGHT, value.get()); } + + @Test + public void testPut_sharedWeigher() throws InterruptedException { + WeightedSemaphore weigher = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue1 = WeightedBoundedQueue.create(weigher); + WeightedBoundedQueue queue2 = WeightedBoundedQueue.create(weigher); + + // Insert value that takes all the weight into the queue1. + queue1.put(MAX_WEIGHT); + + // Try to insert a value into the queue2. This will block since there is no capacity in the + // weigher. + Thread putThread = new Thread(() -> queue2.put(MAX_WEIGHT)); + putThread.start(); + // Should only see the first value in the queue, since the queue is at capacity. putThread + // should be blocked. The weight should be the same however, since queue1 and queue2 are sharing + // the weigher. + Thread.sleep(100); + assertEquals(MAX_WEIGHT, weigher.currentWeight()); + assertEquals(1, queue1.size()); + assertEquals(0, queue2.size()); + + // Poll queue1, pulling off the only value inside and freeing up the capacity in the weigher. + queue1.poll(); + + // Wait for the putThread which was previously blocked due to the weigher being at capacity. + putThread.join(); + + assertEquals(MAX_WEIGHT, weigher.currentWeight()); + assertEquals(1, queue2.size()); + queue2.poll(); + assertEquals(0, queue2.size()); + assertEquals(0, weigher.currentWeight()); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java index 9fa17588c94d..3a0ae7bb2084 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java @@ -47,7 +47,6 @@ @RunWith(JUnit4.class) public class StreamingEngineComputationConfigFetcherTest { - private final WorkUnitClient mockDataflowServiceClient = mock(WorkUnitClient.class, new Returns(Optional.empty())); private StreamingEngineComputationConfigFetcher streamingEngineConfigFetcher; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index ed8815c48e76..606d2b9dbdbc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -30,23 +30,20 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashSet; -import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; @@ -66,12 +63,10 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessSocketAddress; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.junit.After; import org.junit.Before; @@ -83,6 +78,7 @@ @RunWith(JUnit4.class) public class FanOutStreamingEngineWorkerHarnessTest { + private static final String CHANNEL_NAME = "FanOutStreamingEngineWorkerHarnessTest"; private static final WindmillServiceAddress DEFAULT_WINDMILL_SERVICE_ADDRESS = WindmillServiceAddress.create(HostAndPort.fromParts(WindmillChannelFactory.LOCALHOST, 443)); private static final ImmutableMap DEFAULT = @@ -92,7 +88,6 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setDirectEndpoint(DEFAULT_WINDMILL_SERVICE_ADDRESS.gcpServiceAddress().toString()) .build()); - private static final long CLIENT_ID = 1L; private static final String JOB_ID = "jobId"; private static final String PROJECT_ID = "projectId"; private static final String WORKER_ID = "workerId"; @@ -101,17 +96,15 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) + .setClientId(1L) .build(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); private final GrpcWindmillStreamFactory streamFactory = spy(GrpcWindmillStreamFactory.of(JOB_HEADER).build()); private final ChannelCachingStubFactory stubFactory = new FakeWindmillStubFactory( - () -> - grpcCleanup.register( - WindmillChannelFactory.inProcessChannel("StreamingEngineClientTest"))); + () -> grpcCleanup.register(WindmillChannelFactory.inProcessChannel(CHANNEL_NAME))); private final GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.forTesting( PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class), @@ -134,7 +127,7 @@ private static GetWorkRequest getWorkRequest(long items, long bytes) { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) - .setClientId(CLIENT_ID) + .setClientId(JOB_HEADER.getClientId()) .setMaxItems(items) .setMaxBytes(bytes) .build(); @@ -149,59 +142,59 @@ private static WorkerMetadataResponse.Endpoint metadataResponseEndpoint(String w @Before public void setUp() throws IOException { - stubFactory.shutdown(); + getWorkerMetadataReady = new CountDownLatch(1); + fakeGetWorkerMetadataStub = new GetWorkerMetadataTestStub(getWorkerMetadataReady); fakeStreamingEngineServer = - grpcCleanup.register( - InProcessServerBuilder.forName("StreamingEngineClientTest") - .fallbackHandlerRegistry(serviceRegistry) - .executor(Executors.newFixedThreadPool(1)) - .build()); + grpcCleanup + .register( + InProcessServerBuilder.forName(CHANNEL_NAME) + .directExecutor() + .addService(fakeGetWorkerMetadataStub) + .addService(new WindmillServiceFakeStub()) + .build()) + .start(); - fakeStreamingEngineServer.start(); dispatcherClient.consumeWindmillDispatcherEndpoints( ImmutableSet.of( - HostAndPort.fromString( - new InProcessSocketAddress("StreamingEngineClientTest").toString()))); - getWorkerMetadataReady = new CountDownLatch(1); - fakeGetWorkerMetadataStub = new GetWorkerMetadataTestStub(getWorkerMetadataReady); - serviceRegistry.addService(fakeGetWorkerMetadataStub); + HostAndPort.fromString(new InProcessSocketAddress(CHANNEL_NAME).toString()))); } @After public void cleanUp() { Preconditions.checkNotNull(fanOutStreamingEngineWorkProvider).shutdown(); - fakeStreamingEngineServer.shutdownNow(); stubFactory.shutdown(); + fakeStreamingEngineServer.shutdownNow(); } - private FanOutStreamingEngineWorkerHarness newStreamingEngineClient( + private FanOutStreamingEngineWorkerHarness newFanOutStreamingEngineWorkerHarness( GetWorkBudget getWorkBudget, GetWorkBudgetDistributor getWorkBudgetDistributor, - WorkItemScheduler workItemScheduler) { - return FanOutStreamingEngineWorkerHarness.forTesting( - JOB_HEADER, - getWorkBudget, - streamFactory, - workItemScheduler, - stubFactory, - getWorkBudgetDistributor, - dispatcherClient, - CLIENT_ID, - ignored -> mock(WorkCommitter.class), - new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class))); + WorkItemScheduler workItemScheduler) + throws InterruptedException { + FanOutStreamingEngineWorkerHarness harness = + FanOutStreamingEngineWorkerHarness.forTesting( + JOB_HEADER, + getWorkBudget, + streamFactory, + workItemScheduler, + stubFactory, + getWorkBudgetDistributor, + dispatcherClient, + ignored -> mock(WorkCommitter.class), + new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class))); + getWorkerMetadataReady.await(); + return harness; } @Test public void testStreamsStartCorrectly() throws InterruptedException { long items = 10L; long bytes = 10L; - int numBudgetDistributionsExpected = 1; - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected)); + TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor()); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(items).setBytes(bytes).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -209,26 +202,20 @@ public void testStreamsStartCorrectly() throws InterruptedException { String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; - WorkerMetadataResponse firstWorkerMetadata = + fakeGetWorkerMetadataStub.injectWorkerMetadata( WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) .addWorkEndpoints(metadataResponseEndpoint(workerToken)) .addWorkEndpoints(metadataResponseEndpoint(workerToken2)) .putAllGlobalDataEndpoints(DEFAULT) - .build(); - - getWorkerMetadataReady.await(); - fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + .build()); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); - assertEquals(2, currentConnections.windmillConnections().size()); - assertEquals(2, currentConnections.windmillStreams().size()); + assertEquals(2, currentBackends.windmillStreams().size()); Set workerTokens = - currentConnections.windmillConnections().values().stream() - .map(WindmillConnection::backendWorkerToken) + currentBackends.windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertTrue(workerTokens.contains(workerToken)); @@ -248,39 +235,16 @@ public void testStreamsStartCorrectly() throws InterruptedException { any(), eq(noOpProcessWorkItemFn())); - verify(streamFactory, times(2)).createGetDataStream(any(), any()); - verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); - } - - @Test - public void testScheduledBudgetRefresh() throws InterruptedException { - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(2)); - fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( - GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), - getWorkBudgetDistributor, - noOpProcessWorkItemFn()); - - getWorkerMetadataReady.await(); - fakeGetWorkerMetadataStub.injectWorkerMetadata( - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(1) - .addWorkEndpoints(metadataResponseEndpoint("workerToken")) - .putAllGlobalDataEndpoints(DEFAULT) - .build()); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(), any()); + verify(streamFactory, times(2)).createDirectGetDataStream(any(), any()); + verify(streamFactory, times(2)).createDirectCommitWorkStream(any(), any()); } @Test public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() throws InterruptedException { - int metadataCount = 2; - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(metadataCount)); + TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor()); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -309,32 +273,26 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() WorkerMetadataResponse.Endpoint.newBuilder() .setBackendWorkerToken(workerToken3) .build()) - .putAllGlobalDataEndpoints(DEFAULT) .build(); - getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); - assertEquals(1, currentConnections.windmillConnections().size()); - assertEquals(1, currentConnections.windmillStreams().size()); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); + assertEquals(1, currentBackends.windmillStreams().size()); Set workerTokens = - fanOutStreamingEngineWorkProvider.getCurrentConnections().windmillConnections().values() - .stream() - .map(WindmillConnection::backendWorkerToken) + fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertFalse(workerTokens.contains(workerToken)); assertFalse(workerTokens.contains(workerToken2)); + assertTrue(currentBackends.globalDataStreams().isEmpty()); } @Test public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedException { String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; - String workerToken3 = "workerToken3"; WorkerMetadataResponse firstWorkerMetadata = WorkerMetadataResponse.newBuilder() @@ -354,42 +312,78 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); - WorkerMetadataResponse thirdWorkerMetadata = - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(3) - .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder() - .setBackendWorkerToken(workerToken3) - .build()) - .putAllGlobalDataEndpoints(DEFAULT) - .build(); - - List workerMetadataResponses = - Lists.newArrayList(firstWorkerMetadata, secondWorkerMetadata, thirdWorkerMetadata); - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(workerMetadataResponses.size())); + TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor()); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); - getWorkerMetadataReady.await(); - - // Make sure we are injecting the metadata from smallest to largest. - workerMetadataResponses.stream() - .sorted(Comparator.comparingLong(WorkerMetadataResponse::getMetadataVersion)) - .forEach(fakeGetWorkerMetadataStub::injectWorkerMetadata); + fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(workerMetadataResponses.size())) - .distributeBudget(any(), any()); + verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any()); } - private void waitForWorkerMetadataToBeConsumed( - TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws InterruptedException { - getWorkBudgetDistributor.waitForBudgetDistribution(); + private static class WindmillServiceFakeStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + @Override + public StreamObserver getDataStream( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(Windmill.StreamingGetDataRequest getDataRequest) {} + + @Override + public void onError(Throwable throwable) { + responseObserver.onError(throwable); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + + @Override + public StreamObserver getWorkStream( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(Windmill.StreamingGetWorkRequest getWorkRequest) {} + + @Override + public void onError(Throwable throwable) { + responseObserver.onError(throwable); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + + @Override + public StreamObserver commitWorkStream( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {} + + @Override + public void onError(Throwable throwable) { + responseObserver.onError(throwable); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } } private static class GetWorkerMetadataTestStub @@ -422,7 +416,11 @@ public void onError(Throwable throwable) { } @Override - public void onCompleted() {} + public void onCompleted() { + if (responseObserver != null) { + responseObserver.onCompleted(); + } + } }; } @@ -434,22 +432,10 @@ private void injectWorkerMetadata(WorkerMetadataResponse response) { } private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor { - private final CountDownLatch getWorkBudgetDistributorTriggered; - - private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) { - this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); - } - - @SuppressWarnings("ReturnValueIgnored") - private void waitForBudgetDistribution() throws InterruptedException { - getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); - } - @Override public void distributeBudget( ImmutableCollection streams, GetWorkBudget getWorkBudget) { - streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(), getWorkBudget.bytes())); - getWorkBudgetDistributorTriggered.countDown(); + streams.forEach(stream -> stream.setBudget(getWorkBudget.items(), getWorkBudget.bytes())); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java new file mode 100644 index 000000000000..4df3bf7cd823 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.WorkerUncaughtExceptionHandler; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.util.common.worker.JvmRuntime; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.processing.StreamingWorkScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@RunWith(JUnit4.class) +public class SingleSourceWorkerHarnessTest { + private static final Logger LOG = LoggerFactory.getLogger(SingleSourceWorkerHarnessTest.class); + private final WorkCommitter workCommitter = mock(WorkCommitter.class); + private final GetDataClient getDataClient = mock(GetDataClient.class); + private final HeartbeatSender heartbeatSender = mock(HeartbeatSender.class); + private final Runnable waitForResources = () -> {}; + private final Function> computationStateFetcher = + ignored -> Optional.empty(); + private final StreamingWorkScheduler streamingWorkScheduler = mock(StreamingWorkScheduler.class); + + private SingleSourceWorkerHarness createWorkerHarness( + SingleSourceWorkerHarness.GetWorkSender getWorkSender, JvmRuntime runtime) { + // In non-test scenario this is set in DataflowWorkerHarnessHelper.initializeLogging(...). + Thread.setDefaultUncaughtExceptionHandler(new WorkerUncaughtExceptionHandler(runtime, LOG)); + return SingleSourceWorkerHarness.builder() + .setWorkCommitter(workCommitter) + .setGetDataClient(getDataClient) + .setHeartbeatSender(heartbeatSender) + .setWaitForResources(waitForResources) + .setStreamingWorkScheduler(streamingWorkScheduler) + .setComputationStateFetcher(computationStateFetcher) + // no-op throttle time supplier. + .setThrottledTimeTracker(() -> 0L) + .setGetWorkSender(getWorkSender) + .build(); + } + + @Test + public void testDispatchLoop_unexpectedFailureKillsJvm_appliance() { + SingleSourceWorkerHarness.GetWorkSender getWorkSender = + SingleSourceWorkerHarness.GetWorkSender.forAppliance( + () -> { + throw new RuntimeException("something bad happened"); + }); + + FakeJvmRuntime fakeJvmRuntime = new FakeJvmRuntime(); + createWorkerHarness(getWorkSender, fakeJvmRuntime).start(); + assertTrue(fakeJvmRuntime.waitForRuntimeDeath(5, TimeUnit.SECONDS)); + fakeJvmRuntime.assertJvmTerminated(); + } + + @Test + public void testDispatchLoop_unexpectedFailureKillsJvm_streamingEngine() { + SingleSourceWorkerHarness.GetWorkSender getWorkSender = + SingleSourceWorkerHarness.GetWorkSender.forStreamingEngine( + workItemReceiver -> { + throw new RuntimeException("something bad happened"); + }); + + FakeJvmRuntime fakeJvmRuntime = new FakeJvmRuntime(); + createWorkerHarness(getWorkSender, fakeJvmRuntime).start(); + assertTrue(fakeJvmRuntime.waitForRuntimeDeath(5, TimeUnit.SECONDS)); + fakeJvmRuntime.assertJvmTerminated(); + } + + private static class FakeJvmRuntime implements JvmRuntime { + private final CountDownLatch haltedLatch = new CountDownLatch(1); + private volatile int exitStatus = 0; + + @Override + public void halt(int status) { + exitStatus = status; + haltedLatch.countDown(); + } + + public boolean waitForRuntimeDeath(long timeout, TimeUnit unit) { + try { + return haltedLatch.await(timeout, unit); + } catch (InterruptedException e) { + return false; + } + } + + private void assertJvmTerminated() { + assertThat(exitStatus).isEqualTo(WorkerUncaughtExceptionHandler.JVM_TERMINATED_STATUS_CODE); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java index 7e65a495638f..f348e4cf1bdb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java @@ -39,14 +39,15 @@ @RunWith(JUnit4.class) public class StreamingWorkerStatusReporterTest { - private final long DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME = 1000; - private final long DEFAULT_HARNESS_REPORTING_PERIOD = 10000; - private final long DEFAULT_PER_WORKER_METRICS_PERIOD = 30000; + private static final long DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME = 1000; + private static final long DEFAULT_HARNESS_REPORTING_PERIOD = 10000; + private static final long DEFAULT_PER_WORKER_METRICS_PERIOD = 30000; private BoundedQueueExecutor mockExecutor; private WorkUnitClient mockWorkUnitClient; private FailureTracker mockFailureTracker; private MemoryMonitor mockMemoryMonitor; + private StreamingWorkerStatusReporter reporter; @Before public void setUp() { @@ -54,23 +55,11 @@ public void setUp() { this.mockWorkUnitClient = mock(WorkUnitClient.class); this.mockFailureTracker = mock(FailureTracker.class); this.mockMemoryMonitor = mock(MemoryMonitor.class); + this.reporter = buildWorkerStatusReporterForTest(); } @Test public void testOverrideMaximumThreadCount() throws Exception { - StreamingWorkerStatusReporter reporter = - StreamingWorkerStatusReporter.forTesting( - true, - mockWorkUnitClient, - () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME, - () -> Collections.emptyList(), - mockFailureTracker, - StreamingCounters.create(), - mockMemoryMonitor, - mockExecutor, - (threadName) -> Executors.newSingleThreadScheduledExecutor(), - DEFAULT_HARNESS_REPORTING_PERIOD, - DEFAULT_PER_WORKER_METRICS_PERIOD); StreamingScalingReportResponse streamingScalingReportResponse = new StreamingScalingReportResponse().setMaximumThreadCount(10); WorkerMessageResponse workerMessageResponse = @@ -84,23 +73,25 @@ public void testOverrideMaximumThreadCount() throws Exception { @Test public void testHandleEmptyWorkerMessageResponse() throws Exception { - StreamingWorkerStatusReporter reporter = - StreamingWorkerStatusReporter.forTesting( - true, - mockWorkUnitClient, - () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME, - () -> Collections.emptyList(), - mockFailureTracker, - StreamingCounters.create(), - mockMemoryMonitor, - mockExecutor, - (threadName) -> Executors.newSingleThreadScheduledExecutor(), - DEFAULT_HARNESS_REPORTING_PERIOD, - DEFAULT_PER_WORKER_METRICS_PERIOD); - WorkerMessageResponse workerMessageResponse = new WorkerMessageResponse(); when(mockWorkUnitClient.reportWorkerMessage(any())) - .thenReturn(Collections.singletonList(workerMessageResponse)); + .thenReturn(Collections.singletonList(new WorkerMessageResponse())); reporter.reportPeriodicWorkerMessage(); verify(mockExecutor, Mockito.times(0)).setMaximumPoolSize(anyInt(), anyInt()); } + + private StreamingWorkerStatusReporter buildWorkerStatusReporterForTest() { + return StreamingWorkerStatusReporter.builder() + .setPublishCounters(true) + .setDataflowServiceClient(mockWorkUnitClient) + .setWindmillQuotaThrottleTime(() -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME) + .setAllStageInfo(Collections::emptyList) + .setFailureTracker(mockFailureTracker) + .setStreamingCounters(StreamingCounters.create()) + .setMemoryMonitor(mockMemoryMonitor) + .setWorkExecutor(mockExecutor) + .setExecutorFactory((threadName) -> Executors.newSingleThreadScheduledExecutor()) + .setWindmillHarnessUpdateReportingPeriodMillis(DEFAULT_HARNESS_REPORTING_PERIOD) + .setPerWorkerMetricsUpdateReportingPeriodMillis(DEFAULT_PER_WORKER_METRICS_PERIOD) + .build(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java index dc6cc5641055..aa2767d5472d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java @@ -17,13 +17,13 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; @@ -96,7 +96,7 @@ public void testStartStream_startsAllStreams() { newWindmillStreamSender( GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); verify(streamFactory) .createDirectGetWorkStream( @@ -113,8 +113,8 @@ public void testStartStream_startsAllStreams() { any(), eq(workItemScheduler)); - verify(streamFactory).createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); - verify(streamFactory).createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); + verify(streamFactory).createDirectGetDataStream(eq(connection), any(ThrottleTimer.class)); + verify(streamFactory).createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class)); } @Test @@ -126,9 +126,9 @@ public void testStartStream_onlyStartsStreamsOnce() { newWindmillStreamSender( GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); - windmillStreamSender.startStreams(); - windmillStreamSender.startStreams(); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); + windmillStreamSender.start(); + windmillStreamSender.start(); verify(streamFactory, times(1)) .createDirectGetWorkStream( @@ -146,9 +146,9 @@ public void testStartStream_onlyStartsStreamsOnce() { eq(workItemScheduler)); verify(streamFactory, times(1)) - .createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectGetDataStream(eq(connection), any(ThrottleTimer.class)); verify(streamFactory, times(1)) - .createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class)); } @Test @@ -160,10 +160,10 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted newWindmillStreamSender( GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); - Thread startStreamThread = new Thread(windmillStreamSender::startStreams); + Thread startStreamThread = new Thread(windmillStreamSender::start); startStreamThread.start(); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); startStreamThread.join(); @@ -183,23 +183,52 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted eq(workItemScheduler)); verify(streamFactory, times(1)) - .createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectGetDataStream(eq(connection), any(ThrottleTimer.class)); verify(streamFactory, times(1)) - .createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class)); } @Test - public void testCloseAllStreams_doesNotCloseUnstartedStreams() { + public void testCloseAllStreams_closesAllStreams() { + long itemBudget = 1L; + long byteBudget = 1L; + GetWorkRequest getWorkRequestWithBudget = + GET_WORK_REQUEST.toBuilder().setMaxItems(itemBudget).setMaxBytes(byteBudget).build(); + GrpcWindmillStreamFactory mockStreamFactory = mock(GrpcWindmillStreamFactory.class); + GetWorkStream mockGetWorkStream = mock(GetWorkStream.class); + GetDataStream mockGetDataStream = mock(GetDataStream.class); + CommitWorkStream mockCommitWorkStream = mock(CommitWorkStream.class); + + when(mockStreamFactory.createDirectGetWorkStream( + eq(connection), + eq(getWorkRequestWithBudget), + any(ThrottleTimer.class), + any(), + any(), + any(), + eq(workItemScheduler))) + .thenReturn(mockGetWorkStream); + + when(mockStreamFactory.createDirectGetDataStream(eq(connection), any(ThrottleTimer.class))) + .thenReturn(mockGetDataStream); + when(mockStreamFactory.createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class))) + .thenReturn(mockCommitWorkStream); + WindmillStreamSender windmillStreamSender = - newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build()); + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build(), + mockStreamFactory); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.start(); + windmillStreamSender.close(); - verifyNoInteractions(streamFactory); + verify(mockGetWorkStream).shutdown(); + verify(mockGetDataStream).shutdown(); + verify(mockCommitWorkStream).shutdown(); } @Test - public void testCloseAllStreams_closesAllStreams() { + public void testCloseAllStreams_doesNotStartStreamsAfterClose() { long itemBudget = 1L; long byteBudget = 1L; GetWorkRequest getWorkRequestWithBudget = @@ -219,9 +248,9 @@ public void testCloseAllStreams_closesAllStreams() { eq(workItemScheduler))) .thenReturn(mockGetWorkStream); - when(mockStreamFactory.createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class))) + when(mockStreamFactory.createDirectGetDataStream(eq(connection), any(ThrottleTimer.class))) .thenReturn(mockGetDataStream); - when(mockStreamFactory.createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class))) + when(mockStreamFactory.createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class))) .thenReturn(mockCommitWorkStream); WindmillStreamSender windmillStreamSender = @@ -229,14 +258,30 @@ public void testCloseAllStreams_closesAllStreams() { GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build(), mockStreamFactory); - windmillStreamSender.startStreams(); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.close(); + + verify(mockGetWorkStream, times(0)).start(); + verify(mockGetDataStream, times(0)).start(); + verify(mockCommitWorkStream, times(0)).start(); verify(mockGetWorkStream).shutdown(); verify(mockGetDataStream).shutdown(); verify(mockCommitWorkStream).shutdown(); } + @Test + public void testStartStream_afterCloseThrows() { + long itemBudget = 1L; + long byteBudget = 1L; + + WindmillStreamSender windmillStreamSender = + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); + + windmillStreamSender.close(); + assertThrows(IllegalStateException.class, windmillStreamSender::start); + } + private WindmillStreamSender newWindmillStreamSender(GetWorkBudget budget) { return newWindmillStreamSender(budget, streamFactory); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java index ad77958837a1..734925289920 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java @@ -293,40 +293,4 @@ public void testRenderSummaryHtml() { + "Work Queue Bytes: 0/10000000
/n"; assertEquals(expectedSummaryHtml, executor.summaryHtml()); } - - @Test - public void testExecute_updatesThreadNameForExecutableWork() throws InterruptedException { - CountDownLatch waitForWorkExecution = new CountDownLatch(1); - ExecutableWork executableWork = - createWork( - work -> { - assertTrue( - Thread.currentThread() - .getName() - .contains( - BoundedQueueExecutor.debugFormattedWorkToken( - work.getWorkItem().getWorkToken()))); - waitForWorkExecution.countDown(); - }); - executor.execute(executableWork, executableWork.getWorkItem().getSerializedSize()); - waitForWorkExecution.await(); - } - - @Test - public void testForceExecute_updatesThreadNameForExecutableWork() throws InterruptedException { - CountDownLatch waitForWorkExecution = new CountDownLatch(1); - ExecutableWork executableWork = - createWork( - work -> { - assertTrue( - Thread.currentThread() - .getName() - .contains( - BoundedQueueExecutor.debugFormattedWorkToken( - work.getWorkItem().getWorkToken()))); - waitForWorkExecution.countDown(); - }); - executor.forceExecute(executableWork, executableWork.getWorkItem().getSerializedSize()); - waitForWorkExecution.await(); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java new file mode 100644 index 000000000000..05fbc6f969df --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import java.io.PrintWriter; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.LoggerFactory; + +@RunWith(JUnit4.class) +public class AbstractWindmillStreamTest { + private static final long DEADLINE_SECONDS = 10; + private final Set> streamRegistry = ConcurrentHashMap.newKeySet(); + private final StreamObserverFactory streamObserverFactory = + StreamObserverFactory.direct(DEADLINE_SECONDS, 1); + + @Before + public void setUp() { + streamRegistry.clear(); + } + + private TestStream newStream( + Function, StreamObserver> clientFactory) { + return new TestStream(clientFactory, streamRegistry, streamObserverFactory); + } + + @Test + public void testShutdown_notBlockedBySend() throws InterruptedException, ExecutionException { + CountDownLatch sendBlocker = new CountDownLatch(1); + Function, StreamObserver> clientFactory = + ignored -> + new CallStreamObserver() { + @Override + public void onNext(Integer integer) { + try { + sendBlocker.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + + @Override + public boolean isReady() { + return false; + } + + @Override + public void setOnReadyHandler(Runnable runnable) {} + + @Override + public void disableAutoInboundFlowControl() {} + + @Override + public void request(int i) {} + + @Override + public void setMessageCompression(boolean b) {} + }; + + TestStream testStream = newStream(clientFactory); + testStream.start(); + ExecutorService sendExecutor = Executors.newSingleThreadExecutor(); + Future sendFuture = + sendExecutor.submit( + () -> + assertThrows(WindmillStreamShutdownException.class, () -> testStream.testSend(1))); + testStream.shutdown(); + + // Sleep a bit to give sendExecutor time to execute the send(). + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + + sendBlocker.countDown(); + assertThat(sendFuture.get()).isInstanceOf(WindmillStreamShutdownException.class); + } + + private static class TestStream extends AbstractWindmillStream { + private final AtomicInteger numStarts = new AtomicInteger(); + + private TestStream( + Function, StreamObserver> clientFactory, + Set> streamRegistry, + StreamObserverFactory streamObserverFactory) { + super( + LoggerFactory.getLogger(AbstractWindmillStreamTest.class), + "Test", + clientFactory, + FluentBackoff.DEFAULT.backoff(), + streamObserverFactory, + streamRegistry, + 1, + "Test"); + } + + @Override + protected void onResponse(Integer response) {} + + @Override + protected void onNewStream() { + numStarts.incrementAndGet(); + } + + @Override + protected boolean hasPendingRequests() { + return false; + } + + @Override + protected void startThrottleTimer() {} + + public void testSend(Integer i) + throws ResettableThrowingStreamObserver.StreamClosedException, + WindmillStreamShutdownException { + trySend(i); + } + + @Override + protected void sendHealthCheck() {} + + @Override + protected void appendSpecificHtml(PrintWriter writer) {} + + @Override + protected void shutdownInternal() {} + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java new file mode 100644 index 000000000000..790c155d94d6 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.TerminatingStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.LoggerFactory; + +@RunWith(JUnit4.class) +public class ResettableThrowingStreamObserverTest { + private final TerminatingStreamObserver delegate = newDelegate(); + + private static TerminatingStreamObserver newDelegate() { + return spy( + new TerminatingStreamObserver() { + @Override + public void onNext(Integer integer) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + + @Override + public void terminate(Throwable terminationException) {} + }); + } + + @Test + public void testPoison_beforeDelegateSet() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + verifyNoInteractions(delegate); + } + + @Test + public void testPoison_afterDelegateSet() throws WindmillStreamShutdownException { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.reset(); + observer.poison(); + verify(delegate).terminate(isA(WindmillStreamShutdownException.class)); + } + + @Test + public void testReset_afterPoisonedThrows() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, observer::reset); + } + + @Test + public void testOnNext_afterPoisonedThrows() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, () -> observer.onNext(1)); + } + + @Test + public void testOnError_afterPoisonedThrows() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + assertThrows( + WindmillStreamShutdownException.class, + () -> observer.onError(new RuntimeException("something bad happened."))); + } + + @Test + public void testOnCompleted_afterPoisonedThrows() { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, observer::onCompleted); + } + + @Test + public void testReset_usesNewDelegate() + throws WindmillStreamShutdownException, + ResettableThrowingStreamObserver.StreamClosedException { + List> delegates = new ArrayList<>(); + ResettableThrowingStreamObserver observer = + newStreamObserver( + () -> { + TerminatingStreamObserver delegate = newDelegate(); + delegates.add(delegate); + return delegate; + }); + observer.reset(); + observer.onNext(1); + observer.reset(); + observer.onNext(2); + + StreamObserver firstObserver = delegates.get(0); + StreamObserver secondObserver = delegates.get(1); + + verify(firstObserver).onNext(eq(1)); + verify(secondObserver).onNext(eq(2)); + } + + private ResettableThrowingStreamObserver newStreamObserver( + Supplier> delegate) { + return new ResettableThrowingStreamObserver<>(delegate, LoggerFactory.getLogger(getClass())); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java new file mode 100644 index 000000000000..564b2e664505 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.function.Supplier; +import org.joda.time.DateTime; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class StreamDebugMetricsTest { + + @Test + public void testSummaryMetrics_noRestarts() throws InterruptedException { + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + streamDebugMetrics.recordStart(); + streamDebugMetrics.recordSend(); + streamDebugMetrics.recordResponse(); + Thread.sleep(1000); + StreamDebugMetrics.Snapshot metricsSnapshot = streamDebugMetrics.getSummaryMetrics(); + assertFalse(metricsSnapshot.shutdownTime().isPresent()); + assertTrue(metricsSnapshot.sleepLeft() <= 0); + assertThat(metricsSnapshot.streamAge()).isGreaterThan(0); + assertThat(metricsSnapshot.timeSinceLastSend()).isGreaterThan(0); + assertThat(metricsSnapshot.timeSinceLastResponse()).isGreaterThan(0); + assertFalse(metricsSnapshot.restartMetrics().isPresent()); + + streamDebugMetrics.recordShutdown(); + StreamDebugMetrics.Snapshot metricsSnapshotAfterShutdown = + streamDebugMetrics.getSummaryMetrics(); + assertTrue(metricsSnapshotAfterShutdown.shutdownTime().isPresent()); + } + + @Test + public void testSummaryMetrics_sleep() { + long sleepMs = 100; + Instant aLongTimeAgo = Instant.parse("1998-09-04T00:00:00Z"); + Supplier fakeClock = () -> aLongTimeAgo; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(fakeClock); + streamDebugMetrics.recordSleep(sleepMs); + StreamDebugMetrics.Snapshot metricsSnapshot = streamDebugMetrics.getSummaryMetrics(); + assertEquals(sleepMs, metricsSnapshot.sleepLeft()); + } + + @Test + public void testSummaryMetrics_withRestarts() { + String restartReason = "something bad happened"; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + streamDebugMetrics.incrementAndGetErrors(); + streamDebugMetrics.incrementAndGetRestarts(); + streamDebugMetrics.recordRestartReason(restartReason); + + StreamDebugMetrics.Snapshot metricsSnapshot = streamDebugMetrics.getSummaryMetrics(); + assertTrue(metricsSnapshot.restartMetrics().isPresent()); + StreamDebugMetrics.RestartMetrics restartMetrics = metricsSnapshot.restartMetrics().get(); + assertThat(restartMetrics.lastRestartReason()).isEqualTo(restartReason); + assertThat(restartMetrics.restartCount()).isEqualTo(1); + assertThat(restartMetrics.errorCount()).isEqualTo(1); + assertTrue(restartMetrics.lastRestartTime().isPresent()); + assertThat(restartMetrics.lastRestartTime().get()).isLessThan(DateTime.now()); + assertThat(restartMetrics.lastRestartTime().get().toInstant()).isGreaterThan(Instant.EPOCH); + } + + @Test + public void testResponseDebugString_neverReceivedResponse() { + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + assertFalse(streamDebugMetrics.responseDebugString(Instant.now().getMillis()).isPresent()); + } + + @Test + public void testResponseDebugString_withResponse() { + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + streamDebugMetrics.recordResponse(); + assertTrue(streamDebugMetrics.responseDebugString(Instant.now().getMillis()).isPresent()); + } + + @Test + public void testGetStartTime() { + Instant aLongTimeAgo = Instant.parse("1998-09-04T00:00:00Z"); + Supplier fakeClock = () -> aLongTimeAgo; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(fakeClock); + assertEquals(0, streamDebugMetrics.getStartTimeMs()); + streamDebugMetrics.recordStart(); + assertThat(streamDebugMetrics.getStartTimeMs()).isEqualTo(aLongTimeAgo.getMillis()); + } + + @Test + public void testGetLastSendTime() { + Instant aLongTimeAgo = Instant.parse("1998-09-04T00:00:00Z"); + Supplier fakeClock = () -> aLongTimeAgo; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(fakeClock); + assertEquals(0, streamDebugMetrics.getLastSendTimeMs()); + streamDebugMetrics.recordSend(); + assertThat(streamDebugMetrics.getLastSendTimeMs()).isEqualTo(aLongTimeAgo.getMillis()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java index bdad382c9af2..fdd213223987 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java @@ -260,5 +260,8 @@ public String backendWorkerToken() { public void shutdown() { halfClose(); } + + @Override + public void start() {} } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 546a2883e3b2..4072b582c831 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -53,6 +53,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; @@ -60,13 +61,15 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class StreamingEngineWorkCommitterTest { - + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public ErrorCollector errorCollector = new ErrorCollector(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private WorkCommitter workCommitter; private FakeWindmillServer fakeWindmillServer; private Supplier> commitWorkStreamFactory; @@ -121,6 +124,7 @@ public void setUp() throws IOException { private WorkCommitter createWorkCommitter(Consumer onCommitComplete) { return StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory(commitWorkStreamFactory) .setOnCommitComplete(onCommitComplete) .build(); @@ -261,6 +265,10 @@ public void testStop_drainsCommitQueue() { Supplier fakeCommitWorkStream = () -> new CommitWorkStream() { + + @Override + public void start() {} + @Override public RequestBatcher batcher() { return new RequestBatcher() { @@ -342,6 +350,7 @@ public void testMultipleCommitSendersSingleStream() { Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); workCommitter = StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory(commitWorkStreamFactory) .setNumCommitSenders(5) .setOnCommitComplete(completeCommits::add) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java new file mode 100644 index 000000000000..7de824b86fd2 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.spy; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.InOrder; + +@RunWith(JUnit4.class) +public class GrpcCommitWorkStreamTest { + private static final String FAKE_SERVER_NAME = "Fake server for GrpcCommitWorkStreamTest"; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String COMPUTATION_ID = "computationId"; + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + + private static Windmill.WorkItemCommitRequest workItemCommitRequest(long value) { + return Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(value) + .setWorkToken(value) + .setCacheToken(value) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + } + + private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub testStub) { + serviceRegistry.addService(testStub); + GrpcCommitWorkStream commitWorkStream = + (GrpcCommitWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createCommitWorkStream( + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer()); + commitWorkStream.start(); + return commitWorkStream; + } + + @Test + public void testShutdown_abortsQueuedCommits() throws InterruptedException { + int numCommits = 5; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); + Set onDone = new HashSet<>(); + + TestCommitWorkStreamRequestObserver requestObserver = + spy(new TestCommitWorkStreamRequestObserver()); + CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + InOrder requestObserverVerifier = inOrder(requestObserver); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(i), + commitStatus -> { + onDone.add(commitStatus); + commitProcessed.countDown(); + }); + } + } catch (StreamObserverCancelledException ignored) { + } + + // Verify that we sent the commits above in a request + the initial header. + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER))); + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> !request.getCommitChunkList().isEmpty())); + requestObserverVerifier.verifyNoMoreInteractions(); + + // We won't get responses so we will have some pending requests. + assertTrue(commitWorkStream.hasPendingRequests()); + commitWorkStream.shutdown(); + commitProcessed.await(); + + assertThat(onDone).containsExactly(Windmill.CommitStatus.ABORTED); + } + + @Test + public void testCommitWorkItem_afterShutdown() { + int numCommits = 5; + + CommitWorkStreamTestStub testStub = + new CommitWorkStreamTestStub(new TestCommitWorkStreamRequestObserver()); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertTrue(batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), ignored -> {})); + } + } + commitWorkStream.shutdown(); + + AtomicReference commitStatus = new AtomicReference<>(); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertTrue( + batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatus::set)); + } + } + + assertThat(commitStatus.get()).isEqualTo(Windmill.CommitStatus.ABORTED); + } + + @Test + public void testSend_notCalledAfterShutdown() { + int numCommits = 5; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); + + TestCommitWorkStreamRequestObserver requestObserver = + spy(new TestCommitWorkStreamRequestObserver()); + InOrder requestObserverVerifier = inOrder(requestObserver); + + CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(i), + commitStatus -> commitProcessed.countDown())); + } + // Shutdown the stream before we exit the try-with-resources block which will try to send() + // the batched request. + commitWorkStream.shutdown(); + } + + // send() uses the requestObserver to send requests. We expect 1 send since startStream() sends + // the header, which happens before we shutdown. + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER))); + requestObserverVerifier.verify(requestObserver).onError(any()); + requestObserverVerifier.verifyNoMoreInteractions(); + } + + private static class TestCommitWorkStreamRequestObserver + implements StreamObserver { + private @Nullable StreamObserver responseObserver; + + @Override + public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + if (responseObserver != null) { + responseObserver.onCompleted(); + } + } + } + + private static class CommitWorkStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestCommitWorkStreamRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private CommitWorkStreamTestStub(TestCommitWorkStreamRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver commitWorkStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + ((ServerCallStreamObserver) responseObserver) + .setOnCancelHandler(() -> {}); + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java new file mode 100644 index 000000000000..6584ed1c5ae6 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final WorkItemScheduler NO_OP_WORK_ITEM_SCHEDULER = + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setClientId(1L) + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + private static void assertHeader( + Windmill.StreamingGetWorkRequest getWorkRequest, GetWorkBudget expectedInitialBudget) { + assertTrue(getWorkRequest.hasRequest()); + assertFalse(getWorkRequest.hasRequestExtension()); + assertThat(getWorkRequest.getRequest()) + .isEqualTo( + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(expectedInitialBudget.items()) + .setMaxBytes(expectedInitialBudget.bytes()) + .build()); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + checkNotNull(stream).shutdown(); + inProcessChannel.shutdownNow(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + GrpcDirectGetWorkStream getWorkStream = + (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + getWorkStream.start(); + return getWorkStream; + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertHeader(requestObserver.sent().get(0), GetWorkBudget.noBudget()); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.shutdown(); + stream.setBudget( + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(1); + assertHeader(Iterables.getOnlyElement(requests), GetWorkBudget.noBudget()); + } + + @Test + public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(1).setBytes(100).build(); + Set scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); + + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()) + .isEqualTo( + extension( + GetWorkBudget.builder() + .setItems(1) + .setBytes(initialBudget.bytes() - inFlightBytes) + .build())); + } + + @Test + public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + Set scheduledWorkItems = new HashSet<>(); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> + scheduledWorkItems.add(work)); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testOnResponse_stopsThrottling() { + ThrottleTimer throttleTimer = new ThrottleTimer(); + TestGetWorkRequestObserver requestObserver = + new TestGetWorkRequestObserver(new CountDownLatch(1)); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), throttleTimer, NO_OP_WORK_ITEM_SCHEDULER); + stream.startThrottleTimer(); + assertTrue(throttleTimer.throttled()); + testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance()); + assertFalse(throttleTimer.throttled()); + } + + private static class GetWorkStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetWorkRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetWorkStreamTestStub(TestGetWorkRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getWorkStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + + private void injectResponse(Windmill.StreamingGetWorkResponseChunk responseChunk) { + checkNotNull(responseObserver).onNext(responseChunk); + } + } + + private static class TestGetWorkRequestObserver + implements StreamObserver { + private final List requests = + Collections.synchronizedList(new ArrayList<>()); + private final CountDownLatch waitForRequests; + private @Nullable volatile StreamObserver + responseObserver; + + public TestGetWorkRequestObserver(CountDownLatch waitForRequests) { + this.waitForRequests = waitForRequests; + } + + @Override + public void onNext(Windmill.StreamingGetWorkRequest request) { + requests.add(request); + waitForRequests.countDown(); + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + + List sent() { + return requests; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java index 3f746d91a868..c04456906ea2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java @@ -34,7 +34,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.hamcrest.Matcher; import org.junit.Test; @@ -55,9 +54,6 @@ public static class RespectsJobSettingTest { public void createsNewStubWhenIsolatedChannelsConfigIsChanged() { DataflowWorkerHarnessOptions options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - options.setExperiments( - Lists.newArrayList( - GrpcDispatcherClient.STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)); GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); // Create first time with Isolated channels disabled @@ -91,27 +87,18 @@ public static class RespectsPipelineOptionsTest { public static Collection data() { List list = new ArrayList<>(); for (Boolean pipelineOption : new Boolean[] {true, false}) { - list.add(new Object[] {/*experimentEnabled=*/ false, pipelineOption}); - list.add(new Object[] {/*experimentEnabled=*/ true, pipelineOption}); + list.add(new Object[] {pipelineOption}); } return list; } @Parameter(0) - public Boolean experimentEnabled; - - @Parameter(1) public Boolean pipelineOption; @Test public void ignoresIsolatedChannelsConfigWithPipelineOption() { DataflowWorkerHarnessOptions options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - if (experimentEnabled) { - options.setExperiments( - Lists.newArrayList( - GrpcDispatcherClient.STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)); - } options.setUseWindmillIsolatedChannels(pipelineOption); GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java new file mode 100644 index 000000000000..dc2dce7807a9 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcGetDataStreamRequestsTest { + + @Test + public void testQueuedRequest_globalRequestsFirstComparator() { + List requests = new ArrayList<>(); + Windmill.KeyedGetDataRequest keyedGetDataRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(1L) + .setShardingKey(1L) + .setWorkToken(1L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + requests.add( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 1, "computation1", keyedGetDataRequest1)); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(2L) + .setShardingKey(2L) + .setWorkToken(2L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + requests.add( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 2, "computation2", keyedGetDataRequest2)); + + Windmill.GlobalDataRequest globalDataRequest = + Windmill.GlobalDataRequest.newBuilder() + .setDataId( + Windmill.GlobalDataId.newBuilder() + .setTag("globalData") + .setVersion(ByteString.EMPTY) + .build()) + .setComputationId("computation1") + .build(); + requests.add(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest)); + + requests.sort(GrpcGetDataStreamRequests.QueuedRequest.globalRequestsFirst()); + + // First one should be the global request. + assertTrue(requests.get(0).getDataRequest().isGlobal()); + } + + @Test + public void testQueuedBatch_asGetDataRequest() { + GrpcGetDataStreamRequests.QueuedBatch queuedBatch = new GrpcGetDataStreamRequests.QueuedBatch(); + + Windmill.KeyedGetDataRequest keyedGetDataRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(1L) + .setShardingKey(1L) + .setWorkToken(1L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + queuedBatch.addRequest( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 1, "computation1", keyedGetDataRequest1)); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(2L) + .setShardingKey(2L) + .setWorkToken(2L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + queuedBatch.addRequest( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 2, "computation2", keyedGetDataRequest2)); + + Windmill.GlobalDataRequest globalDataRequest = + Windmill.GlobalDataRequest.newBuilder() + .setDataId( + Windmill.GlobalDataId.newBuilder() + .setTag("globalData") + .setVersion(ByteString.EMPTY) + .build()) + .setComputationId("computation1") + .build(); + queuedBatch.addRequest(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest)); + + Windmill.StreamingGetDataRequest getDataRequest = queuedBatch.asGetDataRequest(); + + assertThat(getDataRequest.getRequestIdCount()).isEqualTo(3); + assertThat(getDataRequest.getGlobalDataRequestList()).containsExactly(globalDataRequest); + assertThat(getDataRequest.getStateRequestList()) + .containsExactly( + Windmill.ComputationGetDataRequest.newBuilder() + .setComputationId("computation1") + .addRequests(keyedGetDataRequest1) + .build(), + Windmill.ComputationGetDataRequest.newBuilder() + .setComputationId("computation2") + .addRequests(keyedGetDataRequest2) + .build()); + } + + @Test + public void testQueuedBatch_notifyFailed_throwsWindmillStreamShutdownExceptionOnWaiters() { + GrpcGetDataStreamRequests.QueuedBatch queuedBatch = new GrpcGetDataStreamRequests.QueuedBatch(); + CompletableFuture waitFuture = + CompletableFuture.supplyAsync( + () -> + assertThrows( + WindmillStreamShutdownException.class, + queuedBatch::waitForSendOrFailNotification)); + // Wait a few seconds for the above future to get scheduled and run. + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + queuedBatch.notifyFailed(); + waitFuture.join(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java new file mode 100644 index 000000000000..3125def64b32 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcGetDataStreamTest { + private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetDataStreamTest"; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + } + + private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) { + serviceRegistry.addService(testStub); + GrpcGetDataStream getDataStream = + (GrpcGetDataStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .setSendKeyedGetDataRequests(false) + .build() + .createGetDataStream( + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer()); + getDataStream.start(); + return getDataStream; + } + + @Test + public void testRequestKeyedData() { + GetDataStreamTestStub testStub = + new GetDataStreamTestStub(new TestGetDataStreamRequestObserver()); + GrpcGetDataStream getDataStream = createGetDataStream(testStub); + // These will block until they are successfully sent. + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData( + "computationId", + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(1) + .setCacheToken(1) + .setWorkToken(1) + .build()); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + + // Sleep a bit to allow future to run. + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + + Windmill.KeyedGetDataResponse response = + Windmill.KeyedGetDataResponse.newBuilder() + .setShardingKey(1) + .setKey(ByteString.EMPTY) + .build(); + + testStub.injectResponse( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(response.toByteString()) + .setRemainingBytesForResponse(0) + .build()); + + assertThat(sendFuture.join()).isEqualTo(response); + } + + @Test + public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdownException() { + GetDataStreamTestStub testStub = + new GetDataStreamTestStub(new TestGetDataStreamRequestObserver()); + GrpcGetDataStream getDataStream = createGetDataStream(testStub); + int numSendThreads = 5; + ExecutorService getDataStreamSenders = Executors.newFixedThreadPool(numSendThreads); + CountDownLatch waitForSendAttempt = new CountDownLatch(1); + // These will block until they are successfully sent. + List> sendFutures = + IntStream.range(0, 5) + .sequential() + .mapToObj( + i -> + (Runnable) + () -> { + // Prevent some threads from sending until we close the stream. + if (i % 2 == 0) { + try { + waitForSendAttempt.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + try { + getDataStream.requestKeyedData( + "computationId", + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(i) + .setCacheToken(i) + .setWorkToken(i) + .build()); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }) + // Run the code above on multiple threads. + .map(runnable -> CompletableFuture.runAsync(runnable, getDataStreamSenders)) + .collect(Collectors.toList()); + + getDataStream.shutdown(); + + // Free up waiting threads so that they can try to send on a closed stream. + waitForSendAttempt.countDown(); + + for (int i = 0; i < numSendThreads; i++) { + CompletableFuture sendFuture = sendFutures.get(i); + try { + // Wait for future to complete. + sendFuture.join(); + } catch (Exception ignored) { + } + if (i % 2 == 0) { + assertTrue(sendFuture.isCompletedExceptionally()); + ExecutionException e = assertThrows(ExecutionException.class, sendFuture::get); + assertThat(e) + .hasCauseThat() + .hasCauseThat() + .isInstanceOf(WindmillStreamShutdownException.class); + } + } + } + + private static class TestGetDataStreamRequestObserver + implements StreamObserver { + private @Nullable StreamObserver responseObserver; + + @Override + public void onNext(Windmill.StreamingGetDataRequest streamingGetDataRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + if (responseObserver != null) { + responseObserver.onCompleted(); + } + } + } + + private static class GetDataStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetDataStreamRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetDataStreamTestStub(TestGetDataStreamRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getDataStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + ((ServerCallStreamObserver) responseObserver) + .setOnCancelHandler(() -> {}); + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + + private void injectResponse(Windmill.StreamingGetDataResponse getDataResponse) { + checkNotNull(responseObserver).onNext(getDataResponse); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index 4439c409b32f..d74735ee3052 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; -import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.verify; @@ -38,10 +37,8 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; @@ -82,28 +79,24 @@ public class GrpcGetWorkerMetadataStreamTest { private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetWorkerMetadataStreamTest"; @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - private final Set> streamRegistry = new HashSet<>(); + private final GrpcWindmillStreamFactory streamFactory = + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER).build(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; private GrpcGetWorkerMetadataStream stream; private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream( GetWorkerMetadataTestStub getWorkerMetadataTestStub, - int metadataVersion, Consumer endpointsConsumer) { serviceRegistry.addService(getWorkerMetadataTestStub); - return GrpcGetWorkerMetadataStream.create( - responseObserver -> - CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel) - .getWorkerMetadata(responseObserver), - FluentBackoff.DEFAULT.backoff(), - StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1), - streamRegistry, - 1, // logEveryNStreamFailures - TEST_JOB_HEADER, - metadataVersion, - new ThrottleTimer(), - endpointsConsumer); + GrpcGetWorkerMetadataStream getWorkerMetadataStream = + (GrpcGetWorkerMetadataStream) + streamFactory.createGetWorkerMetadataStream( + () -> CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer(), + endpointsConsumer); + getWorkerMetadataStream.start(); + return getWorkerMetadataStream; } @Before @@ -146,8 +139,7 @@ public void testGetWorkerMetadata() { new TestWindmillEndpointsConsumer(); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - int metadataVersion = -1; - stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer); testStub.injectWorkerMetadata(mockResponse); assertThat(testWindmillEndpointsConsumer.globalDataEndpoints.keySet()) @@ -175,8 +167,7 @@ public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() { GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - int metadataVersion = 0; - stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer); testStub.injectWorkerMetadata(initialResponse); List newDirectPathEndpoints = @@ -222,8 +213,7 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() { Mockito.spy(new TestWindmillEndpointsConsumer()); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - int metadataVersion = 0; - stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer); testStub.injectWorkerMetadata(freshEndpoints); List staleDirectPathEndpoints = @@ -252,7 +242,7 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() { public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer()); + stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer()); testStub.injectWorkerMetadata( WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) @@ -260,17 +250,17 @@ public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) .build()); - assertTrue(streamRegistry.contains(stream)); + assertTrue(streamFactory.streamRegistry().contains(stream)); stream.halfClose(); - assertFalse(streamRegistry.contains(stream)); + assertFalse(streamFactory.streamRegistry().contains(stream)); } @Test - public void testSendHealthCheck() { + public void testSendHealthCheck() throws WindmillStreamShutdownException { TestGetWorkMetadataRequestObserver requestObserver = Mockito.spy(new TestGetWorkMetadataRequestObserver()); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(requestObserver); - stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer()); + stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer()); stream.sendHealthCheck(); verify(requestObserver).onNext(WorkerMetadataRequest.getDefaultInstance()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 239e3979a3b7..a595524ca582 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -22,12 +22,14 @@ import java.io.InputStream; import java.io.SequenceInputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -71,6 +73,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactoryFactory; @@ -115,6 +118,7 @@ public class GrpcWindmillServerTest { private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServerTest.class); private static final int STREAM_CHUNK_SIZE = 2 << 20; private final long clientId = 10L; + private final Set openedChannels = new HashSet<>(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Rule public GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @@ -131,16 +135,18 @@ public void setUp() throws Exception { @After public void tearDown() throws Exception { server.shutdownNow(); + openedChannels.forEach(ManagedChannel::shutdownNow); } private void startServerAndClient(List experiments) throws Exception { String name = "Fake server for " + getClass(); this.server = - InProcessServerBuilder.forName(name) - .fallbackHandlerRegistry(serviceRegistry) - .executor(Executors.newFixedThreadPool(1)) - .build() - .start(); + grpcCleanup.register( + InProcessServerBuilder.forName(name) + .fallbackHandlerRegistry(serviceRegistry) + .executor(Executors.newFixedThreadPool(1)) + .build() + .start()); this.client = GrpcWindmillServer.newTestInstance( @@ -149,7 +155,12 @@ private void startServerAndClient(List experiments) throws Exception { clientId, new FakeWindmillStubFactoryFactory( new FakeWindmillStubFactory( - () -> grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name))))); + () -> { + ManagedChannel channel = + grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name)); + openedChannels.add(channel); + return channel; + }))); } private void maybeInjectError(Stream stream) { @@ -460,8 +471,9 @@ private void flushResponse() { "Sending batched response of {} ids", responseBuilder.getRequestIdCount()); try { responseObserver.onNext(responseBuilder.build()); - } catch (IllegalStateException e) { + } catch (Exception e) { // Stream is already closed. + LOG.warn(Arrays.toString(e.getStackTrace())); } responseBuilder.clear(); } @@ -480,16 +492,24 @@ private void flushResponse() { final String s = i % 5 == 0 ? largeString(i) : "tag"; executor.submit( () -> { - errorCollector.checkThat( - stream.requestKeyedData("computation", makeGetDataRequest(key, s)), - Matchers.equalTo(makeGetDataResponse(s))); + try { + errorCollector.checkThat( + stream.requestKeyedData("computation", makeGetDataRequest(key, s)), + Matchers.equalTo(makeGetDataResponse(s))); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } done.countDown(); }); executor.execute( () -> { - errorCollector.checkThat( - stream.requestGlobalData(makeGlobalDataRequest(key)), - Matchers.equalTo(makeGlobalDataResponse(key))); + try { + errorCollector.checkThat( + stream.requestGlobalData(makeGlobalDataRequest(key)), + Matchers.equalTo(makeGlobalDataResponse(key))); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } done.countDown(); }); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java new file mode 100644 index 000000000000..6bc713aa7747 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; +import org.apache.beam.sdk.fn.stream.AdvancingPhaser; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.VerifyException; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class DirectStreamObserverTest { + + @Test + public void testOnNext_onCompleted() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>( + new AdvancingPhaser(1), delegate, Long.MAX_VALUE, Integer.MAX_VALUE); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + Future onNextFuture = + onNextExecutor.submit( + () -> { + streamObserver.onNext(1); + streamObserver.onNext(1); + streamObserver.onNext(1); + }); + + // Wait for all of the onNext's to run. + onNextFuture.get(); + + verify(delegate, times(3)).onNext(eq(1)); + + streamObserver.onCompleted(); + verify(delegate, times(1)).onCompleted(); + } + + @Test + public void testOnNext_onError() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>( + new AdvancingPhaser(1), delegate, Long.MAX_VALUE, Integer.MAX_VALUE); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + Future onNextFuture = + onNextExecutor.submit( + () -> { + streamObserver.onNext(1); + streamObserver.onNext(1); + streamObserver.onNext(1); + }); + + // Wait for all of the onNext's to run. + onNextFuture.get(); + + verify(delegate, times(3)).onNext(eq(1)); + + RuntimeException error = new RuntimeException(); + streamObserver.onError(error); + verify(delegate, times(1)).onError(same(error)); + } + + @Test + public void testOnCompleted_executedOnce() { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + + streamObserver.onCompleted(); + assertThrows(IllegalStateException.class, streamObserver::onCompleted); + } + + @Test + public void testOnError_executedOnce() { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + + RuntimeException error = new RuntimeException(); + streamObserver.onError(error); + assertThrows(IllegalStateException.class, () -> streamObserver.onError(error)); + verify(delegate, times(1)).onError(same(error)); + } + + @Test + public void testOnNext_waitForReady() throws InterruptedException, ExecutionException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch blockLatch = new CountDownLatch(1); + Future<@Nullable Object> onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + try { + // We will check isReady on the next message, will block here. + streamObserver.onNext(1); + streamObserver.onNext(1); + blockLatch.countDown(); + return null; + } catch (Throwable e) { + return e; + } + }); + + while (delegate.getNumIsReadyChecks() <= 1) { + // Wait for isReady check to block. + Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS); + } + + delegate.setIsReady(true); + blockLatch.await(); + verify(delegate, times(3)).onNext(eq(1)); + assertNull(onNextFuture.get()); + + streamObserver.onCompleted(); + verify(delegate, times(1)).onCompleted(); + } + + @Test + public void testTerminate_waitingForReady() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(2)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch blockLatch = new CountDownLatch(1); + Future onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + blockLatch.countDown(); + try { + // We will check isReady on the next message, will block here. + streamObserver.onNext(1); + } catch (Throwable e) { + return e; + } + + return new VerifyException(); + }); + RuntimeException terminationException = new RuntimeException("terminated"); + + assertTrue(blockLatch.await(5, TimeUnit.SECONDS)); + streamObserver.terminate(terminationException); + assertThat(onNextFuture.get()).isInstanceOf(StreamObserverCancelledException.class); + verify(delegate).onError(same(terminationException)); + // onNext should only have been called once. + verify(delegate, times(1)).onNext(any()); + } + + @Test + public void testOnNext_interruption() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(2)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch streamObserverExitLatch = new CountDownLatch(1); + Future onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + // We will check isReady on the next message, will block here. + StreamObserverCancelledException e = + assertThrows( + StreamObserverCancelledException.class, () -> streamObserver.onNext(1)); + streamObserverExitLatch.countDown(); + return e; + }); + + // Assert that onNextFuture is blocked. + assertFalse(onNextFuture.isDone()); + assertThat(streamObserverExitLatch.getCount()).isEqualTo(1); + + onNextExecutor.shutdownNow(); + assertTrue(streamObserverExitLatch.await(5, TimeUnit.SECONDS)); + assertThat(onNextFuture.get()).hasCauseThat().isInstanceOf(InterruptedException.class); + + // onNext should only have been called once. + verify(delegate, times(1)).onNext(any()); + } + + @Test + public void testOnNext_timeOut() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(2)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, 1, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch streamObserverExitLatch = new CountDownLatch(1); + Future onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + // We will check isReady on the next message, will block here. + WindmillServerStub.WindmillRpcException e = + assertThrows( + WindmillServerStub.WindmillRpcException.class, + () -> streamObserver.onNext(1)); + streamObserverExitLatch.countDown(); + return e; + }); + + // Assert that onNextFuture is blocked. + assertFalse(onNextFuture.isDone()); + assertThat(streamObserverExitLatch.getCount()).isEqualTo(1); + + assertTrue(streamObserverExitLatch.await(10, TimeUnit.SECONDS)); + assertThat(onNextFuture.get()).hasCauseThat().isInstanceOf(TimeoutException.class); + + // onNext should only have been called once. + verify(delegate, times(1)).onNext(any()); + } + + private static class TestStreamObserver extends CallStreamObserver { + private final CountDownLatch sendBlocker; + private final int blockAfter; + private final AtomicInteger seen = new AtomicInteger(0); + private final AtomicInteger numIsReadyChecks = new AtomicInteger(0); + private volatile boolean isReady = false; + + private TestStreamObserver(int blockAfter) { + this.blockAfter = blockAfter; + this.sendBlocker = new CountDownLatch(1); + } + + @Override + public void onNext(Integer integer) { + try { + if (seen.incrementAndGet() == blockAfter) { + sendBlocker.await(); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + + @Override + public boolean isReady() { + numIsReadyChecks.incrementAndGet(); + return isReady; + } + + public int getNumIsReadyChecks() { + return numIsReadyChecks.get(); + } + + private void setIsReady(boolean isReadyOverride) { + isReady = isReadyOverride; + } + + @Override + public void setOnReadyHandler(Runnable runnable) {} + + @Override + public void disableAutoInboundFlowControl() {} + + @Override + public void request(int i) {} + + @Override + public void setMessageCompression(boolean b) {} + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCacheTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCacheTest.java index 9f8a901cb629..1781261e3400 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCacheTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCacheTest.java @@ -105,19 +105,10 @@ public ManagedChannel apply(WindmillServiceAddress windmillServiceAddress) { @Test public void testRemoveAndClose() throws InterruptedException { String channelName = "existingChannel"; - CountDownLatch verifyRemovalListenerAsync = new CountDownLatch(1); CountDownLatch notifyWhenChannelClosed = new CountDownLatch(1); cache = ChannelCache.forTesting( - ignored -> newChannel(channelName), - () -> { - try { - verifyRemovalListenerAsync.await(); - notifyWhenChannelClosed.countDown(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - }); + ignored -> newChannel(channelName), notifyWhenChannelClosed::countDown); WindmillServiceAddress someAddress = mock(WindmillServiceAddress.class); ManagedChannel cachedChannel = cache.get(someAddress); @@ -125,7 +116,6 @@ public void testRemoveAndClose() throws InterruptedException { // Assert that the removal happened before we check to see if the shutdowns happen to confirm // that removals are async. assertTrue(cache.isEmpty()); - verifyRemovalListenerAsync.countDown(); // Assert that the channel gets shutdown. notifyWhenChannelClosed.await(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java index d06ed0f526c7..8d2623c382e9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java @@ -30,6 +30,8 @@ import java.io.Closeable; import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; import java.nio.charset.StandardCharsets; import java.util.AbstractMap; import java.util.AbstractMap.SimpleEntry; @@ -305,6 +307,26 @@ private K userKeyFromProtoKey(ByteString tag, Coder keyCoder) throws IOEx return keyCoder.decode(keyBytes.newInput(), Context.OUTER); } + private static void assertBuildable( + Windmill.WorkItemCommitRequest.Builder commitWorkRequestBuilder) { + Windmill.WorkItemCommitRequest.Builder clone = commitWorkRequestBuilder.clone(); + if (!clone.hasKey()) { + clone.setKey(ByteString.EMPTY); // key is required to build + } + if (!clone.hasWorkToken()) { + clone.setWorkToken(1357924680L); // workToken is required to build + } + + try { + clone.build(); + } catch (Exception e) { + StringWriter sw = new StringWriter(); + e.printStackTrace(new PrintWriter(sw)); + fail( + "Failed to build commitRequest from: " + commitWorkRequestBuilder + "\n" + sw.toString()); + } + } + @Test public void testMapAddBeforeGet() throws Exception { StateTag> addr = @@ -647,6 +669,8 @@ public void testMapAddPersist() throws Exception { .map(tv -> fromTagValue(tv, StringUtf8Coder.of(), VarIntCoder.of())) .collect(Collectors.toList()), Matchers.containsInAnyOrder(new SimpleEntry<>(tag1, 1), new SimpleEntry<>(tag2, 2))); + + assertBuildable(commitBuilder); } @Test @@ -670,6 +694,8 @@ public void testMapRemovePersist() throws Exception { .map(tv -> fromTagValue(tv, StringUtf8Coder.of(), VarIntCoder.of())) .collect(Collectors.toList()), Matchers.containsInAnyOrder(new SimpleEntry<>(tag1, null), new SimpleEntry<>(tag2, null))); + + assertBuildable(commitBuilder); } @Test @@ -695,6 +721,8 @@ public void testMapClearPersist() throws Exception { assertEquals( protoKeyFromUserKey(null, StringUtf8Coder.of()), commitBuilder.getTagValuePrefixDeletes(0).getTagPrefix()); + + assertBuildable(commitBuilder); } @Test @@ -736,6 +764,8 @@ public void testMapComplexPersist() throws Exception { commitBuilder = Windmill.WorkItemCommitRequest.newBuilder(); assertEquals(0, commitBuilder.getTagValuePrefixDeletesCount()); assertEquals(0, commitBuilder.getValueUpdatesCount()); + + assertBuildable(commitBuilder); } @Test @@ -953,6 +983,8 @@ public void testMultimapRemovePersistPut() { multimapState.put(key, 5); assertThat(multimapState.get(key).read(), Matchers.containsInAnyOrder(4, 5)); + + assertBuildable(commitBuilder); } @Test @@ -1766,6 +1798,8 @@ public void testMultimapPutAndPersist() { builder, new MultimapEntryUpdate(key1, Arrays.asList(1, 2), false), new MultimapEntryUpdate(key2, Collections.singletonList(2), false)); + + assertBuildable(commitBuilder); } @Test @@ -1799,6 +1833,8 @@ public void testMultimapRemovePutAndPersist() { builder, new MultimapEntryUpdate(key1, Arrays.asList(1, 2), true), new MultimapEntryUpdate(key2, Collections.singletonList(4), true)); + + assertBuildable(commitBuilder); } @Test @@ -1825,6 +1861,8 @@ public void testMultimapRemoveAndPersist() { builder, new MultimapEntryUpdate(key1, Collections.emptyList(), true), new MultimapEntryUpdate(key2, Collections.emptyList(), true)); + + assertBuildable(commitBuilder); } @Test @@ -1856,6 +1894,8 @@ public void testMultimapPutRemoveClearAndPersist() { Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); assertEquals(0, builder.getUpdatesCount()); assertTrue(builder.getDeleteAll()); + + assertBuildable(commitBuilder); } @Test @@ -1894,6 +1934,8 @@ false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); assertTagMultimapUpdates( builder, new MultimapEntryUpdate(key1, Collections.singletonList(4), false)); + + assertBuildable(commitBuilder); } @Test @@ -1938,6 +1980,8 @@ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) ByteArrayCoder.of().decode(entryUpdate.getEntryName().newInput(), Context.OUTER); assertArrayEquals(key1, decodedKey); assertTrue(entryUpdate.getDeleteAll()); + + assertBuildable(commitBuilder); } @Test @@ -2053,6 +2097,8 @@ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) Windmill.WorkItemCommitRequest.Builder commitBuilder = Windmill.WorkItemCommitRequest.newBuilder(); underTest.persist(commitBuilder); + + assertBuildable(commitBuilder); } @Test @@ -2253,6 +2299,8 @@ public void testOrderedListAddPersist() throws Exception { assertEquals("hello", updates.getInserts(0).getEntries(0).getValue().toStringUtf8()); assertEquals(1000, updates.getInserts(0).getEntries(0).getSortKey()); assertEquals(IdTracker.NEW_RANGE_MIN_ID, updates.getInserts(0).getEntries(0).getId()); + + assertBuildable(commitBuilder); } @Test @@ -2284,6 +2332,8 @@ public void testOrderedListClearPersist() throws Exception { assertEquals(IdTracker.NEW_RANGE_MIN_ID, updates.getInserts(0).getEntries(0).getId()); assertEquals(IdTracker.NEW_RANGE_MIN_ID + 1, updates.getInserts(0).getEntries(1).getId()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2331,6 +2381,8 @@ public void testOrderedListDeleteRangePersist() { assertEquals(4000, updates.getInserts(0).getEntries(1).getSortKey()); assertEquals(IdTracker.NEW_RANGE_MIN_ID, updates.getInserts(0).getEntries(0).getId()); assertEquals(IdTracker.NEW_RANGE_MIN_ID + 1, updates.getInserts(0).getEntries(1).getId()); + + assertBuildable(commitBuilder); } @Test @@ -2539,6 +2591,8 @@ public void testOrderedListPersistEmpty() throws Exception { assertEquals(1, updates.getDeletesCount()); assertEquals(WindmillOrderedList.MIN_TS_MICROS, updates.getDeletes(0).getRange().getStart()); assertEquals(WindmillOrderedList.MAX_TS_MICROS, updates.getDeletes(0).getRange().getLimit()); + + assertBuildable(commitBuilder); } @Test @@ -2653,6 +2707,8 @@ public void testBagAddPersist() throws Exception { assertEquals("hello", bagUpdates.getValues(0).toStringUtf8()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2678,6 +2734,8 @@ public void testBagClearPersist() throws Exception { assertEquals("world", tagBag.getValues(0).toStringUtf8()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2693,6 +2751,8 @@ public void testBagPersistEmpty() throws Exception { // 1 bag update = the clear assertEquals(1, commitBuilder.getBagUpdatesCount()); + + assertBuildable(commitBuilder); } @Test @@ -2806,6 +2866,8 @@ public void testCombiningAddPersist() throws Exception { 11, CoderUtils.decodeFromByteArray(accumCoder, bagUpdates.getValues(0).toByteArray())[0]); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2835,6 +2897,8 @@ public void testCombiningAddPersistWithCompact() throws Exception { assertTrue(bagUpdates.getDeleteAll()); assertEquals( 111, CoderUtils.decodeFromByteArray(accumCoder, bagUpdates.getValues(0).toByteArray())[0]); + + assertBuildable(commitBuilder); } @Test @@ -2862,6 +2926,8 @@ public void testCombiningClearPersist() throws Exception { 11, CoderUtils.decodeFromByteArray(accumCoder, tagBag.getValues(0).toByteArray())[0]); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2990,6 +3056,8 @@ public void testWatermarkPersistEarliest() throws Exception { assertEquals(TimeUnit.MILLISECONDS.toMicros(1000), watermarkHold.getTimestamps(0)); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3016,6 +3084,8 @@ public void testWatermarkPersistLatestEmpty() throws Exception { Mockito.verify(mockReader).watermarkFuture(key(NAMESPACE, "watermark"), STATE_FAMILY); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3042,6 +3112,8 @@ public void testWatermarkPersistLatestWindmillWins() throws Exception { Mockito.verify(mockReader).watermarkFuture(key(NAMESPACE, "watermark"), STATE_FAMILY); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3068,6 +3140,8 @@ public void testWatermarkPersistLatestLocalAdditionsWin() throws Exception { Mockito.verify(mockReader).watermarkFuture(key(NAMESPACE, "watermark"), STATE_FAMILY); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3091,6 +3165,8 @@ public void testWatermarkPersistEndOfWindow() throws Exception { // Blind adds should not need to read the future. Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3116,6 +3192,8 @@ public void testWatermarkClearPersist() throws Exception { assertEquals(TimeUnit.MILLISECONDS.toMicros(1000), clearAndUpdate.getTimestamps(0)); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3133,6 +3211,8 @@ public void testWatermarkPersistEmpty() throws Exception { // 1 bag update corresponds to deletion. There shouldn't be a bag update adding items. assertEquals(1, commitBuilder.getWatermarkHoldsCount()); + + assertBuildable(commitBuilder); } @Test @@ -3200,6 +3280,8 @@ public void testValueSetPersist() throws Exception { assertTrue(valueUpdate.isInitialized()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3220,6 +3302,8 @@ public void testValueClearPersist() throws Exception { assertEquals(0, valueUpdate.getValue().getData().size()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3234,6 +3318,8 @@ public void testValueNoChangePersist() throws Exception { assertEquals(0, commitBuilder.getValueUpdatesCount()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java index af3a3e8295bb..19e05efb50c6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java @@ -31,7 +31,7 @@ public final class FakeWindmillStubFactory implements ChannelCachingStubFactory private final ChannelCache channelCache; public FakeWindmillStubFactory(Supplier channelFactory) { - this.channelCache = ChannelCache.create(ignored -> channelFactory.get()); + this.channelCache = ChannelCache.forTesting(ignored -> channelFactory.get(), () -> {}); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java index 3cda4559c100..c76d5a584184 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java @@ -17,9 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -40,169 +38,79 @@ public class EvenGetWorkBudgetDistributorTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - private static GetWorkBudgetDistributor createBudgetDistributor(GetWorkBudget activeWorkBudget) { - return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget); - } + private static GetWorkBudgetSpender createGetWorkBudgetOwner() { + // Lambdas are final and cannot be spied. + return spy( + new GetWorkBudgetSpender() { - private static GetWorkBudgetDistributor createBudgetDistributor(long activeWorkItemsAndBytes) { - return createBudgetDistributor( - GetWorkBudget.builder() - .setItems(activeWorkItemsAndBytes) - .setBytes(activeWorkItemsAndBytes) - .build()); + @Override + public void setBudget(long items, long bytes) {} + }); } @Test public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() { - createBudgetDistributor(1L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.of(), GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); } @Test public void testDistributeBudget_doesNothingWithNoBudget() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())); - createBudgetDistributor(1L) + GetWorkBudgetSpender getWorkBudgetSpender = createGetWorkBudgetOwner(); + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget(ImmutableList.of(getWorkBudgetSpender), GetWorkBudget.noBudget()); verifyNoInteractions(getWorkBudgetSpender); } @Test - public void testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(10L).setBytes(10L).build())); - createBudgetDistributor(0L) - .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(5L).setBytes(5L).build())); - createBudgetDistributor(10L) + public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { + int totalStreams = 10; + long totalItems = 10L; + long totalBytes = 100L; + List streams = new ArrayList<>(); + for (int i = 0; i < totalStreams; i++) { + streams.add(createGetWorkBudgetOwner()); + } + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(20L).setBytes(20L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq( - totalGetWorkBudget.items() - - streamRemainingBudget.items() - - activeWorkItemsAndBytes), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); + ImmutableList.copyOf(streams), + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq( - totalGetWorkBudget.bytes() - - streamRemainingBudget.bytes() - - activeWorkItemsAndBytes)); + streams.forEach( + stream -> + verify(stream, times(1)) + .setBudget(eq(GetWorkBudget.builder().setItems(1L).setBytes(10L).build()))); } @Test - public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { - long totalItemsAndBytes = 10L; + public void testDistributeBudget_distributesFairlyWhenNotEven() { + long totalItems = 10L; + long totalBytes = 19L; List streams = new ArrayList<>(); - for (int i = 0; i < totalItemsAndBytes; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < 3; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), - GetWorkBudget.builder() - .setItems(totalItemsAndBytes) - .setBytes(totalItemsAndBytes) - .build()); + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); + .setBudget(eq(GetWorkBudget.builder().setItems(4L).setBytes(7L).build()))); } @Test - public void testDistributeBudget_distributesFairlyWhenNotEven() { + public void testDistributeBudget_distributesBudgetEvenly() { long totalItemsAndBytes = 10L; List streams = new ArrayList<>(); - for (int i = 0; i < 3; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < totalItemsAndBytes; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), GetWorkBudget.builder() @@ -210,24 +118,10 @@ public void testDistributeBudget_distributesFairlyWhenNotEven() { .setBytes(totalItemsAndBytes) .build()); - long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes / (streams.size() * 1.0)); + long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); - } - - private GetWorkBudgetSpender createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget getWorkBudget) { - return spy( - new GetWorkBudgetSpender() { - @Override - public void adjustBudget(long itemsDelta, long bytesDelta) {} - - @Override - public GetWorkBudget remainingBudget() { - return getWorkBudget; - } - }); + .setBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSenderTest.java index ed915088d0a6..acbb3aebbcf5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSenderTest.java @@ -39,7 +39,7 @@ public class StreamPoolHeartbeatSenderTest { public void sendsHeartbeatsOnStream() { FakeWindmillServer server = new FakeWindmillServer(new ErrorCollector(), c -> Optional.empty()); StreamPoolHeartbeatSender heartbeatSender = - StreamPoolHeartbeatSender.Create( + StreamPoolHeartbeatSender.create( WindmillStreamPool.create(1, Duration.standardSeconds(10), server::getDataStream)); Heartbeats.Builder heartbeatsBuilder = Heartbeats.builder(); heartbeatsBuilder @@ -59,7 +59,7 @@ public void sendsHeartbeatsOnDedicatedStream() { FakeGlobalConfigHandle configHandle = new FakeGlobalConfigHandle(getGlobalConfig(/*useSeparateHeartbeatStreams=*/ true)); StreamPoolHeartbeatSender heartbeatSender = - StreamPoolHeartbeatSender.Create( + StreamPoolHeartbeatSender.create( WindmillStreamPool.create( 1, Duration.standardSeconds(10), dedicatedServer::getDataStream), WindmillStreamPool.create( @@ -104,7 +104,7 @@ public void sendsHeartbeatsOnGetDataStream() { FakeGlobalConfigHandle configHandle = new FakeGlobalConfigHandle(getGlobalConfig(/*useSeparateHeartbeatStreams=*/ false)); StreamPoolHeartbeatSender heartbeatSender = - StreamPoolHeartbeatSender.Create( + StreamPoolHeartbeatSender.create( WindmillStreamPool.create( 1, Duration.standardSeconds(10), dedicatedServer::getDataStream), WindmillStreamPool.create( diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java index 72fa991c1f73..470692e75103 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java @@ -140,6 +140,8 @@ public RemoteEnvironment createEnvironment(Environment environment, String worke try { fnHarness.get(); } catch (Throwable t) { + // Print stacktrace to stderr. Could be useful if underlying error not surfaced earlier + t.printStackTrace(); executor.shutdownNow(); } }); diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java index 874748d7b975..49120d38f1f1 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java @@ -248,7 +248,7 @@ public void launchSdkHarness(PipelineOptions options) throws Exception { } }); InstructionRequestHandler controlClient = - clientPool.getSource().take(WORKER_ID, java.time.Duration.ofSeconds(2)); + clientPool.getSource().take(WORKER_ID, java.time.Duration.ofSeconds(10)); this.controlClient = SdkHarnessClient.usingFnApiClient(controlClient, dataServer.getService()); } diff --git a/runners/portability/java/build.gradle b/runners/portability/java/build.gradle index b684299c3174..a82759b4e4a0 100644 --- a/runners/portability/java/build.gradle +++ b/runners/portability/java/build.gradle @@ -253,7 +253,7 @@ tasks.register("validatesRunnerSickbay", Test) { } task ulrDockerValidatesRunner { - dependsOn createUlrValidatesRunnerTask("ulrDockerValidatesRunnerTests", "DOCKER", ":sdks:java:container:java8:docker") + dependsOn createUlrValidatesRunnerTask("ulrDockerValidatesRunnerTests", "DOCKER", ":sdks:java:container:${project.ext.currentJavaVersion}:docker") } task ulrLoopbackValidatesRunner { diff --git a/runners/prism/build.gradle b/runners/prism/build.gradle index 711a1aa2dd75..1009b9856e71 100644 --- a/runners/prism/build.gradle +++ b/runners/prism/build.gradle @@ -42,6 +42,9 @@ ext.set('buildTarget', buildTarget) def buildTask = tasks.named("build") { // goPrepare is a task registered in applyGoNature. dependsOn("goPrepare") + // Allow Go to manage the caching, not gradle. + outputs.cacheIf { false } + outputs.upToDateWhen { false } doLast { exec { workingDir = modDir diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle index de9a30ad8189..82eb62b9e207 100644 --- a/runners/prism/java/build.gradle +++ b/runners/prism/java/build.gradle @@ -16,6 +16,8 @@ * limitations under the License. */ +import groovy.json.JsonOutput + plugins { id 'org.apache.beam.module' } applyJavaNature( @@ -43,3 +45,235 @@ tasks.test { var prismBuildTask = dependsOn(':runners:prism:build') systemProperty 'prism.buildTarget', prismBuildTask.project.property('buildTarget').toString() } + +// Below is configuration to support running the Java Validates Runner tests. + +configurations { + validatesRunner +} + +dependencies { + implementation project(path: ":sdks:java:core", configuration: "shadow") + implementation library.java.hamcrest + permitUnusedDeclared library.java.hamcrest + implementation library.java.joda_time + implementation library.java.slf4j_api + implementation library.java.vendored_guava_32_1_2_jre + + testImplementation library.java.hamcrest + testImplementation library.java.junit + testImplementation library.java.mockito_core + testImplementation library.java.slf4j_jdk14 + + validatesRunner project(path: ":sdks:java:core", configuration: "shadowTest") + validatesRunner project(path: ":runners:core-java", configuration: "testRuntimeMigration") + validatesRunner project(path: project.path, configuration: "testRuntimeMigration") +} + +project.evaluationDependsOn(":sdks:java:core") +project.evaluationDependsOn(":runners:core-java") + +def sickbayTests = [ + // PortableMetrics doesn't implement "getCommitedOrNull" from Metrics + // Preventing Prism from passing these tests. + // In particular, it doesn't subclass MetricResult with an override, and + // it explicilty passes "false" to commited supported in create. + // + // There is not currently a category for excluding these _only_ in committed mode + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testAllCommittedMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedCounterMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedDistributionMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedStringSetMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedGaugeMetrics', + + // Triggers / Accumulation modes not yet implemented in prism. + // https://github.com/apache/beam/issues/31438 + 'org.apache.beam.sdk.transforms.CombineTest$WindowingTests.testGlobalCombineWithDefaultsAndTriggers', + 'org.apache.beam.sdk.transforms.CombineTest$BasicTests.testHotKeyCombiningWithAccumulationMode', + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testNoWindowFnDoesNotReassignWindows', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testCombiningAccumulatingProcessingTime', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerEarly', + 'org.apache.beam.sdk.transforms.ParDoTest$BundleInvariantsTests.testWatermarkUpdateMidBundle', + 'org.apache.beam.sdk.transforms.ViewTest.testTriggeredLatestSingleton', + // Requires Allowed Lateness, among others. + 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testEventTimeTimerSetWithinAllowedLateness', + 'org.apache.beam.sdk.testing.TestStreamTest.testFirstElementLate', + 'org.apache.beam.sdk.testing.TestStreamTest.testDiscardingMode', + 'org.apache.beam.sdk.testing.TestStreamTest.testEarlyPanesOfWindow', + 'org.apache.beam.sdk.testing.TestStreamTest.testElementsAtAlmostPositiveInfinity', + 'org.apache.beam.sdk.testing.TestStreamTest.testLateDataAccumulating', + 'org.apache.beam.sdk.testing.TestStreamTest.testMultipleStreams', + 'org.apache.beam.sdk.testing.TestStreamTest.testProcessingTimeTrigger', + + // Coding error somehow: short write: reached end of stream after reading 5 bytes; 98 bytes expected + 'org.apache.beam.sdk.testing.TestStreamTest.testMultiStage', + + // Prism not firing sessions correctly (seems to be merging inapppropriately) + 'org.apache.beam.sdk.transforms.CombineTest$WindowingTests.testSessionsCombine', + 'org.apache.beam.sdk.transforms.CombineTest$WindowingTests.testSessionsCombineWithContext', + + // Java side dying during execution. + // https://github.com/apache/beam/issues/32930 + 'org.apache.beam.sdk.transforms.FlattenTest.testFlattenMultipleCoders', + // Stream corruption error java side: failed:java.io.StreamCorruptedException: invalid stream header: 206E6F74 + // Likely due to prism't coder changes. + 'org.apache.beam.sdk.transforms.FlattenTest.testFlattenWithDifferentInputAndOutputCoders2', + + // java.lang.IllegalStateException: Output with tag Tag must have a schema in order to call getRowReceiver + // Ultimately because getRoeReceiver code path SDK side isn't friendly to LengthPrefix wrapping of row coders. + // https://github.com/apache/beam/issues/32931 + 'org.apache.beam.sdk.transforms.ParDoSchemaTest.testReadAndWrite', + 'org.apache.beam.sdk.transforms.ParDoSchemaTest.testReadAndWriteMultiOutput', + 'org.apache.beam.sdk.transforms.ParDoSchemaTest.testReadAndWriteWithSchemaRegistry', + + // Technically these tests "succeed" + // the test is just complaining that an AssertionException isn't a RuntimeException + // + // java.lang.RuntimeException: test error in finalize + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests.testParDoWithErrorInFinishBatch', + // java.lang.RuntimeException: test error in process + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests.testParDoWithErrorInProcessElement', + // java.lang.RuntimeException: test error in initialize + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests.testParDoWithErrorInStartBatch', + + // Only known window fns supported, not general window merging + // Custom window fns not yet implemented in prism. + // https://github.com/apache/beam/issues/31921 + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testMergingCustomWindows', + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testMergingCustomWindowsKeyedCollection', + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testMergingCustomWindowsWithoutCustomWindowTypes', + 'org.apache.beam.sdk.transforms.windowing.WindowingTest.testMergingWindowing', + 'org.apache.beam.sdk.transforms.windowing.WindowingTest.testNonPartitioningWindowing', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$WindowTests.testGroupByKeyMergingWindows', + + // Possibly a different error being hidden behind the main error. + // org.apache.beam.sdk.util.WindowedValue$ValueInGlobalWindow cannot be cast to class java.lang.String + // TODO(https://github.com/apache/beam/issues/29973) + 'org.apache.beam.sdk.transforms.ReshuffleTest.testReshufflePreservesMetadata', + // TODO(https://github.com/apache/beam/issues/31231) + 'org.apache.beam.sdk.transforms.RedistributeTest.testRedistributePreservesMetadata', + + // Prism isn't handling Java's side input views properly. + // https://github.com/apache/beam/issues/32932 + // java.lang.IllegalArgumentException: PCollection with more than one element accessed as a singleton view. + // Consider using Combine.globally().asSingleton() to combine the PCollection into a single value + 'org.apache.beam.sdk.transforms.ViewTest.testDiscardingNonSingletonSideInput', + // java.util.NoSuchElementException: Empty PCollection accessed as a singleton view. + 'org.apache.beam.sdk.transforms.ViewTest.testDiscardingNonSingletonSideInput', + // ava.lang.IllegalArgumentException: Duplicate values for a + 'org.apache.beam.sdk.transforms.ViewTest.testMapSideInputWithNullValuesCatchesDuplicates', + // java.lang.IllegalArgumentException: PCollection with more than one element accessed as a singleton view.... + 'org.apache.beam.sdk.transforms.ViewTest.testNonSingletonSideInput', + // java.util.NoSuchElementException: Empty PCollection accessed as a singleton view. + 'org.apache.beam.sdk.transforms.ViewTest.testEmptySingletonSideInput', + // Prism side encoding error. + // java.lang.IllegalStateException: java.io.EOFException + 'org.apache.beam.sdk.transforms.ViewTest.testSideInputWithNestedIterables', + + // Requires Time Sorted Input + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInput', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithTestStream', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithLateDataAndAllowedLateness', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testTwoRequiresTimeSortedInputWithLateData', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithLateData', + + // Missing output due to processing time timer skew. + 'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testProcessElementSkew', + + // TestStream + BundleFinalization. + // Tests seem to assume individual element bundles from test stream, but prism will aggregate them, preventing + // a subsequent firing. Tests ultimately hang until timeout. + // Either a test problem, or a misunderstanding of how test stream must work problem in prism. + // Biased to test problem, due to how they are constructed. + 'org.apache.beam.sdk.transforms.ParDoTest$BundleFinalizationTests.testBundleFinalization', + 'org.apache.beam.sdk.transforms.ParDoTest$BundleFinalizationTests.testBundleFinalizationWithSideInputs', + + // Filtered by PortableRunner tests. + // Teardown not called in exceptions + // https://github.com/apache/beam/issues/20372 + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful', +] + +/** + * Runs Java ValidatesRunner tests against the Prism Runner + * with the specified environment type. + */ +def createPrismValidatesRunnerTask = { name, environmentType -> + Task vrTask = tasks.create(name: name, type: Test, group: "Verification") { + description "PrismRunner Java $environmentType ValidatesRunner suite" + classpath = configurations.validatesRunner + + var prismBuildTask = dependsOn(':runners:prism:build') + systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ + "--runner=TestPrismRunner", + "--experiments=beam_fn_api", + "--defaultEnvironmentType=${environmentType}", + "--prismLogLevel=warn", + "--prismLocation=${prismBuildTask.project.property('buildTarget').toString()}", + "--enableWebUI=false", + ]) + testClassesDirs = files(project(":sdks:java:core").sourceSets.test.output.classesDirs) + useJUnit { + includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' + // Should be run only in a properly configured SDK harness environment + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' + excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' + + // Not yet implemented in Prism + // https://github.com/apache/beam/issues/32211 + excludeCategories 'org.apache.beam.sdk.testing.UsesOnWindowExpiration' + // https://github.com/apache/beam/issues/32929 + excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState' + + // Not supported in Portable Java SDK yet. + // https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState + excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' + } + filter { + for (String test : sickbayTests) { + excludeTestsMatching test + } + } + } + return vrTask +} + +tasks.register("validatesRunnerSickbay", Test) { + group = "Verification" + description "Validates Prism local runner (Sickbay Tests)" + + var prismBuildTask = dependsOn(':runners:prism:build') + systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ + "--runner=TestPrismRunner", + "--experiments=beam_fn_api", + "--enableWebUI=false", + "--prismLogLevel=warn", + "--prismLocation=${prismBuildTask.project.property('buildTarget').toString()}" + ]) + + classpath = configurations.validatesRunner + testClassesDirs = files(project(":sdks:java:core").sourceSets.test.output.classesDirs) + + filter { + for (String test : sickbayTests) { + includeTestsMatching test + } + } +} + +task prismDockerValidatesRunner { + Task vrTask = createPrismValidatesRunnerTask("prismDockerValidatesRunnerTests", "DOCKER") + vrTask.dependsOn ":sdks:java:container:${project.ext.currentJavaVersion}:docker" +} + +task prismLoopbackValidatesRunner { + dependsOn createPrismValidatesRunnerTask("prismLoopbackValidatesRunnerTests", "LOOPBACK") +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java index fda5db923a7f..111d937fcbf6 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java @@ -31,6 +31,8 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.slf4j.Logger; @@ -48,6 +50,7 @@ abstract class PrismExecutor { static final String IDLE_SHUTDOWN_TIMEOUT = "-idle_shutdown_timeout=%s"; static final String JOB_PORT_FLAG_TEMPLATE = "-job_port=%s"; static final String SERVE_HTTP_FLAG_TEMPLATE = "-serve_http=%s"; + static final String LOG_LEVEL_FLAG_TEMPLATE = "-log_level=%s"; protected @MonotonicNonNull Process process; protected ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -157,6 +160,16 @@ abstract static class Builder { abstract Builder setArguments(List arguments); + Builder addArguments(List arguments) { + Optional> original = getArguments(); + if (!original.isPresent()) { + return this.setArguments(arguments); + } + List newArguments = + Stream.concat(original.get().stream(), arguments.stream()).collect(Collectors.toList()); + return this.setArguments(newArguments); + } + abstract Optional> getArguments(); abstract PrismExecutor autoBuild(); diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java index 27aea3f64df0..b32f03e78e6a 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java @@ -110,6 +110,12 @@ String resolveSource() { String resolve() throws IOException { String from = resolveSource(); + // If the location is set, and it's not an http request or a zip, + // use the binary directly. + if (!from.startsWith("http") && !from.endsWith("zip") && Files.exists(Paths.get(from))) { + return from; + } + String fromFileName = getNameWithoutExtension(from); Path to = Paths.get(userHome(), PRISM_BIN_PATH, fromFileName); diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java index 9b280d0a70d4..ceec1ad8268a 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java @@ -59,4 +59,10 @@ public interface PrismPipelineOptions extends PortablePipelineOptions { String getIdleShutdownTimeout(); void setIdleShutdownTimeout(String idleShutdownTimeout); + + @Description("Sets the log level for Prism. Can be set to 'debug', 'info', 'warn', or 'error'.") + @Default.String("warn") + String getPrismLogLevel(); + + void setPrismLogLevel(String prismLogLevel); } diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java index 6099db4b63ee..ac1e68237faf 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java @@ -101,13 +101,17 @@ PrismExecutor startPrism() throws IOException { String idleShutdownTimeoutFlag = String.format( PrismExecutor.IDLE_SHUTDOWN_TIMEOUT, prismPipelineOptions.getIdleShutdownTimeout()); + String logLevelFlag = + String.format( + PrismExecutor.LOG_LEVEL_FLAG_TEMPLATE, prismPipelineOptions.getPrismLogLevel()); String endpoint = "localhost:" + port; prismPipelineOptions.setJobEndpoint(endpoint); String command = locator.resolve(); PrismExecutor executor = PrismExecutor.builder() .setCommand(command) - .setArguments(Arrays.asList(portFlag, serveHttpFlag, idleShutdownTimeoutFlag)) + .setArguments( + Arrays.asList(portFlag, serveHttpFlag, idleShutdownTimeoutFlag, logLevelFlag)) .build(); executor.execute(); checkState(executor.isAlive()); diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java index eb497f0a4c43..a81e3e24ee69 100644 --- a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java @@ -59,7 +59,7 @@ public void executeWithStreamRedirectThenStop() throws IOException { sleep(3000L); executor.stop(); String output = outputStream.toString(StandardCharsets.UTF_8.name()); - assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:8073"); + assertThat(output).contains("level=INFO msg=\"Serving JobManagement\" endpoint=localhost:8073"); } @Test @@ -71,7 +71,8 @@ public void executeWithFileOutputThenStop() throws IOException { executor.stop(); try (Stream stream = Files.lines(log.toPath(), StandardCharsets.UTF_8)) { String output = stream.collect(Collectors.joining("\n")); - assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:8073"); + assertThat(output) + .contains("level=INFO msg=\"Serving JobManagement\" endpoint=localhost:8073"); } } @@ -79,21 +80,23 @@ public void executeWithFileOutputThenStop() throws IOException { public void executeWithCustomArgumentsThenStop() throws IOException { PrismExecutor executor = underTest() - .setArguments(Collections.singletonList("-" + JOB_PORT_FLAG_NAME + "=5555")) + .addArguments(Collections.singletonList("-" + JOB_PORT_FLAG_NAME + "=5555")) .build(); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); executor.execute(outputStream); sleep(3000L); executor.stop(); String output = outputStream.toString(StandardCharsets.UTF_8.name()); - assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:5555"); + assertThat(output).contains("level=INFO msg=\"Serving JobManagement\" endpoint=localhost:5555"); } @Test public void executeWithPortFinderThenStop() throws IOException {} private PrismExecutor.Builder underTest() { - return PrismExecutor.builder().setCommand(getLocalPrismBuildOrIgnoreTest()); + return PrismExecutor.builder() + .setCommand(getLocalPrismBuildOrIgnoreTest()) + .setArguments(Collections.singletonList("--log_kind=text")); // disable color control chars } private void sleep(long millis) { diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java index 095d3c9bde61..fa5ba6d37203 100644 --- a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java @@ -134,7 +134,10 @@ public void givenFilePrismLocationOption_thenResolves() throws IOException { PrismLocator underTest = new PrismLocator(options); String got = underTest.resolve(); - assertThat(got).contains(DESTINATION_DIRECTORY.toString()); + // Local file overrides should use the local binary in place, not copy + // to the cache. Doing so prevents using a locally built version. + assertThat(got).doesNotContain(DESTINATION_DIRECTORY.toString()); + assertThat(got).contains(options.getPrismLocation()); Path gotPath = Paths.get(got); assertThat(Files.exists(gotPath)).isTrue(); } diff --git a/runners/samza/job-server/build.gradle b/runners/samza/job-server/build.gradle index f972f376e5c8..6fc8db98a4f9 100644 --- a/runners/samza/job-server/build.gradle +++ b/runners/samza/job-server/build.gradle @@ -90,7 +90,6 @@ def portableValidatesRunnerTask(String name, boolean docker) { excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' - excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesMapState' excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' excludeCategories 'org.apache.beam.sdk.testing.UsesSetState' @@ -127,6 +126,8 @@ def portableValidatesRunnerTask(String name, boolean docker) { excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testParDoShiftTimestampInvalid' // TODO(https://github.com/apache/beam/issues/21144) excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testParDoShiftTimestampInvalidZeroAllowed' + // TODO(https://github.com/apache/beam/issues/32520) + excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionIn*Stateful' // TODO(https://github.com/apache/beam/issues/21145) excludeTestsMatching 'org.apache.beam.sdk.transforms.DeduplicateTest.testEventTime' // TODO(https://github.com/apache/beam/issues/21146) diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java index e0bcbed1577c..68674d202cdb 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; @@ -59,6 +60,8 @@ public class TestSamzaRunnerWithTransformMetrics { @Test public void testSamzaRunnerWithDefaultMetrics() { + // TODO(https://github.com/apache/beam/issues/32208) + assumeTrue(System.getProperty("java.version").startsWith("1.")); SamzaPipelineOptions options = PipelineOptionsFactory.create().as(SamzaPipelineOptions.class); InMemoryMetricsReporter inMemoryMetricsReporter = new InMemoryMetricsReporter(); options.setMetricsReporters(ImmutableList.of(inMemoryMetricsReporter)); diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java index 8670d9a46eac..73454cc95421 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.samza.runtime; +import static org.junit.Assume.assumeTrue; + import java.io.Serializable; import java.util.Arrays; import org.apache.beam.sdk.coders.KvCoder; @@ -35,11 +37,19 @@ import org.apache.beam.sdk.values.TimestampedValue; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; /** Tests for GroupByKeyOp. */ public class GroupByKeyOpTest implements Serializable { + + @BeforeClass + public static void beforeClass() { + // TODO(https://github.com/apache/beam/issues/32208) + assumeTrue(System.getProperty("java.version").startsWith("1.")); + } + @Rule public final transient TestPipeline pipeline = TestPipeline.fromOptions( diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java index 004162600179..9409efbcf394 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; import java.io.File; import java.io.IOException; @@ -75,6 +76,7 @@ import org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStorageEngineFactory; import org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStore; import org.apache.samza.system.SystemStreamPartition; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -91,6 +93,12 @@ public class SamzaStoreStateInternalsTest implements Serializable { TestPipeline.fromOptions( PipelineOptionsFactory.fromArgs("--runner=TestSamzaRunner").create()); + @BeforeClass + public static void beforeClass() { + // TODO(https://github.com/apache/beam/issues/32208) + assumeTrue(System.getProperty("java.version").startsWith("1.")); + } + @Test public void testMapStateIterator() { final String stateId = "foo"; diff --git a/runners/spark/job-server/container/Dockerfile b/runners/spark/job-server/container/Dockerfile index ec4a123f2b9d..f5639430a33b 100644 --- a/runners/spark/job-server/container/Dockerfile +++ b/runners/spark/job-server/container/Dockerfile @@ -16,7 +16,7 @@ # limitations under the License. ############################################################################### -FROM openjdk:8 +FROM eclipse-temurin:11 MAINTAINER "Apache Beam " RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y libltdl7 diff --git a/runners/spark/job-server/spark_job_server.gradle b/runners/spark/job-server/spark_job_server.gradle index 6d2d4b2bafbf..90109598ed64 100644 --- a/runners/spark/job-server/spark_job_server.gradle +++ b/runners/spark/job-server/spark_job_server.gradle @@ -118,7 +118,6 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean docker, excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesPerKeyOrderedDelivery' - excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesMapState' excludeCategories 'org.apache.beam.sdk.testing.UsesSetState' excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState' @@ -187,7 +186,6 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean docker, excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesPerKeyOrderedDelivery' excludeCategories 'org.apache.beam.sdk.testing.UsesPerKeyOrderInBundle' - excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle' excludeCategories 'org.apache.beam.sdk.testing.UsesMapState' excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' excludeCategories 'org.apache.beam.sdk.testing.UsesSetState' @@ -303,3 +301,7 @@ createCrossLanguageValidatesRunnerTask( "--endpoint localhost:${jobPort}", ], ) + +shadowJar { + outputs.upToDateWhen { false } +} diff --git a/scripts/ci/pr-bot/findPrsNeedingAttention.ts b/scripts/ci/pr-bot/findPrsNeedingAttention.ts index 621e89b2d87d..095be5ba762f 100644 --- a/scripts/ci/pr-bot/findPrsNeedingAttention.ts +++ b/scripts/ci/pr-bot/findPrsNeedingAttention.ts @@ -114,13 +114,16 @@ async function assignToNewReviewers( reviewersToExclude.push(pull.user.login); const reviewersForLabels: { [key: string]: string[] } = reviewerConfig.getReviewersForLabels(labelObjects, reviewersToExclude); + const fallbackReviewers = + reviewerConfig.getFallbackReviewers(); for (const labelObject of labelObjects) { const label = labelObject.name; let availableReviewers = reviewersForLabels[label]; if (availableReviewers && availableReviewers.length > 0) { let reviewersState = await stateClient.getReviewersForLabelState(label); let chosenReviewer = await reviewersState.assignNextCommitter( - availableReviewers + availableReviewers, + fallbackReviewers ); reviewerStateToUpdate[label] = reviewersState; prState.reviewersAssignedForLabels[label] = chosenReviewer; diff --git a/scripts/ci/pr-bot/processNewPrs.ts b/scripts/ci/pr-bot/processNewPrs.ts index db723e5623fa..5aa89e9cce4c 100644 --- a/scripts/ci/pr-bot/processNewPrs.ts +++ b/scripts/ci/pr-bot/processNewPrs.ts @@ -217,8 +217,11 @@ async function processPull( ); const availableReviewers = reviewerConfig.getReviewersForLabel(labelOfReviewer); + const fallbackReviewers = + reviewerConfig.getFallbackReviewers(); const chosenCommitter = await reviewersState.assignNextCommitter( - availableReviewers + availableReviewers, + fallbackReviewers ); prState.reviewersAssignedForLabels[labelOfReviewer] = chosenCommitter; prState.committerAssigned = true; diff --git a/scripts/ci/pr-bot/shared/reviewersForLabel.ts b/scripts/ci/pr-bot/shared/reviewersForLabel.ts index 971f3f1cd7a5..f46f5359a8b9 100644 --- a/scripts/ci/pr-bot/shared/reviewersForLabel.ts +++ b/scripts/ci/pr-bot/shared/reviewersForLabel.ts @@ -75,7 +75,10 @@ export class ReviewersForLabel { // Given the up to date list of available reviewers (excluding the author), // returns the next reviewer up based on who has reviewed least recently. // Updates this object to reflect their assignment. - async assignNextCommitter(availableReviewers: string[]): Promise { + async assignNextCommitter( + availableReviewers: string[], + fallbackReviewers: string[] + ): Promise { let earliestDate = Date.now(); let earliestCommitter: string = ""; @@ -94,7 +97,9 @@ export class ReviewersForLabel { } if (!earliestCommitter) { - throw new Error(`No committers available for label ${this.label}`); + console.log(`No committers available for label ${this.label}`); + console.log(`Using fallbackReviewers label instead of ${this.label}`); + return this.assignNextCommitter(fallbackReviewers, fallbackReviewers); } this.dateOfLastReviewAssignment[earliestCommitter] = Date.now(); return earliestCommitter; diff --git a/sdks/go.mod b/sdks/go.mod index c5d8822b2133..5406c2b70cbc 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -20,22 +20,22 @@ // directory. module github.com/apache/beam/sdks/v2 -go 1.21 +go 1.21.0 require ( - cloud.google.com/go/bigquery v1.63.0 + cloud.google.com/go/bigquery v1.64.0 cloud.google.com/go/bigtable v1.33.0 - cloud.google.com/go/datastore v1.19.0 + cloud.google.com/go/datastore v1.20.0 cloud.google.com/go/profiler v0.4.1 - cloud.google.com/go/pubsub v1.43.0 - cloud.google.com/go/spanner v1.67.0 - cloud.google.com/go/storage v1.43.0 - github.com/aws/aws-sdk-go-v2 v1.31.0 - github.com/aws/aws-sdk-go-v2/config v1.27.39 - github.com/aws/aws-sdk-go-v2/credentials v1.17.37 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.25 - github.com/aws/aws-sdk-go-v2/service/s3 v1.63.3 - github.com/aws/smithy-go v1.21.0 + cloud.google.com/go/pubsub v1.45.3 + cloud.google.com/go/spanner v1.73.0 + cloud.google.com/go/storage v1.47.0 + github.com/aws/aws-sdk-go-v2 v1.32.6 + github.com/aws/aws-sdk-go-v2/config v1.28.6 + github.com/aws/aws-sdk-go-v2/credentials v1.17.47 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 + github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0 + github.com/aws/smithy-go v1.22.1 github.com/docker/go-connections v0.5.0 github.com/dustin/go-humanize v1.0.1 github.com/go-sql-driver/mysql v1.8.1 @@ -44,24 +44,24 @@ require ( github.com/johannesboyne/gofakes3 v0.0.0-20221110173912-32fb85c5aed6 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 - github.com/nats-io/nats-server/v2 v2.10.18 + github.com/nats-io/nats-server/v2 v2.10.23 github.com/nats-io/nats.go v1.37.0 github.com/proullon/ramsql v0.1.4 github.com/spf13/cobra v1.8.1 github.com/testcontainers/testcontainers-go v0.33.0 - github.com/tetratelabs/wazero v1.8.0 + github.com/tetratelabs/wazero v1.8.1 github.com/xitongsys/parquet-go v1.6.2 github.com/xitongsys/parquet-go-source v0.0.0-20220315005136-aec0fe3e777c go.mongodb.org/mongo-driver v1.17.1 - golang.org/x/net v0.29.0 - golang.org/x/oauth2 v0.23.0 - golang.org/x/sync v0.8.0 - golang.org/x/sys v0.25.0 - golang.org/x/text v0.18.0 - google.golang.org/api v0.199.0 - google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 - google.golang.org/grpc v1.67.0 - google.golang.org/protobuf v1.34.2 + golang.org/x/net v0.32.0 + golang.org/x/oauth2 v0.24.0 + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.28.0 + golang.org/x/text v0.21.0 + google.golang.org/api v0.210.0 + google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 + google.golang.org/grpc v1.67.1 + google.golang.org/protobuf v1.35.2 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -69,18 +69,22 @@ require ( require ( github.com/avast/retry-go/v4 v4.6.0 github.com/fsouza/fake-gcs-server v1.49.2 + github.com/golang-cz/devslog v0.0.11 golang.org/x/exp v0.0.0-20231006140011-7918f672742d ) require ( - cel.dev/expr v0.16.0 // indirect - cloud.google.com/go/auth v0.9.5 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect - cloud.google.com/go/monitoring v1.21.1 // indirect + cel.dev/expr v0.16.1 // indirect + cloud.google.com/go/auth v0.11.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect + cloud.google.com/go/monitoring v1.21.2 // indirect dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 // indirect github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.24.1 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 // indirect github.com/apache/arrow/go/v15 v15.0.2 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect @@ -94,7 +98,7 @@ require ( github.com/moby/sys/user v0.1.0 // indirect github.com/moby/sys/userns v0.1.0 // indirect github.com/nats-io/jwt/v2 v2.5.8 // indirect - github.com/nats-io/nkeys v0.4.7 // indirect + github.com/nats-io/nkeys v0.4.8 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect @@ -104,6 +108,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect go.einride.tech/aip v0.68.0 // indirect + go.opentelemetry.io/contrib/detectors/gcp v1.29.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect go.opentelemetry.io/otel v1.29.0 // indirect @@ -112,36 +117,37 @@ require ( go.opentelemetry.io/otel/sdk v1.29.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.29.0 // indirect go.opentelemetry.io/otel/trace v1.29.0 // indirect - golang.org/x/time v0.6.0 // indirect + golang.org/x/time v0.8.0 // indirect + google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a // indirect ) require ( - cloud.google.com/go v0.115.1 // indirect + cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/compute/metadata v0.5.2 // indirect - cloud.google.com/go/iam v1.2.1 // indirect - cloud.google.com/go/longrunning v0.6.1 // indirect + cloud.google.com/go/iam v1.2.2 // indirect + cloud.google.com/go/longrunning v0.6.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/apache/arrow/go/arrow v0.0.0-20200730104253-651201b0f516 // indirect github.com/apache/thrift v0.17.0 // indirect github.com/aws/aws-sdk-go v1.34.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.18 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.20 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.18 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.25 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/cncf/xds/go v0.0.0-20240822171458-6449f94b4d59 // indirect + github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect github.com/docker/docker v27.3.1+incompatible // but required to resolve issue docker has with go1.20 github.com/docker/go-units v0.5.0 // indirect @@ -157,11 +163,11 @@ require ( github.com/google/renameio/v2 v2.0.0 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect - github.com/googleapis/gax-go/v2 v2.13.0 // indirect + github.com/googleapis/gax-go/v2 v2.14.0 // indirect github.com/gorilla/handlers v1.5.2 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -184,10 +190,10 @@ require ( github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect go.opencensus.io v0.24.0 // indirect - golang.org/x/crypto v0.27.0 // indirect + golang.org/x/crypto v0.30.0 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/tools v0.24.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241113202542-65e8d215514f // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241118233622-e639e219e697 // indirect ) diff --git a/sdks/go.sum b/sdks/go.sum index 5bf71f7ab771..bb96c54af087 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.16.0 h1:yloc84fytn4zmJX2GU3TkXGsaieaV7dQ057Qs4sIG2Y= -cel.dev/expr v0.16.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= +cel.dev/expr v0.16.1 h1:NR0+oFYzR1CqLFhTAqg3ql59G9VfN8fKq1TCHJ6gq1g= +cel.dev/expr v0.16.1/go.mod h1:AsGA5zb3WruAEQeQng1RZdGEXmBj0jvMWh6l5SnNuC8= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -38,8 +38,8 @@ cloud.google.com/go v0.104.0/go.mod h1:OO6xxXdJyvuJPcEPBLN9BJPD+jep5G1+2U5B5gkRY cloud.google.com/go v0.105.0/go.mod h1:PrLgOJNe5nfE9UMxKxgXj4mD3voiP+YQ6gdt6KMFOKM= cloud.google.com/go v0.107.0/go.mod h1:wpc2eNrD7hXUTy8EKS10jkxpZBjASrORK7goS+3YX2I= cloud.google.com/go v0.110.0/go.mod h1:SJnCLqQ0FCFGSZMUNUf84MV3Aia54kn7pi8st7tMzaY= -cloud.google.com/go v0.115.1 h1:Jo0SM9cQnSkYfp44+v+NQXHpcHqlnRJk2qxh6yvxxxQ= -cloud.google.com/go v0.115.1/go.mod h1:DuujITeaufu3gL68/lOFIirVNJwQeyf5UXyi+Wbgknc= +cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= +cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= cloud.google.com/go/accessapproval v1.4.0/go.mod h1:zybIuC3KpDOvotz59lFe5qxRZx6C75OtwbisN56xYB4= cloud.google.com/go/accessapproval v1.5.0/go.mod h1:HFy3tuiGvMdcd/u+Cu5b9NkO1pEICJ46IR82PoUdplw= cloud.google.com/go/accessapproval v1.6.0/go.mod h1:R0EiYnwV5fsRFiKZkPHr6mwyk2wxUJ30nL4j2pcFY2E= @@ -101,10 +101,10 @@ cloud.google.com/go/assuredworkloads v1.7.0/go.mod h1:z/736/oNmtGAyU47reJgGN+KVo cloud.google.com/go/assuredworkloads v1.8.0/go.mod h1:AsX2cqyNCOvEQC8RMPnoc0yEarXQk6WEKkxYfL6kGIo= cloud.google.com/go/assuredworkloads v1.9.0/go.mod h1:kFuI1P78bplYtT77Tb1hi0FMxM0vVpRC7VVoJC3ZoT0= cloud.google.com/go/assuredworkloads v1.10.0/go.mod h1:kwdUQuXcedVdsIaKgKTp9t0UJkE5+PAVNhdQm4ZVq2E= -cloud.google.com/go/auth v0.9.5 h1:4CTn43Eynw40aFVr3GpPqsQponx2jv0BQpjvajsbbzw= -cloud.google.com/go/auth v0.9.5/go.mod h1:Xo0n7n66eHyOWWCnitop6870Ilwo3PiZyodVkkH1xWM= -cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY= -cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc= +cloud.google.com/go/auth v0.11.0 h1:Ic5SZz2lsvbYcWT5dfjNWgw6tTlGi2Wc8hyQSC9BstA= +cloud.google.com/go/auth v0.11.0/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= +cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU= +cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8= cloud.google.com/go/automl v1.5.0/go.mod h1:34EjfoFGMZ5sgJ9EoLsRtdPSNZLcfflJR39VbVNS2M0= cloud.google.com/go/automl v1.6.0/go.mod h1:ugf8a6Fx+zP0D59WLhqgTDsQI9w07o64uf/Is3Nh5p8= cloud.google.com/go/automl v1.7.0/go.mod h1:RL9MYCCsJEOmt0Wf3z9uzG0a7adTT1fe+aObgSpkCt8= @@ -133,8 +133,8 @@ cloud.google.com/go/bigquery v1.47.0/go.mod h1:sA9XOgy0A8vQK9+MWhEQTY6Tix87M/Zur cloud.google.com/go/bigquery v1.48.0/go.mod h1:QAwSz+ipNgfL5jxiaK7weyOhzdoAy1zFm0Nf1fysJac= cloud.google.com/go/bigquery v1.49.0/go.mod h1:Sv8hMmTFFYBlt/ftw2uN6dFdQPzBlREY9yBh7Oy7/4Q= cloud.google.com/go/bigquery v1.50.0/go.mod h1:YrleYEh2pSEbgTBZYMJ5SuSr0ML3ypjRB1zgf7pvQLU= -cloud.google.com/go/bigquery v1.63.0 h1:yQFuJXdDukmBkiUUpjX0i1CtHLFU62HqPs/VDvSzaZo= -cloud.google.com/go/bigquery v1.63.0/go.mod h1:TQto6OR4kw27bqjNTGkVk1Vo5PJlTgxvDJn6YEIZL/E= +cloud.google.com/go/bigquery v1.64.0 h1:vSSZisNyhr2ioJE1OuYBQrnrpB7pIhRQm4jfjc7E/js= +cloud.google.com/go/bigquery v1.64.0/go.mod h1:gy8Ooz6HF7QmA+TRtX8tZmXBKH5mCFBwUApGAb3zI7Y= cloud.google.com/go/bigtable v1.33.0 h1:2BDaWLRAwXO14DJL/u8crbV2oUbMZkIa2eGq8Yao1bk= cloud.google.com/go/bigtable v1.33.0/go.mod h1:HtpnH4g25VT1pejHRtInlFPnN5sjTxbQlsYBjh9t5l0= cloud.google.com/go/billing v1.4.0/go.mod h1:g9IdKBEFlItS8bTtlrZdVLWSSdSyFUZKXNS02zKMOZY= @@ -210,8 +210,8 @@ cloud.google.com/go/datacatalog v1.8.0/go.mod h1:KYuoVOv9BM8EYz/4eMFxrr4DUKhGIOX cloud.google.com/go/datacatalog v1.8.1/go.mod h1:RJ58z4rMp3gvETA465Vg+ag8BGgBdnRPEMMSTr5Uv+M= cloud.google.com/go/datacatalog v1.12.0/go.mod h1:CWae8rFkfp6LzLumKOnmVh4+Zle4A3NXLzVJ1d1mRm0= cloud.google.com/go/datacatalog v1.13.0/go.mod h1:E4Rj9a5ZtAxcQJlEBTLgMTphfP11/lNaAshpoBgemX8= -cloud.google.com/go/datacatalog v1.22.0 h1:7e5/0B2LYbNx0BcUJbiCT8K2wCtcB5993z/v1JeLIdc= -cloud.google.com/go/datacatalog v1.22.0/go.mod h1:4Wff6GphTY6guF5WphrD76jOdfBiflDiRGFAxq7t//I= +cloud.google.com/go/datacatalog v1.23.0 h1:9F2zIbWNNmtrSkPIyGRQNsIugG5VgVVFip6+tXSdWLg= +cloud.google.com/go/datacatalog v1.23.0/go.mod h1:9Wamq8TDfL2680Sav7q3zEhBJSPBrDxJU8WtPJ25dBM= cloud.google.com/go/dataflow v0.6.0/go.mod h1:9QwV89cGoxjjSR9/r7eFDqqjtvbKxAK2BaYU6PVk9UM= cloud.google.com/go/dataflow v0.7.0/go.mod h1:PX526vb4ijFMesO1o202EaUmouZKBpjHsTlCtB4parQ= cloud.google.com/go/dataflow v0.8.0/go.mod h1:Rcf5YgTKPtQyYz8bLYhFoIV/vP39eL7fWNcSOyFfLJE= @@ -240,8 +240,8 @@ cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7 cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/datastore v1.10.0/go.mod h1:PC5UzAmDEkAmkfaknstTYbNpgE49HAgW2J1gcgUfmdM= cloud.google.com/go/datastore v1.11.0/go.mod h1:TvGxBIHCS50u8jzG+AW/ppf87v1of8nwzFNgEZU1D3c= -cloud.google.com/go/datastore v1.19.0 h1:p5H3bUQltOa26GcMRAxPoNwoqGkq5v8ftx9/ZBB35MI= -cloud.google.com/go/datastore v1.19.0/go.mod h1:KGzkszuj87VT8tJe67GuB+qLolfsOt6bZq/KFuWaahc= +cloud.google.com/go/datastore v1.20.0 h1:NNpXoyEqIJmZFc0ACcwBEaXnmscUpcG4NkKnbCePmiM= +cloud.google.com/go/datastore v1.20.0/go.mod h1:uFo3e+aEpRfHgtp5pp0+6M0o147KoPaYNaPAKpfh8Ew= cloud.google.com/go/datastream v1.2.0/go.mod h1:i/uTP8/fZwgATHS/XFu0TcNUhuA0twZxxQ3EyCUQMwo= cloud.google.com/go/datastream v1.3.0/go.mod h1:cqlOX8xlyYF/uxhiKn6Hbv6WjwPPuI9W2M9SAXwaLLQ= cloud.google.com/go/datastream v1.4.0/go.mod h1:h9dpzScPhDTs5noEMQVWP8Wx8AFBRyS0s8KWPx/9r0g= @@ -327,8 +327,8 @@ cloud.google.com/go/iam v0.8.0/go.mod h1:lga0/y3iH6CX7sYqypWJ33hf7kkfXJag67naqGE cloud.google.com/go/iam v0.11.0/go.mod h1:9PiLDanza5D+oWFZiH1uG+RnRCfEGKoyl6yo4cgWZGY= cloud.google.com/go/iam v0.12.0/go.mod h1:knyHGviacl11zrtZUoDuYpDgLjvr28sLQaG0YB2GYAY= cloud.google.com/go/iam v0.13.0/go.mod h1:ljOg+rcNfzZ5d6f1nAUJ8ZIxOaZUVoS14bKCtaLZ/D0= -cloud.google.com/go/iam v1.2.1 h1:QFct02HRb7H12J/3utj0qf5tobFh9V4vR6h9eX5EBRU= -cloud.google.com/go/iam v1.2.1/go.mod h1:3VUIJDPpwT6p/amXRC5GY8fCCh70lxPygguVtI0Z4/g= +cloud.google.com/go/iam v1.2.2 h1:ozUSofHUGf/F4tCNy/mu9tHLTaxZFLOUiKzjcgWHGIA= +cloud.google.com/go/iam v1.2.2/go.mod h1:0Ys8ccaZHdI1dEUilwzqng/6ps2YB6vRsjIe00/+6JY= cloud.google.com/go/iap v1.4.0/go.mod h1:RGFwRJdihTINIe4wZ2iCP0zF/qu18ZwyKxrhMhygBEc= cloud.google.com/go/iap v1.5.0/go.mod h1:UH/CGgKd4KyohZL5Pt0jSKE4m3FR51qg6FKQ/z/Ix9A= cloud.google.com/go/iap v1.6.0/go.mod h1:NSuvI9C/j7UdjGjIde7t7HBz+QTwBcapPE07+sSRcLk= @@ -348,8 +348,8 @@ cloud.google.com/go/kms v1.8.0/go.mod h1:4xFEhYFqvW+4VMELtZyxomGSYtSQKzM178ylFW4 cloud.google.com/go/kms v1.9.0/go.mod h1:qb1tPTgfF9RQP8e1wq4cLFErVuTJv7UsSC915J8dh3w= cloud.google.com/go/kms v1.10.0/go.mod h1:ng3KTUtQQU9bPX3+QGLsflZIHlkbn8amFAMY63m8d24= cloud.google.com/go/kms v1.10.1/go.mod h1:rIWk/TryCkR59GMC3YtHtXeLzd634lBbKenvyySAyYI= -cloud.google.com/go/kms v1.19.0 h1:x0OVJDl6UH1BSX4THKlMfdcFWoE4ruh90ZHuilZekrU= -cloud.google.com/go/kms v1.19.0/go.mod h1:e4imokuPJUc17Trz2s6lEXFDt8bgDmvpVynH39bdrHM= +cloud.google.com/go/kms v1.20.1 h1:og29Wv59uf2FVaZlesaiDAqHFzHaoUyHI3HYp9VUHVg= +cloud.google.com/go/kms v1.20.1/go.mod h1:LywpNiVCvzYNJWS9JUcGJSVTNSwPwi0vBAotzDqn2nc= cloud.google.com/go/language v1.4.0/go.mod h1:F9dRpNFQmJbkaop6g0JhSBXCNlO90e1KWx5iDdxbWic= cloud.google.com/go/language v1.6.0/go.mod h1:6dJ8t3B+lUYfStgls25GusK04NLh3eDLQnWM3mdEbhI= cloud.google.com/go/language v1.7.0/go.mod h1:DJ6dYN/W+SQOjF8e1hLQXMF21AkH2w9wiPzPCJa2MIE= @@ -360,11 +360,13 @@ cloud.google.com/go/lifesciences v0.6.0/go.mod h1:ddj6tSX/7BOnhxCSd3ZcETvtNr8NZ6 cloud.google.com/go/lifesciences v0.8.0/go.mod h1:lFxiEOMqII6XggGbOnKiyZ7IBwoIqA84ClvoezaA/bo= cloud.google.com/go/logging v1.6.1/go.mod h1:5ZO0mHHbvm8gEmeEUHrmDlTDSu5imF6MUP9OfilNXBw= cloud.google.com/go/logging v1.7.0/go.mod h1:3xjP2CjkM3ZkO73aj4ASA5wRPGGCRrPIAeNqVNkzY8M= +cloud.google.com/go/logging v1.12.0 h1:ex1igYcGFd4S/RZWOCU51StlIEuey5bjqwH9ZYjHibk= +cloud.google.com/go/logging v1.12.0/go.mod h1:wwYBt5HlYP1InnrtYI0wtwttpVU1rifnMT7RejksUAM= cloud.google.com/go/longrunning v0.1.1/go.mod h1:UUFxuDWkv22EuY93jjmDMFT5GPQKeFVJBIF6QlTqdsE= cloud.google.com/go/longrunning v0.3.0/go.mod h1:qth9Y41RRSUE69rDcOn6DdK3HfQfsUI0YSmW3iIlLJc= cloud.google.com/go/longrunning v0.4.1/go.mod h1:4iWDqhBZ70CvZ6BfETbvam3T8FMvLK+eFj0E6AaRQTo= -cloud.google.com/go/longrunning v0.6.1 h1:lOLTFxYpr8hcRtcwWir5ITh1PAKUD/sG2lKrTSYjyMc= -cloud.google.com/go/longrunning v0.6.1/go.mod h1:nHISoOZpBcmlwbJmiVk5oDRz0qG/ZxPynEGs1iZ79s0= +cloud.google.com/go/longrunning v0.6.2 h1:xjDfh1pQcWPEvnfjZmwjKQEcHnpz6lHjfy7Fo0MK+hc= +cloud.google.com/go/longrunning v0.6.2/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI= cloud.google.com/go/managedidentities v1.3.0/go.mod h1:UzlW3cBOiPrzucO5qWkNkh0w33KFtBJU281hacNvsdE= cloud.google.com/go/managedidentities v1.4.0/go.mod h1:NWSBYbEMgqmbZsLIyKvxrYbtqOsxY1ZrGM+9RgDqInM= cloud.google.com/go/managedidentities v1.5.0/go.mod h1:+dWcZ0JlUmpuxpIDfyP5pP5y0bLdRwOS4Lp7gMni/LA= @@ -388,8 +390,8 @@ cloud.google.com/go/monitoring v1.7.0/go.mod h1:HpYse6kkGo//7p6sT0wsIC6IBDET0RhI cloud.google.com/go/monitoring v1.8.0/go.mod h1:E7PtoMJ1kQXWxPjB6mv2fhC5/15jInuulFdYYtlcvT4= cloud.google.com/go/monitoring v1.12.0/go.mod h1:yx8Jj2fZNEkL/GYZyTLS4ZtZEZN8WtDEiEqG4kLK50w= cloud.google.com/go/monitoring v1.13.0/go.mod h1:k2yMBAB1H9JT/QETjNkgdCGD9bPF712XiLTVr+cBrpw= -cloud.google.com/go/monitoring v1.21.1 h1:zWtbIoBMnU5LP9A/fz8LmWMGHpk4skdfeiaa66QdFGc= -cloud.google.com/go/monitoring v1.21.1/go.mod h1:Rj++LKrlht9uBi8+Eb530dIrzG/cU/lB8mt+lbeFK1c= +cloud.google.com/go/monitoring v1.21.2 h1:FChwVtClH19E7pJ+e0xUhJPGksctZNVOk2UhMmblmdU= +cloud.google.com/go/monitoring v1.21.2/go.mod h1:hS3pXvaG8KgWTSz+dAdyzPrGUYmi2Q+WFX8g2hqVEZU= cloud.google.com/go/networkconnectivity v1.4.0/go.mod h1:nOl7YL8odKyAOtzNX73/M5/mGZgqqMeryi6UPZTk/rA= cloud.google.com/go/networkconnectivity v1.5.0/go.mod h1:3GzqJx7uhtlM3kln0+x5wyFvuVH1pIBJjhCpjzSt75o= cloud.google.com/go/networkconnectivity v1.6.0/go.mod h1:OJOoEXW+0LAxHh89nXd64uGG+FbQoeH8DtxCHVOMlaM= @@ -449,8 +451,8 @@ cloud.google.com/go/pubsub v1.26.0/go.mod h1:QgBH3U/jdJy/ftjPhTkyXNj543Tin1pRYcd cloud.google.com/go/pubsub v1.27.1/go.mod h1:hQN39ymbV9geqBnfQq6Xf63yNhUAhv9CZhzp5O6qsW0= cloud.google.com/go/pubsub v1.28.0/go.mod h1:vuXFpwaVoIPQMGXqRyUQigu/AX1S3IWugR9xznmcXX8= cloud.google.com/go/pubsub v1.30.0/go.mod h1:qWi1OPS0B+b5L+Sg6Gmc9zD1Y+HaM0MdUr7LsupY1P4= -cloud.google.com/go/pubsub v1.43.0 h1:s3Qx+F96J7Kwey/uVHdK3QxFLIlOvvw4SfMYw2jFjb4= -cloud.google.com/go/pubsub v1.43.0/go.mod h1:LNLfqItblovg7mHWgU5g84Vhza4J8kTxx0YqIeTzcXY= +cloud.google.com/go/pubsub v1.45.3 h1:prYj8EEAAAwkp6WNoGTE4ahe0DgHoyJd5Pbop931zow= +cloud.google.com/go/pubsub v1.45.3/go.mod h1:cGyloK/hXC4at7smAtxFnXprKEFTqmMXNNd9w+bd94Q= cloud.google.com/go/pubsublite v1.5.0/go.mod h1:xapqNQ1CuLfGi23Yda/9l4bBCKz/wC3KIJ5gKcxveZg= cloud.google.com/go/pubsublite v1.6.0/go.mod h1:1eFCS0U11xlOuMFV/0iBqw3zP12kddMeCbj/F3FSj9k= cloud.google.com/go/pubsublite v1.7.0/go.mod h1:8hVMwRXfDfvGm3fahVbtDbiLePT3gpoiJYJY+vxWxVM= @@ -540,8 +542,8 @@ cloud.google.com/go/shell v1.6.0/go.mod h1:oHO8QACS90luWgxP3N9iZVuEiSF84zNyLytb+ cloud.google.com/go/spanner v1.41.0/go.mod h1:MLYDBJR/dY4Wt7ZaMIQ7rXOTLjYrmxLE/5ve9vFfWos= cloud.google.com/go/spanner v1.44.0/go.mod h1:G8XIgYdOK+Fbcpbs7p2fiprDw4CaZX63whnSMLVBxjk= cloud.google.com/go/spanner v1.45.0/go.mod h1:FIws5LowYz8YAE1J8fOS7DJup8ff7xJeetWEo5REA2M= -cloud.google.com/go/spanner v1.67.0 h1:h8xfobxh5lQu4qJVMPH+wSiyU+ZM6ZTxRNqGeu9iIVA= -cloud.google.com/go/spanner v1.67.0/go.mod h1:Um+TNmxfcCHqNCKid4rmAMvoe/Iu1vdz6UfxJ9GPxRQ= +cloud.google.com/go/spanner v1.73.0 h1:0bab8QDn6MNj9lNK6XyGAVFhMlhMU2waePPa6GZNoi8= +cloud.google.com/go/spanner v1.73.0/go.mod h1:mw98ua5ggQXVWwp83yjwggqEmW9t8rjs9Po1ohcUGW4= cloud.google.com/go/speech v1.6.0/go.mod h1:79tcr4FHCimOp56lwC01xnt/WPJZc4v3gzyT7FoBkCM= cloud.google.com/go/speech v1.7.0/go.mod h1:KptqL+BAQIhMsj1kOP2la5DSEEerPDuOP/2mmkhHhZQ= cloud.google.com/go/speech v1.8.0/go.mod h1:9bYIl1/tjsAnMgKGHKmBZzXKEkGgtU+MpdDPTE9f7y0= @@ -559,8 +561,8 @@ cloud.google.com/go/storage v1.23.0/go.mod h1:vOEEDNFnciUMhBeT6hsJIn3ieU5cFRmzeL cloud.google.com/go/storage v1.27.0/go.mod h1:x9DOL8TK/ygDUMieqwfhdpQryTeEkhGKMi80i/iqR2s= cloud.google.com/go/storage v1.28.1/go.mod h1:Qnisd4CqDdo6BGs2AD5LLnEsmSQ80wQ5ogcBBKhU86Y= cloud.google.com/go/storage v1.29.0/go.mod h1:4puEjyTKnku6gfKoTfNOU/W+a9JyuVNxjpS5GBrB8h4= -cloud.google.com/go/storage v1.43.0 h1:CcxnSohZwizt4LCzQHWvBf1/kvtHUn7gk9QERXPyXFs= -cloud.google.com/go/storage v1.43.0/go.mod h1:ajvxEa7WmZS1PxvKRq4bq0tFT3vMd502JwstCcYv0Q0= +cloud.google.com/go/storage v1.47.0 h1:ajqgt30fnOMmLfWfu1PWcb+V9Dxz6n+9WKjdNg5R4HM= +cloud.google.com/go/storage v1.47.0/go.mod h1:Ks0vP374w0PW6jOUameJbapbQKXqkjGd/OJRp2fb9IQ= cloud.google.com/go/storagetransfer v1.5.0/go.mod h1:dxNzUopWy7RQevYFHewchb29POFv3/AaBgnhqzqiK0w= cloud.google.com/go/storagetransfer v1.6.0/go.mod h1:y77xm4CQV/ZhFZH75PLEXY0ROiS7Gh6pSKrM8dJyg6I= cloud.google.com/go/storagetransfer v1.7.0/go.mod h1:8Giuj1QNb1kfLAiWM1bN6dHzfdlDAVC9rv9abHot2W4= @@ -580,6 +582,8 @@ cloud.google.com/go/trace v1.3.0/go.mod h1:FFUE83d9Ca57C+K8rDl/Ih8LwOzWIV1krKgxg cloud.google.com/go/trace v1.4.0/go.mod h1:UG0v8UBqzusp+z63o7FK74SdFE+AXpCLdFb1rshXG+Y= cloud.google.com/go/trace v1.8.0/go.mod h1:zH7vcsbAhklH8hWFig58HvxcxyQbaIqMarMg9hn5ECA= cloud.google.com/go/trace v1.9.0/go.mod h1:lOQqpE5IaWY0Ixg7/r2SjixMuc6lfTFeO4QGM4dQWOk= +cloud.google.com/go/trace v1.11.2 h1:4ZmaBdL8Ng/ajrgKqY5jfvzqMXbrDcBsUGXOT9aqTtI= +cloud.google.com/go/trace v1.11.2/go.mod h1:bn7OwXd4pd5rFuAnTrzBuoZ4ax2XQeG3qNgYmfCy0Io= cloud.google.com/go/translate v1.3.0/go.mod h1:gzMUwRjvOqj5i69y/LYLd8RrNQk+hOmIXTi9+nb3Djs= cloud.google.com/go/translate v1.4.0/go.mod h1:06Dn/ppvLD6WvA5Rhdp029IX2Mi3Mn7fpMRLPvXT5Wg= cloud.google.com/go/translate v1.5.0/go.mod h1:29YDSYveqqpA1CQFD7NQuP49xymq17RXNaUDdc0mNu0= @@ -649,6 +653,14 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.0 h1:oVLqHXhnYtUwM89y9T1fXGaK9wTkXHgNp8/ZNMQzUxE= github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.0/go.mod h1:dppbR7CwXD4pgtV9t3wD1812RaLDcBjtblcDF5f1vI0= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.24.1 h1:pB2F2JKCj1Znmp2rwxxt1J0Fg0wezTMgWYk5Mpbi1kg= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.24.1/go.mod h1:itPGVDKf9cC/ov4MdvJ2QZ0khw4bfoo9jzwTJlaxy2k= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1 h1:UQ0AhxogsIRZDkElkblfnwjc3IaltCm2HUMvezQaL7s= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1/go.mod h1:jyqM3eLpJ3IbIFDTKVz2rF9T/xWGW0rIriGwnz8l9Tk= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.48.1 h1:oTX4vsorBZo/Zdum6OKPA4o7544hm6smoRv1QjpTwGo= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.48.1/go.mod h1:0wEl7vrAD8mehJyohS9HZy+WyEOaQO2mJx86Cvh93kM= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 h1:8nn+rsCvTq9axyEh382S0PFLBeaFwNsT43IrPWzctRU= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1/go.mod h1:viRWSEhtMZqz1rhwmOVKkWl6SwmVowfL9O2YR5gI2PE= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= @@ -677,56 +689,56 @@ github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo= github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250= -github.com/aws/aws-sdk-go-v2 v1.31.0 h1:3V05LbxTSItI5kUqNwhJrrrY1BAXxXt0sN0l72QmG5U= -github.com/aws/aws-sdk-go-v2 v1.31.0/go.mod h1:ztolYtaEUtdpf9Wftr31CJfLVjOnD/CVRkKOOYgF8hA= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5 h1:xDAuZTn4IMm8o1LnBZvmrL8JA1io4o3YWNXgohbf20g= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.5/go.mod h1:wYSv6iDS621sEFLfKvpPE2ugjTuGlAG7iROg0hLOkfc= +github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= +github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= -github.com/aws/aws-sdk-go-v2/config v1.27.39 h1:FCylu78eTGzW1ynHcongXK9YHtoXD5AiiUqq3YfJYjU= -github.com/aws/aws-sdk-go-v2/config v1.27.39/go.mod h1:wczj2hbyskP4LjMKBEZwPRO1shXY+GsQleab+ZXT2ik= +github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo= +github.com/aws/aws-sdk-go-v2/config v1.28.6/go.mod h1:GDzxJ5wyyFSCoLkS+UhGB0dArhb9mI+Co4dHtoTxbko= github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc= -github.com/aws/aws-sdk-go-v2/credentials v1.17.37 h1:G2aOH01yW8X373JK419THj5QVqu9vKEwxSEsGxihoW0= -github.com/aws/aws-sdk-go-v2/credentials v1.17.37/go.mod h1:0ecCjlb7htYCptRD45lXJ6aJDQac6D2NlKGpZqyTG6A= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6jUDNkDTIsqDiMUdw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 h1:C/d03NAmh8C4BZXhuRNboF/DqhBkBCeDiJDcaqIT5pA= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14/go.mod h1:7I0Ju7p9mCIdlrfS+JCgqcYD0VXz/N4yozsox+0o078= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.25 h1:HkpHeZMM39sGtMHVYG1buAg93vhj5d7F81y6G0OAbGc= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.25/go.mod h1:j3Vz04ZjaWA6kygOsZRpmWe4CyGqfqq2u3unDTU0QGA= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18 h1:kYQ3H1u0ANr9KEKlGs/jTLrBFPo8P8NaH/w7A01NeeM= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.18/go.mod h1:r506HmK5JDUh9+Mw4CfGJGSSoqIiLCndAuqXuhbv67Y= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18 h1:Z7IdFUONvTcvS7YuhtVxN99v2cCoHRXOS4mTr0B/pUc= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.18/go.mod h1:DkKMmksZVVyat+Y+r1dEOgJEfUeA7UngIHWeKsi0yNc= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 h1:iLdpkYZ4cXIQMO7ud+cqMWR1xK5ESbt1rvN77tRi1BY= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43/go.mod h1:OgbsKPAswXDd5kxnR4vZov69p3oYjbvUyIRBAAV0y9o= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25/go.mod h1:IgPfDv5jqFIzQSNbUEMoitNooSMXjRSDkhXv8jiROvU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 h1:ZntTCl5EsYnhN/IygQEUugpdwbhdkom9uHcbCftiGgA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25/go.mod h1:DBdPrgeocww+CSl1C8cEV8PN1mHMBhuCDLpXezyvWkE= github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.18 h1:OWYvKL53l1rbsUmW7bQyJVsYU/Ii3bbAAQIIFNbM0Tk= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.18/go.mod h1:CUx0G1v3wG6l01tUB+j7Y8kclA8NSqK4ef0YG79a4cg= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.25 h1:r67ps7oHCYnflpgDy2LZU0MAQtQbYIOqNNnqGO6xQkE= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.25/go.mod h1:GrGY+Q4fIokYLtjCVB/aFfCVL6hhGUFl8inD18fDalE= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1/go.mod h1:v33JQ57i2nekYTA70Mb+O18KeH4KqhdqxTJZNK1zdRE= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5 h1:QFASJGfT8wMXtuP3D5CRmMjARHv9ZmzFUMJznHDOY3w= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.5/go.mod h1:QdZ3OmoIjSX+8D1OPAzPxDfjXASbBMDsz9qvtyIhtik= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.20 h1:rTWjG6AvWekO2B1LHeM3ktU7MqyX9rzWQ7hgzneZW7E= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.20/go.mod h1:RGW2DDpVc8hu6Y6yG8G5CHVmVOAn1oV8rNKOHRJyswg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.6 h1:HCpPsWqmYQieU7SS6E9HXfdAMSud0pteVXieJmcpIRI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.6/go.mod h1:ngUiVRCco++u+soRRVBIvBZxSMMvOVMXA4PJ36JLfSw= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20 h1:Xbwbmk44URTiHNx6PNo0ujDE6ERlsCKJD3u1zfnzAPg= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.20/go.mod h1:oAfOFzUB14ltPZj1rWwRc3d/6OgD76R8KlvU3EqM9Fg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 h1:50+XsN70RS7dwJ2CkVNXzj7U2L1HKP8nqTd3XWEXBN4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6/go.mod h1:WqgLmwY7so32kG01zD8CPTJWVWM+TzJoOVHwTg4aPug= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.18 h1:eb+tFOIl9ZsUe2259/BKPeniKuz4/02zZFH/i4Nf8Rg= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.18/go.mod h1:GVCC2IJNJTmdlyEsSmofEy7EfJncP7DNnXDzRjJ5Keg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.6 h1:BbGDtTi0T1DYlmjBiCr/le3wzhA37O8QTC5/Ab8+EXk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.6/go.mod h1:hLMJt7Q8ePgViKupeymbqI0la+t9/iYFBjxQCFwuAwI= github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32yTv8GpBquCApZEycDLunI= -github.com/aws/aws-sdk-go-v2/service/s3 v1.63.3 h1:3zt8qqznMuAZWDTDpcwv9Xr11M/lVj2FsRR7oYBt0OA= -github.com/aws/aws-sdk-go-v2/service/s3 v1.63.3/go.mod h1:NLTqRLe3pUNu3nTEHI6XlHLKYmc8fbHUdMxAB6+s41Q= +github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0 h1:nyuzXooUNJexRT0Oy0UQY6AhOzxPxhtt4DcBIHyCnmw= +github.com/aws/aws-sdk-go-v2/service/s3 v1.71.0/go.mod h1:sT/iQz8JK3u/5gZkT+Hmr7GzVZehUMkRZpOaAwYXeGY= github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM= -github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 h1:rs4JCczF805+FDv2tRhZ1NU0RB2H6ryAvsWPanAr72Y= -github.com/aws/aws-sdk-go-v2/service/sso v1.23.3/go.mod h1:XRlMvmad0ZNL+75C5FYdMvbbLkd6qiqz6foR1nA1PXY= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 h1:S7EPdMVZod8BGKQQPTBK+FcX9g7bKR7c4+HxWqHP7Vg= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDXIzAB9GAwVSzFzSy97uZ3IsHo4E= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6/go.mod h1:URronUEGfXZN1VpdktPSD1EkAL9mfrV+2F4sjH38qOY= github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg= -github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= -github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 h1:s4074ZO1Hk8qv65GqNXqDjmkf4HSQqJukaLuuW0TpDA= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2/go.mod h1:mVggCnIWoM09jP71Wh+ea7+5gAp53q+49wDFs1SW5z8= github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= -github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= -github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= @@ -757,8 +769,8 @@ github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20220314180256-7f1daf1720fc/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/cncf/xds/go v0.0.0-20240822171458-6449f94b4d59 h1:fLZ97KE86ELjEYJCEUVzmbhfzDxHHGwBrDVMd4XL6Bs= -github.com/cncf/xds/go v0.0.0-20240822171458-6449f94b4d59/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 h1:QVw89YDxXxEe+l8gU8ETbOasdwEV+avkR75ZzsVV9WI= +github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/colinmarc/hdfs/v2 v2.1.1/go.mod h1:M3x+k8UKKmxtFu++uAZ0OtDU8jR3jnaZIAc6yK4Ue0c= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= @@ -841,6 +853,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-cz/devslog v0.0.11 h1:v4Yb9o0ZpuZ/D8ZrtVw1f9q5XrjnkxwHF1XmWwO8IHg= +github.com/golang-cz/devslog v0.0.11/go.mod h1:bSe5bm0A7Nyfqtijf1OMNgVJHlWEuVSXnkuASiE1vV8= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= @@ -964,8 +978,8 @@ github.com/googleapis/gax-go/v2 v2.5.1/go.mod h1:h6B0KMMFNtI2ddbGJn3T3ZbwkeT6yqE github.com/googleapis/gax-go/v2 v2.6.0/go.mod h1:1mjbznJAPHFpesgE5ucqfYEscaz5kMdcIDwU/6+DDoY= github.com/googleapis/gax-go/v2 v2.7.0/go.mod h1:TEop28CZZQ2y+c0VxMUmu1lV+fQx57QpBWsYpwqHJx8= github.com/googleapis/gax-go/v2 v2.7.1/go.mod h1:4orTrqY6hXxxaUL4LHIPl6lGo8vAE38/qKbhSAKP6QI= -github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDPT0hH1s= -github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A= +github.com/googleapis/gax-go/v2 v2.14.0 h1:f+jMrjBPl+DL9nI4IQzLUxMq7XrAqFYB7hBPqMNIe8o= +github.com/googleapis/gax-go/v2 v2.14.0/go.mod h1:lhBCnjdLrWRaPvLWhmc8IS24m9mr07qSYnHncrgo+zk= github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= @@ -1015,8 +1029,8 @@ github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= @@ -1071,12 +1085,12 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/nats-io/jwt/v2 v2.5.8 h1:uvdSzwWiEGWGXf+0Q+70qv6AQdvcvxrv9hPM0RiPamE= github.com/nats-io/jwt/v2 v2.5.8/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= -github.com/nats-io/nats-server/v2 v2.10.18 h1:tRdZmBuWKVAFYtayqlBB2BuCHNGAQPvoQIXOKwU3WSM= -github.com/nats-io/nats-server/v2 v2.10.18/go.mod h1:97Qyg7YydD8blKlR8yBsUlPlWyZKjA7Bp5cl3MUE9K8= +github.com/nats-io/nats-server/v2 v2.10.23 h1:jvfb9cEi5h8UG6HkZgJGdn9f1UPaX3Dohk0PohEekJI= +github.com/nats-io/nats-server/v2 v2.10.23/go.mod h1:hMFnpDT2XUXsvHglABlFl/uroQCCOcW6X/0esW6GpBk= github.com/nats-io/nats.go v1.37.0 h1:07rauXbVnnJvv1gfIyghFEo6lUcYRY0WXc3x7x0vUxE= github.com/nats-io/nats.go v1.37.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= -github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= -github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= +github.com/nats-io/nkeys v0.4.8 h1:+wee30071y3vCZAYRsnrmIPaOe47A/SkK/UBDPdIV70= +github.com/nats-io/nkeys v0.4.8/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/ncw/swift v1.0.52/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= @@ -1115,8 +1129,9 @@ github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qq github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -1164,8 +1179,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/testcontainers/testcontainers-go v0.33.0 h1:zJS9PfXYT5O0ZFXM2xxXfk4J5UMw/kRiISng037Gxdw= github.com/testcontainers/testcontainers-go v0.33.0/go.mod h1:W80YpTa8D5C3Yy16icheD01UTDu+LmXIA2Keo+jWtT8= -github.com/tetratelabs/wazero v1.8.0 h1:iEKu0d4c2Pd+QSRieYbnQC9yiFlMS9D+Jr0LsRmcF4g= -github.com/tetratelabs/wazero v1.8.0/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= +github.com/tetratelabs/wazero v1.8.1 h1:NrcgVbWfkWvVc4UtT4LRLDf91PsOzDzefMdwhLfA550= +github.com/tetratelabs/wazero v1.8.1/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= @@ -1212,6 +1227,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/contrib/detectors/gcp v1.29.0 h1:TiaiXB4DpGD3sdzNlYQxruQngn5Apwzi1X0DRhuGvDQ= +go.opentelemetry.io/contrib/detectors/gcp v1.29.0/go.mod h1:GW2aWZNwR2ZxDLdv8OyC2G8zkRoQBuURgV7RPQgcPoU= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= @@ -1222,6 +1239,8 @@ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.29.0 h1:WDdP9acbMYjbKIyJUhTvtzj601sVJOqgWdUxSdR/Ysc= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.29.0/go.mod h1:BLbf7zbNIONBLPwvFnwNHGj4zge8uTCM/UPIVW1Mq2I= go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= go.opentelemetry.io/otel/sdk v1.29.0 h1:vkqKjk7gwhS8VaWb0POZKmIEDimRCMsopNYnriHyryo= @@ -1247,8 +1266,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.30.0 h1:RwoQn3GkWiMkzlX562cLB7OxWvjH1L8xutO2WoJcRoY= +golang.org/x/crypto v0.30.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1369,8 +1388,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= -golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= +golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1400,8 +1419,8 @@ golang.org/x/oauth2 v0.4.0/go.mod h1:RznEsdpjGAINPTOF0UH/t+xJ75L18YO3Ho6Pyn+uRec golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I= golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= -golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= -golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= +golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1418,8 +1437,8 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1507,8 +1526,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= @@ -1517,8 +1536,8 @@ golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= -golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= -golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1535,16 +1554,16 @@ golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= -golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1688,8 +1707,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/ google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI= google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0= google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg= -google.golang.org/api v0.199.0 h1:aWUXClp+VFJmqE0JPvpZOK3LDQMyFKYIow4etYd9qxs= -google.golang.org/api v0.199.0/go.mod h1:ohG4qSztDJmZdjK/Ar6MhbAmb/Rpi4JHOqagsh90K28= +google.golang.org/api v0.210.0 h1:HMNffZ57OoZCRYSbdWVRoqOa8V8NIHLL0CzdBPLztWk= +google.golang.org/api v0.210.0/go.mod h1:B9XDZGnx2NtyjzVkOVTGrFSAVZgPcbedzKg/gTLwqBs= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1829,12 +1848,12 @@ google.golang.org/genproto v0.0.0-20230323212658-478b75c54725/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230330154414-c0448cd141ea/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak= google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak= google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 h1:BulPr26Jqjnd4eYDVe+YvyR7Yc2vJGkO5/0UxD0/jZU= -google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:hL97c3SYopEHblzpxRL4lSs523++l8DYxGM1FQiYmb4= -google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 h1:hjSy6tcFQZ171igDaN5QHOw2n6vx40juYbC/x67CEhc= -google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:qpvKtACPCQhAdu3PyQgV4l3LMXZEtft7y8QcarRsp9I= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 h1:ToEetK57OidYuqD4Q5w+vfEnPvPpuTwedCNVohYJfNk= +google.golang.org/genproto v0.0.0-20241118233622-e639e219e697/go.mod h1:JJrvXBWRZaFMxBufik1a4RpFw4HhgVtBBWQeQgUj2cc= +google.golang.org/genproto/googleapis/api v0.0.0-20241113202542-65e8d215514f h1:M65LEviCfuZTfrfzwwEoxVtgvfkFkBUbFnRbxCXuXhU= +google.golang.org/genproto/googleapis/api v0.0.0-20241113202542-65e8d215514f/go.mod h1:Yo94eF2nj7igQt+TiJ49KxjIH8ndLYPZMIRSiRcEbg0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241118233622-e639e219e697 h1:LWZqQOEjDyONlF1H6afSWpAL/znlREo2tHfLoe+8LMA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241118233622-e639e219e697/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1876,9 +1895,11 @@ google.golang.org/grpc v1.52.3/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5v google.golang.org/grpc v1.53.0/go.mod h1:OnIrk0ipVdj4N5d9IUoFUx72/VlD7+jUsHwZgwSMQpw= google.golang.org/grpc v1.54.0/go.mod h1:PUSEXI6iWghWaB6lXM4knEgpJNu2qUcKfDtNci3EC2g= google.golang.org/grpc v1.56.3/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= -google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw= -google.golang.org/grpc v1.67.0/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= +google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= +google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= +google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a h1:UIpYSuWdWHSzjwcAFRLjKcPXFZVVLXGEM23W+NWqipw= +google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a/go.mod h1:9i1T9n4ZinTUZGgzENMi8MDDgbGC5mqTS75JAv6xN3A= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -1896,8 +1917,8 @@ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= +google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/sdks/go/BUILD.md b/sdks/go/BUILD.md index 9834c8ddee89..e2606f3597ee 100644 --- a/sdks/go/BUILD.md +++ b/sdks/go/BUILD.md @@ -62,9 +62,11 @@ To develop the SDK, it should be sufficient to clone the repository, make changes and execute tests from within the module directory (`/sdks/...`). Go users can just `go get` the code directly. For example: -``` + +```bash go get github.com/apache/beam/sdks/v2/go/pkg/beam ``` + Developers must invoke Go for cross-compilation manually, if desired. If you make changes to .proto files, you will need to rebuild the generated code. @@ -72,10 +74,13 @@ Consult `pkg/beam/model/PROTOBUF.md`. If you make changes to .tmpl files, then add the specialize tool to your path. You can install specialize using: -``` + +```bash go get github.com/apache/beam/sdks/v2/go/cmd/specialize ``` + Add it to your path: -``` + +```bash export PATH=$PATH:$GOROOT/bin:$GOPATH/bin ``` diff --git a/sdks/go/README.md b/sdks/go/README.md index 7734d58d9eb9..bcfba2742590 100644 --- a/sdks/go/README.md +++ b/sdks/go/README.md @@ -33,7 +33,7 @@ The examples are normal Go programs and are most easily run directly. They are parameterized by Go flags. For example, to run wordcount on the Go direct runner do: -``` +```bash $ pwd [...]/sdks/go $ go run examples/wordcount/wordcount.go --output=/tmp/result.txt @@ -70,7 +70,7 @@ Edges: 1: Impulse [] -> [Out: []uint8 -> {1: []uint8/GW/bytes}] The debugging output is currently quite verbose and likely to change. The output is a local file in this case: -``` +```bash $ head /tmp/result.txt while: 2 darkling: 1 @@ -86,13 +86,13 @@ purse: 6 To run wordcount on dataflow runner do: -``` -$ go run wordcount.go --runner=dataflow --project= --region= --staging_location=/staging --worker_harness_container_image= --output=/output +```bash +$ go run wordcount.go --runner=dataflow --project= --region= --staging_location=/staging --worker_harness_container_image= --output=/output ``` The output is a GCS file in this case: -``` +```bash $ gsutil cat /output* | head Blanket: 1 blot: 1 @@ -106,7 +106,6 @@ sport: 3 Crown'd: 1 ``` - See [BUILD.md](./BUILD.md) for how to build Go code in general. See [container documentation](https://beam.apache.org/documentation/runtime/environments/#building-container-images) for how to build and push the Go SDK harness container image. @@ -117,9 +116,10 @@ Please use the [`sdk-go`](https://github.com/apache/beam/issues?q=is%3Aopen+is%3 ## Contributing to the Go SDK ### New to developing Go? -https://tour.golang.org : The Go Tour gives you the basics of the language, interactively no installation required. -https://github.com/campoy/go-tooling-workshop is a great start on learning good (optional) development tools for Go. + : The Go Tour gives you the basics of the language, interactively no installation required. + + is a great start on learning good (optional) development tools for Go. ### Developing Go Beam SDK on Github @@ -130,11 +130,10 @@ Executing all unit tests for the SDK is possible from the `\sdks\go` To test your change as Jenkins would execute it from a PR, from the beam root directory, run: - * `./gradlew :sdks:go:goTest` executes the unit tests. - * `./gradlew :sdks:go:test:prismValidatesRunner` validates the SDK against the Go Prism runner as a stand alone binary, with containers. - * `./gradlew :sdks:go:test:ulrValidatesRunner` validates the SDK against the Portable Python runner. - * `./gradlew :sdks:go:test:flinkValidatesRunner` validates the SDK against the Flink runner. - -Follow the [contribution guide](https://beam.apache.org/contribute/contribution-guide/#code) to create branches, and submit pull requests as normal. +* `./gradlew :sdks:go:goTest` executes the unit tests. +* `./gradlew :sdks:go:test:prismValidatesRunner` validates the SDK against the Go Prism runner as a stand alone binary, with containers. +* `./gradlew :sdks:go:test:ulrValidatesRunner` validates the SDK against the Portable Python runner. +* `./gradlew :sdks:go:test:flinkValidatesRunner` validates the SDK against the Flink runner. +Follow the [contribution guide](https://beam.apache.org/contribute/contribution-guide/#code) to create branches, and submit pull requests as normal. diff --git a/sdks/go/cmd/prism/prism.go b/sdks/go/cmd/prism/prism.go index 39c19df00dc3..5e3f42a9e5a5 100644 --- a/sdks/go/cmd/prism/prism.go +++ b/sdks/go/cmd/prism/prism.go @@ -22,9 +22,14 @@ import ( "flag" "fmt" "log" + "log/slog" + "os" + "strings" + "time" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism" + "github.com/golang-cz/devslog" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -37,10 +42,59 @@ var ( idleShutdownTimeout = flag.Duration("idle_shutdown_timeout", -1, "duration that prism will wait for a new job before shutting itself down. Negative durations disable auto shutdown. Defaults to never shutting down.") ) +// Logging flags +var ( + logKind = flag.String("log_kind", "dev", + "Determines the format of prism's logging to std err: valid values are `dev', 'json', or 'text'. Default is `dev`.") + logLevelFlag = flag.String("log_level", "info", + "Sets the minimum log level of Prism. Valid options are 'debug', 'info','warn', and 'error'. Default is 'info'. Debug adds prism source lines.") +) + +var logLevel = new(slog.LevelVar) + func main() { flag.Parse() ctx, cancel := context.WithCancelCause(context.Background()) + var logHandler slog.Handler + loggerOutput := os.Stderr + handlerOpts := &slog.HandlerOptions{ + Level: logLevel, + } + switch strings.ToLower(*logLevelFlag) { + case "debug": + logLevel.Set(slog.LevelDebug) + handlerOpts.AddSource = true + case "info": + logLevel.Set(slog.LevelInfo) + case "warn": + logLevel.Set(slog.LevelWarn) + case "error": + logLevel.Set(slog.LevelError) + default: + log.Fatalf("Invalid value for log_level: %v, must be 'debug', 'info', 'warn', or 'error'", *logKind) + } + switch strings.ToLower(*logKind) { + case "dev": + logHandler = + devslog.NewHandler(loggerOutput, &devslog.Options{ + TimeFormat: "[" + time.RFC3339Nano + "]", + StringerFormatter: true, + HandlerOptions: handlerOpts, + StringIndentation: false, + NewLineAfterLog: true, + MaxErrorStackTrace: 3, + }) + case "json": + logHandler = slog.NewJSONHandler(loggerOutput, handlerOpts) + case "text": + logHandler = slog.NewTextHandler(loggerOutput, handlerOpts) + default: + log.Fatalf("Invalid value for log_kind: %v, must be 'dev', 'json', or 'text'", *logKind) + } + + slog.SetDefault(slog.New(logHandler)) + cli, err := makeJobClient(ctx, prism.Options{ Port: *jobPort, diff --git a/sdks/go/container/tools/buffered_logging.go b/sdks/go/container/tools/buffered_logging.go index 445d19fabfdc..a7b84e56af3a 100644 --- a/sdks/go/container/tools/buffered_logging.go +++ b/sdks/go/container/tools/buffered_logging.go @@ -18,13 +18,15 @@ package tools import ( "context" "log" - "math" "os" "strings" "time" ) -const initialLogSize int = 255 +const ( + initialLogSize int = 255 + defaultFlushInterval time.Duration = 15 * time.Second +) // BufferedLogger is a wrapper around the FnAPI logging client meant to be used // in place of stdout and stderr in bootloader subprocesses. Not intended for @@ -41,7 +43,7 @@ type BufferedLogger struct { // NewBufferedLogger returns a new BufferedLogger type by reference. func NewBufferedLogger(logger *Logger) *BufferedLogger { - return &BufferedLogger{logger: logger, lastFlush: time.Now(), flushInterval: time.Duration(math.MaxInt64), periodicFlushContext: context.Background(), now: time.Now} + return &BufferedLogger{logger: logger, lastFlush: time.Now(), flushInterval: defaultFlushInterval, periodicFlushContext: context.Background(), now: time.Now} } // NewBufferedLoggerWithFlushInterval returns a new BufferedLogger type by reference. This type will diff --git a/sdks/go/examples/stringsplit/stringsplit.go b/sdks/go/examples/stringsplit/stringsplit.go index 266cdd99fb37..76140075b625 100644 --- a/sdks/go/examples/stringsplit/stringsplit.go +++ b/sdks/go/examples/stringsplit/stringsplit.go @@ -21,7 +21,7 @@ // 1. From a command line, navigate to the top-level beam/ directory and run // the Flink job server: // -// ./gradlew :runners:flink:1.18:job-server:runShadow -Djob-host=localhost -Dflink-master=local +// ./gradlew :runners:flink:1.19:job-server:runShadow -Djob-host=localhost -Dflink-master=local // // 2. The job server is ready to receive jobs once it outputs a log like the // following: `JobService started on localhost:8099`. Take note of the endpoint diff --git a/sdks/go/examples/wasm/README.md b/sdks/go/examples/wasm/README.md index 84d30a3c6a63..103bef88642b 100644 --- a/sdks/go/examples/wasm/README.md +++ b/sdks/go/examples/wasm/README.md @@ -68,7 +68,7 @@ cd $BEAM_HOME Expected output should include the following, from which you acquire the latest flink runner version. ```shell -'flink_versions: 1.15,1.16,1.17,1.18' +'flink_versions: 1.17,1.18,1.19' ``` #### 2. Set to the latest flink runner version i.e. 1.16 diff --git a/sdks/go/pkg/beam/core/core.go b/sdks/go/pkg/beam/core/core.go index e1b660e99ac6..1b478f483077 100644 --- a/sdks/go/pkg/beam/core/core.go +++ b/sdks/go/pkg/beam/core/core.go @@ -27,7 +27,7 @@ const ( // SdkName is the human readable name of the SDK for UserAgents. SdkName = "Apache Beam SDK for Go" // SdkVersion is the current version of the SDK. - SdkVersion = "2.61.0.dev" + SdkVersion = "2.62.0.dev" // DefaultDockerImage represents the associated image for this release. DefaultDockerImage = "apache/beam_go_sdk:" + SdkVersion diff --git a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go index c71ead208364..06bb727178fc 100644 --- a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go +++ b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go @@ -19,12 +19,12 @@ import ( "bytes" "fmt" "log" + "log/slog" "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" - "golang.org/x/exp/slog" ) // FromMonitoringInfos extracts metrics from monitored states and diff --git a/sdks/go/pkg/beam/runners/prism/internal/coders.go b/sdks/go/pkg/beam/runners/prism/internal/coders.go index eb8abe16ecf8..ffea90e79065 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/coders.go +++ b/sdks/go/pkg/beam/runners/prism/internal/coders.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "strings" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" @@ -28,7 +29,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go index eaaf7f831712..7b8689f95112 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -18,12 +18,12 @@ package engine import ( "bytes" "fmt" + "log/slog" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" - "golang.org/x/exp/slog" ) // StateData is a "union" between Bag state and MultiMap state to increase common code. diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index f7229853e4d3..1739efdb742a 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -23,6 +23,7 @@ import ( "context" "fmt" "io" + "log/slog" "sort" "strings" "sync" @@ -36,7 +37,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" ) type element struct { @@ -1159,7 +1159,7 @@ func (ss *stageState) AddPending(newPending []element) int { } ss.pendingByKeys[string(e.keyBytes)] = dnt } - dnt.elements.Push(e) + heap.Push(&dnt.elements, e) if e.IsTimer() { if lastSet, ok := dnt.timers[timerKey{family: e.family, tag: e.tag, window: e.window}]; ok { @@ -1576,6 +1576,8 @@ func (ss *stageState) updateWatermarks(em *ElementManager) set[string] { // They'll never be read in again. for _, wins := range ss.sideInputs { for win := range wins { + // TODO(#https://github.com/apache/beam/issues/31438): + // Adjust with AllowedLateness // Clear out anything we've already used. if win.MaxTimestamp() < newOut { delete(wins, win) @@ -1584,7 +1586,8 @@ func (ss *stageState) updateWatermarks(em *ElementManager) set[string] { } for _, wins := range ss.state { for win := range wins { - // Clear out anything we've already used. + // TODO(#https://github.com/apache/beam/issues/31438): + // Adjust with AllowedLateness if win.MaxTimestamp() < newOut { delete(wins, win) } @@ -1607,7 +1610,7 @@ func (ss *stageState) bundleReady(em *ElementManager, emNow mtime.Time) (mtime.T inputW := ss.input _, upstreamW := ss.UpstreamWatermark() if inputW == upstreamW { - slog.Debug("bundleReady: insufficient upstream watermark", + slog.Debug("bundleReady: unchanged upstream watermark", slog.String("stage", ss.ID), slog.Group("watermark", slog.Any("upstream", upstreamW), diff --git a/sdks/go/pkg/beam/runners/prism/internal/environments.go b/sdks/go/pkg/beam/runners/prism/internal/environments.go index add7f769a702..2f960a04f0cb 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/environments.go +++ b/sdks/go/pkg/beam/runners/prism/internal/environments.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "os" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" @@ -27,7 +28,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" - "golang.org/x/exp/slog" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/proto" @@ -42,7 +42,7 @@ import ( // TODO move environment handling to the worker package. func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *worker.W) error { - logger := slog.With(slog.String("envID", wk.Env)) + logger := j.Logger.With(slog.String("envID", wk.Env)) // TODO fix broken abstraction. // We're starting a worker pool here, because that's the loopback environment. // It's sort of a mess, largely because of loopback, which has @@ -56,7 +56,7 @@ func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *wor } go func() { externalEnvironment(ctx, ep, wk) - slog.Debug("environment stopped", slog.String("job", j.String())) + logger.Debug("environment stopped", slog.String("job", j.String())) }() return nil case urns.EnvDocker: @@ -129,6 +129,8 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock credEnv := fmt.Sprintf("%v=%v", gcloudCredsEnv, dockerGcloudCredsFile) envs = append(envs, credEnv) } + } else { + logger.Debug("local GCP credentials environment variable not found") } if _, _, err := cli.ImageInspectWithRaw(ctx, dp.GetContainerImage()); err != nil { // We don't have a local image, so we should pull it. @@ -140,6 +142,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock logger.Warn("unable to pull image and it's not local", "error", err) } } + logger.Debug("creating container", "envs", envs, "mounts", mounts) ccr, err := cli.ContainerCreate(ctx, &container.Config{ Image: dp.GetContainerImage(), @@ -169,17 +172,32 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock return fmt.Errorf("unable to start container image %v with docker for env %v, err: %w", dp.GetContainerImage(), wk.Env, err) } + logger.Debug("container started") + // Start goroutine to wait on container state. go func() { defer cli.Close() defer wk.Stop() + defer func() { + logger.Debug("container stopped") + }() - statusCh, errCh := cli.ContainerWait(ctx, containerID, container.WaitConditionNotRunning) + bgctx := context.Background() + statusCh, errCh := cli.ContainerWait(bgctx, containerID, container.WaitConditionNotRunning) select { case <-ctx.Done(): - // Can't use command context, since it's already canceled here. - err := cli.ContainerKill(context.Background(), containerID, "") + rc, err := cli.ContainerLogs(bgctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) if err != nil { + logger.Error("error fetching container logs error on context cancellation", "error", err) + } + if rc != nil { + defer rc.Close() + var buf bytes.Buffer + stdcopy.StdCopy(&buf, &buf, rc) + logger.Info("container being killed", slog.Any("cause", context.Cause(ctx)), slog.Any("containerLog", buf)) + } + // Can't use command context, since it's already canceled here. + if err := cli.ContainerKill(bgctx, containerID, ""); err != nil { logger.Error("docker container kill error", "error", err) } case err := <-errCh: @@ -189,7 +207,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock case resp := <-statusCh: logger.Info("docker container has self terminated", "status_code", resp.StatusCode) - rc, err := cli.ContainerLogs(ctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) + rc, err := cli.ContainerLogs(bgctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) if err != nil { logger.Error("docker container logs error", "error", err) } diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index d7605f34f5f2..614edee47721 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "log/slog" "sort" "sync/atomic" "time" @@ -34,7 +35,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) @@ -311,7 +311,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic return fmt.Errorf("prism error building stage %v: \n%w", stage.ID, err) } stages[stage.ID] = stage - slog.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) + j.Logger.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) outputs := maps.Keys(stage.OutputsToCoders) sort.Strings(outputs) em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs) @@ -322,9 +322,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic em.StageProcessingTimeTimers(stage.ID, stage.processingTimeTimers) } default: - err := fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) - slog.Error("Execute", err) - return err + return fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) } } @@ -344,11 +342,13 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic for { select { case <-ctx.Done(): - return context.Cause(ctx) + err := context.Cause(ctx) + j.Logger.Debug("context canceled", slog.Any("cause", err)) + return err case rb, ok := <-bundles: if !ok { err := eg.Wait() - slog.Debug("pipeline done!", slog.String("job", j.String()), slog.Any("error", err)) + j.Logger.Debug("pipeline done!", slog.String("job", j.String()), slog.Any("error", err), slog.Any("topo", topo)) return err } eg.Go(func() error { diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go index 8590fd0d4ced..be9d39ad02b7 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go @@ -19,6 +19,7 @@ import ( "bytes" "fmt" "io" + "log/slog" "reflect" "sort" @@ -31,7 +32,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go index 99b786d45980..e42e3e7ca666 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go @@ -20,9 +20,9 @@ import ( "context" "fmt" "io" + "log/slog" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" ) @@ -77,7 +77,7 @@ func (s *Server) ReverseArtifactRetrievalService(stream jobpb.ArtifactStagingSer case *jobpb.ArtifactResponseWrapper_ResolveArtifactResponse: err := fmt.Errorf("unexpected ResolveArtifactResponse to GetArtifact: %v", in.GetResponse()) - slog.Error("GetArtifact failure", err) + slog.Error("GetArtifact failure", slog.Any("error", err)) return err } } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index 1407feafe325..deef259a99d1 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -27,6 +27,7 @@ package jobservices import ( "context" "fmt" + "log/slog" "sort" "strings" "sync" @@ -37,7 +38,6 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/protobuf/types/known/structpb" ) @@ -88,6 +88,8 @@ type Job struct { // Context used to terminate this job. RootCtx context.Context CancelFn context.CancelCauseFunc + // Logger for this job. + Logger *slog.Logger metrics metricsStore } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index b957b99ca63d..a2840760bf7a 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -19,6 +19,7 @@ import ( "context" "errors" "fmt" + "log/slog" "sync" "sync/atomic" @@ -27,7 +28,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -92,6 +92,7 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * cancelFn(err) terminalOnceWrap() }, + Logger: s.logger, // TODO substitute with a configured logger. artifactEndpoint: s.Endpoint(), } // Stop the idle timer when a new job appears. diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go index 03d5b0a98369..bbbdfd1eba4f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go @@ -19,6 +19,7 @@ import ( "bytes" "fmt" "hash/maphash" + "log/slog" "math" "sort" "sync" @@ -28,7 +29,6 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "golang.org/x/exp/constraints" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -589,7 +589,7 @@ func (m *metricsStore) AddShortIDs(resp *fnpb.MonitoringInfosMetadataResponse) { urn := mi.GetUrn() ops, ok := mUrn2Ops[urn] if !ok { - slog.Debug("unknown metrics urn", slog.String("urn", urn)) + slog.Debug("unknown metrics urn", slog.String("shortID", short), slog.String("urn", urn), slog.String("type", mi.Type)) continue } key := ops.keyFn(urn, mi.GetLabels()) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go index 320159f54c06..bdfe2aff2dd4 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -18,6 +18,7 @@ package jobservices import ( "context" "fmt" + "log/slog" "math" "net" "os" @@ -27,7 +28,6 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" - "golang.org/x/exp/slog" "google.golang.org/grpc" ) @@ -53,6 +53,7 @@ type Server struct { terminatedJobCount uint32 // Use with atomics. idleTimeout time.Duration cancelFn context.CancelCauseFunc + logger *slog.Logger // execute defines how a job is executed. execute func(*Job) @@ -71,8 +72,9 @@ func NewServer(port int, execute func(*Job)) *Server { lis: lis, jobs: make(map[string]*Job), execute: execute, + logger: slog.Default(), // TODO substitute with a configured logger. } - slog.Info("Serving JobManagement", slog.String("endpoint", s.Endpoint())) + s.logger.Info("Serving JobManagement", slog.String("endpoint", s.Endpoint())) opts := []grpc.ServerOption{ grpc.MaxRecvMsgSize(math.MaxInt32), } diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go index 7de32f85b7ee..dceaa9ab8fcb 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -17,6 +17,7 @@ package internal import ( "fmt" + "log/slog" "sort" "strings" @@ -26,7 +27,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go index 1be3d3e70841..650932f525c8 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go @@ -18,6 +18,7 @@ package internal_test import ( "context" "fmt" + "log/slog" "net" "net/http" "net/rpc" @@ -34,7 +35,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/stats" - "golang.org/x/exp/slog" ) // separate_test.go retains structures and tests to ensure the runner can @@ -286,7 +286,7 @@ func (ws *Watchers) Check(args *Args, unblocked *bool) error { w.mu.Lock() *unblocked = w.sentinelCount >= w.sentinelCap w.mu.Unlock() - slog.Debug("sentinel target for watcher%d is %d/%d. unblocked=%v", args.WatcherID, w.sentinelCount, w.sentinelCap, *unblocked) + slog.Debug("sentinel watcher status", slog.Int("watcher", args.WatcherID), slog.Int("sentinelCount", w.sentinelCount), slog.Int("sentinelCap", w.sentinelCap), slog.Bool("unblocked", *unblocked)) return nil } @@ -360,7 +360,7 @@ func (fn *sepHarnessBase) setup() error { sepClientOnce.Do(func() { client, err := rpc.DialHTTP("tcp", fn.LocalService) if err != nil { - slog.Error("failed to dial sentinels server", err, slog.String("endpoint", fn.LocalService)) + slog.Error("failed to dial sentinels server", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic(fmt.Sprintf("dialing sentinels server %v: %v", fn.LocalService, err)) } sepClient = client @@ -385,7 +385,7 @@ func (fn *sepHarnessBase) setup() error { var unblock bool err := sepClient.Call("Watchers.Check", &Args{WatcherID: id}, &unblock) if err != nil { - slog.Error("Watchers.Check: sentinels server error", err, slog.String("endpoint", fn.LocalService)) + slog.Error("Watchers.Check: sentinels server error", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic("sentinel server error") } if unblock { @@ -406,7 +406,7 @@ func (fn *sepHarnessBase) block() { var ignored bool err := sepClient.Call("Watchers.Block", &Args{WatcherID: fn.WatcherID}, &ignored) if err != nil { - slog.Error("Watchers.Block error", err, slog.String("endpoint", fn.LocalService)) + slog.Error("Watchers.Block error", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic(err) } c := sepWaitMap[fn.WatcherID] @@ -423,7 +423,7 @@ func (fn *sepHarnessBase) delay() bool { var delay bool err := sepClient.Call("Watchers.Delay", &Args{WatcherID: fn.WatcherID}, &delay) if err != nil { - slog.Error("Watchers.Delay error", err) + slog.Error("Watchers.Delay error", slog.Any("error", err)) panic(err) } return delay diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index f33754b2ca0a..9f00c22789b6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "runtime/debug" "sync/atomic" "time" @@ -33,7 +34,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) @@ -361,7 +361,7 @@ func portFor(wInCid string, wk *worker.W) []byte { } sourcePortBytes, err := proto.Marshal(sourcePort) if err != nil { - slog.Error("bad port", err, slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl())) + slog.Error("bad port", slog.Any("error", err), slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl())) } return sourcePortBytes } diff --git a/sdks/go/pkg/beam/runners/prism/internal/web/web.go b/sdks/go/pkg/beam/runners/prism/internal/web/web.go index 9fabe22cee3a..b14778e4462c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/web/web.go +++ b/sdks/go/pkg/beam/runners/prism/internal/web/web.go @@ -26,6 +26,7 @@ import ( "fmt" "html/template" "io" + "log/slog" "net/http" "sort" "strings" @@ -40,7 +41,6 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go index 3ccafdb81e9a..83ad1bda9841 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -19,12 +19,12 @@ import ( "bytes" "context" "fmt" + "log/slog" "sync/atomic" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" - "golang.org/x/exp/slog" ) // SideInputKey is for data lookups for a given bundle. @@ -126,7 +126,7 @@ func (b *B) ProcessOn(ctx context.Context, wk *W) <-chan struct{} { slog.Debug("processing", "bundle", b, "worker", wk) // Tell the SDK to start processing the bundle. - wk.InstReqs <- &fnpb.InstructionRequest{ + req := &fnpb.InstructionRequest{ InstructionId: b.InstID, Request: &fnpb.InstructionRequest_ProcessBundle{ ProcessBundle: &fnpb.ProcessBundleRequest{ @@ -134,6 +134,18 @@ func (b *B) ProcessOn(ctx context.Context, wk *W) <-chan struct{} { }, }, } + select { + case <-wk.StoppedChan: + // The worker was stopped before req was sent. + // Quit to avoid sending on a closed channel. + outCap := b.OutputCount + len(b.HasTimers) + for i := 0; i < outCap; i++ { + b.DataOrTimerDone() + } + return b.DataWait + case wk.InstReqs <- req: + // desired outcome + } // TODO: make batching decisions on the maxium to send per elements block, to reduce processing time overhead. for _, block := range b.Input { @@ -163,10 +175,13 @@ func (b *B) ProcessOn(ctx context.Context, wk *W) <-chan struct{} { } select { - case wk.DataReqs <- elms: + case <-wk.StoppedChan: + b.DataOrTimerDone() + return b.DataWait case <-ctx.Done(): b.DataOrTimerDone() return b.DataWait + case wk.DataReqs <- elms: } } @@ -181,6 +196,12 @@ func (b *B) ProcessOn(ctx context.Context, wk *W) <-chan struct{} { }) } select { + case <-wk.StoppedChan: + b.DataOrTimerDone() + return b.DataWait + case <-ctx.Done(): + b.DataOrTimerDone() + return b.DataWait case wk.DataReqs <- &fnpb.Elements{ Timers: timers, Data: []*fnpb.Elements_Data{ @@ -191,9 +212,6 @@ func (b *B) ProcessOn(ctx context.Context, wk *W) <-chan struct{} { }, }, }: - case <-ctx.Done(): - b.DataOrTimerDone() - return b.DataWait } return b.DataWait diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index f9ec03793488..c2c988aa097f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -22,10 +22,9 @@ import ( "context" "fmt" "io" + "log/slog" "math" "net" - "strconv" - "strings" "sync" "sync/atomic" "time" @@ -39,7 +38,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -70,6 +68,7 @@ type W struct { // These are the ID sources inst uint64 connected, stopped atomic.Bool + StoppedChan chan struct{} // Channel to Broadcast stopped state. InstReqs chan *fnpb.InstructionRequest DataReqs chan *fnpb.Elements @@ -98,8 +97,9 @@ func New(id, env string) *W { lis: lis, server: grpc.NewServer(opts...), - InstReqs: make(chan *fnpb.InstructionRequest, 10), - DataReqs: make(chan *fnpb.Elements, 10), + InstReqs: make(chan *fnpb.InstructionRequest, 10), + DataReqs: make(chan *fnpb.Elements, 10), + StoppedChan: make(chan struct{}), activeInstructions: make(map[string]controlResponder), Descriptors: make(map[string]*fnpb.ProcessBundleDescriptor), @@ -134,12 +134,26 @@ func (wk *W) LogValue() slog.Value { ) } +// shutdown safely closes channels, and can be called in the event of SDK crashes. +// +// Splitting this logic from the GRPC server Stop is necessary, since a worker +// crash would be handled in a streaming RPC context, which will block GRPC +// stop calls. +func (wk *W) shutdown() { + // If this is the first call to "stop" this worker, also close the channels. + if wk.stopped.CompareAndSwap(false, true) { + slog.Debug("shutdown", "worker", wk, "firstTime", true) + close(wk.StoppedChan) + close(wk.InstReqs) + close(wk.DataReqs) + } else { + slog.Debug("shutdown", "worker", wk, "firstTime", false) + } +} + // Stop the GRPC server. func (wk *W) Stop() { - slog.Debug("stopping", "worker", wk) - wk.stopped.Store(true) - close(wk.InstReqs) - close(wk.DataReqs) + wk.shutdown() // Give the SDK side 5 seconds to gracefully stop, before // hard stopping all RPCs. @@ -203,30 +217,46 @@ func (wk *W) Logging(stream fnpb.BeamFnLogging_LoggingServer) error { case codes.Canceled: return nil default: - slog.Error("logging.Recv", err, "worker", wk) + slog.Error("logging.Recv", slog.Any("error", err), slog.Any("worker", wk)) return err } } for _, l := range in.GetLogEntries() { - if l.Severity >= minsev { - // TODO: Connect to the associated Job for this worker instead of - // logging locally for SDK side logging. - file := l.GetLogLocation() - i := strings.LastIndex(file, ":") - line, _ := strconv.Atoi(file[i+1:]) - if i > 0 { - file = file[:i] - } + // TODO base this on a per pipeline logging setting. + if l.Severity < minsev { + continue + } + + // Ideally we'd be writing these to per-pipeline files, but for now re-log them on the Prism process. + // We indicate they're from the SDK, and which worker, keeping the same log severity. + // SDK specific and worker specific fields are in separate groups for legibility. - slog.LogAttrs(stream.Context(), toSlogSev(l.GetSeverity()), l.GetMessage(), - slog.Any(slog.SourceKey, &slog.Source{ - File: file, - Line: line, - }), - slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), - slog.Any("worker", wk), - ) + attrs := []any{ + slog.String("transformID", l.GetTransformId()), // TODO: pull the unique name from the pipeline graph. + slog.String("location", l.GetLogLocation()), + slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), + slog.String(slog.MessageKey, l.GetMessage()), + } + if fs := l.GetCustomData().GetFields(); len(fs) > 0 { + var grp []any + for n, v := range l.GetCustomData().GetFields() { + var attr slog.Attr + switch v.Kind.(type) { + case *structpb.Value_BoolValue: + attr = slog.Bool(n, v.GetBoolValue()) + case *structpb.Value_NumberValue: + attr = slog.Float64(n, v.GetNumberValue()) + case *structpb.Value_StringValue: + attr = slog.String(n, v.GetStringValue()) + default: + attr = slog.Any(n, v.AsInterface()) + } + grp = append(grp, attr) + } + attrs = append(attrs, slog.Group("customData", grp...)) } + + slog.LogAttrs(stream.Context(), toSlogSev(l.GetSeverity()), "log from SDK worker", slog.Any("worker", wk), slog.Group("sdk", attrs...)) } } } @@ -298,7 +328,7 @@ func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error { if b, ok := wk.activeInstructions[resp.GetInstructionId()]; ok { b.Respond(resp) } else { - slog.Debug("ctrl.Recv: %v", resp) + slog.Debug("ctrl.Recv", slog.Any("response", resp)) } wk.mu.Unlock() } @@ -317,17 +347,21 @@ func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error { case <-ctrl.Context().Done(): wk.mu.Lock() // Fail extant instructions - slog.Debug("SDK Disconnected", "worker", wk, "ctx_error", ctrl.Context().Err(), "outstanding_instructions", len(wk.activeInstructions)) + err := context.Cause(ctrl.Context()) + slog.Debug("SDK Disconnected", "worker", wk, "ctx_error", err, "outstanding_instructions", len(wk.activeInstructions)) - msg := fmt.Sprintf("SDK worker disconnected: %v, %v active instructions", wk.String(), len(wk.activeInstructions)) + msg := fmt.Sprintf("SDK worker disconnected: %v, %v active instructions, error: %v", wk.String(), len(wk.activeInstructions), err) for instID, b := range wk.activeInstructions { b.Respond(&fnpb.InstructionResponse{ InstructionId: instID, Error: msg, }) } + // Soft shutdown to prevent GRPC shutdown from being blocked by this + // streaming call. + wk.shutdown() wk.mu.Unlock() - return context.Cause(ctrl.Context()) + return err case err := <-done: if err != nil { slog.Warn("Control done", "error", err, "worker", wk) @@ -355,7 +389,7 @@ func (wk *W) Data(data fnpb.BeamFnData_DataServer) error { case codes.Canceled: return default: - slog.Error("data.Recv failed", err, "worker", wk) + slog.Error("data.Recv failed", slog.Any("error", err), slog.Any("worker", wk)) panic(err) } } @@ -434,7 +468,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case codes.Canceled: return default: - slog.Error("state.Recv failed", err, "worker", wk) + slog.Error("state.Recv failed", slog.Any("error", err), slog.Any("worker", wk)) panic(err) } } @@ -584,7 +618,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { }() for resp := range responses { if err := state.Send(resp); err != nil { - slog.Error("state.Send error", err) + slog.Error("state.Send", slog.Any("error", err)) } } return nil @@ -625,9 +659,22 @@ func (wk *W) sendInstruction(ctx context.Context, req *fnpb.InstructionRequest) if wk.Stopped() { return nil } - wk.InstReqs <- req + select { + case <-wk.StoppedChan: + return &fnpb.InstructionResponse{ + InstructionId: progInst, + Error: "worker stopped before send", + } + case wk.InstReqs <- req: + // desired outcome + } select { + case <-wk.StoppedChan: + return &fnpb.InstructionResponse{ + InstructionId: progInst, + Error: "worker stopped before receive", + } case <-ctx.Done(): return &fnpb.InstructionResponse{ InstructionId: progInst, diff --git a/sdks/go/test/integration/io/xlang/debezium/debezium_test.go b/sdks/go/test/integration/io/xlang/debezium/debezium_test.go index 24c2b513b2b2..208a062f9436 100644 --- a/sdks/go/test/integration/io/xlang/debezium/debezium_test.go +++ b/sdks/go/test/integration/io/xlang/debezium/debezium_test.go @@ -34,7 +34,7 @@ import ( ) const ( - debeziumImage = "debezium/example-postgres:latest" + debeziumImage = "quay.io/debezium/example-postgres:latest" debeziumPort = "5432/tcp" maxRetries = 5 ) diff --git a/sdks/java/container/Dockerfile-distroless b/sdks/java/container/Dockerfile-distroless new file mode 100644 index 000000000000..328c4dc6a7b3 --- /dev/null +++ b/sdks/java/container/Dockerfile-distroless @@ -0,0 +1,42 @@ +############################################################################### +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + +# ARG BEAM_BASE is the Beam SDK container image built using sdks/python/container/Dockerfile. +ARG BEAM_BASE + +# ARG DISTROLESS_BASE is the distroless container image URL. For available distroless Java images, +# see https://github.com/GoogleContainerTools/distroless/tree/main?tab=readme-ov-file#what-images-are-available. +# Only Java versions 17 and 21 are supported. +ARG DISTROLESS_BASE +FROM ${BEAM_BASE} AS base +ARG TARGETARCH +ENV LANG C.UTF-8 + +LABEL Author="Apache Beam " + +RUN if [ -z "${TARGETARCH}" ]; then echo "fatal: TARGETARCH not set; run as docker buildx build or use --build-arg=TARGETARCH=amd64|arm64" >&2; exit 1; fi + +FROM ${DISTROLESS_BASE}:latest-${TARGETARCH} AS distroless + +COPY --from=base /opt /opt + +# Along with the LANG environment variable above, prevents internally discovered failing bugs related to Dataflow Flex +# template character encodings. +COPY --from=base /usr/lib/locale /usr/lib/locale + +ENTRYPOINT ["/opt/apache/beam/boot"] diff --git a/sdks/java/container/license_scripts/dep_urls_java.yaml b/sdks/java/container/license_scripts/dep_urls_java.yaml index 797a2e3c2f78..781a0decda78 100644 --- a/sdks/java/container/license_scripts/dep_urls_java.yaml +++ b/sdks/java/container/license_scripts/dep_urls_java.yaml @@ -46,7 +46,7 @@ jaxen: '1.1.6': type: "3-Clause BSD" libraries-bom: - '26.45.0': + '26.49.0': license: "https://raw.githubusercontent.com/GoogleCloudPlatform/cloud-opensource-java/master/LICENSE" type: "Apache License 2.0" paranamer: diff --git a/sdks/java/core/build.gradle b/sdks/java/core/build.gradle index e150c22de62d..7423cb7c6b8e 100644 --- a/sdks/java/core/build.gradle +++ b/sdks/java/core/build.gradle @@ -73,7 +73,9 @@ dependencies { antlr library.java.antlr // antlr is used to generate code from sdks/java/core/src/main/antlr/ permitUnusedDeclared library.java.antlr + permitUsedUndeclared library.java.antlr_runtime // Required to load constants from the model, e.g. max timestamp for global window + provided library.java.google_api_services_dataflow shadow project(path: ":model:pipeline", configuration: "shadow") shadow project(path: ":model:fn-execution", configuration: "shadow") shadow project(path: ":model:job-management", configuration: "shadow") @@ -81,7 +83,6 @@ dependencies { shadow library.java.vendored_grpc_1_60_1 shadow library.java.vendored_guava_32_1_2_jre shadow library.java.byte_buddy - shadow library.java.antlr_runtime shadow library.java.commons_compress shadow library.java.commons_lang3 testImplementation library.java.mockito_inline diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java index f739a797af80..453c1cd79a42 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.fn; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.harness.JvmInitializer; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.common.ReflectHelpers; @@ -25,6 +26,8 @@ /** Helpers for executing {@link JvmInitializer} implementations. */ public class JvmInitializers { + private static final AtomicBoolean initialized = new AtomicBoolean(false); + /** * Finds all registered implementations of JvmInitializer and executes their {@code onStartup} * methods. Should be called in worker harness implementations at the very beginning of their main @@ -50,10 +53,23 @@ public static void runBeforeProcessing(PipelineOptions options) { // We load the logger in the method to minimize the amount of class loading that happens // during class initialization. Logger logger = LoggerFactory.getLogger(JvmInitializers.class); - for (JvmInitializer initializer : ReflectHelpers.loadServicesOrdered(JvmInitializer.class)) { - logger.info("Running JvmInitializer#beforeProcessing for {}", initializer); - initializer.beforeProcessing(options); - logger.info("Completed JvmInitializer#beforeProcessing for {}", initializer); + + try { + for (JvmInitializer initializer : ReflectHelpers.loadServicesOrdered(JvmInitializer.class)) { + logger.info("Running JvmInitializer#beforeProcessing for {}", initializer); + initializer.beforeProcessing(options); + logger.info("Completed JvmInitializer#beforeProcessing for {}", initializer); + } + initialized.compareAndSet(false, true); + } catch (Error e) { + if (initialized.get()) { + logger.warn( + "Error at JvmInitializer#beforeProcessing. This error is suppressed after " + + "previous success runs. It is expected on Embedded environment", + e); + } else { + throw e; + } } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java index e946022c4e36..aa0dea80b0a1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java @@ -17,11 +17,13 @@ */ package org.apache.beam.sdk.fn.data; +import java.time.Duration; import java.util.HashSet; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.function.Consumer; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.pipeline.v1.Endpoints; @@ -30,6 +32,8 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; @@ -49,13 +53,20 @@ */ public class BeamFnDataGrpcMultiplexer implements AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcMultiplexer.class); + private static final Duration POISONED_INSTRUCTION_ID_CACHE_TIMEOUT = Duration.ofMinutes(20); private final Endpoints.@Nullable ApiServiceDescriptor apiServiceDescriptor; private final StreamObserver inboundObserver; private final StreamObserver outboundObserver; - private final ConcurrentMap< + private final ConcurrentHashMap< /*instructionId=*/ String, CompletableFuture>> receivers; - private final ConcurrentMap erroredInstructionIds; + private final Cache poisonedInstructionIds; + + private static class PoisonedException extends RuntimeException { + public PoisonedException() { + super("Instruction poisoned"); + } + }; public BeamFnDataGrpcMultiplexer( Endpoints.@Nullable ApiServiceDescriptor apiServiceDescriptor, @@ -64,7 +75,8 @@ public BeamFnDataGrpcMultiplexer( baseOutboundObserverFactory) { this.apiServiceDescriptor = apiServiceDescriptor; this.receivers = new ConcurrentHashMap<>(); - this.erroredInstructionIds = new ConcurrentHashMap<>(); + this.poisonedInstructionIds = + CacheBuilder.newBuilder().expireAfterWrite(POISONED_INSTRUCTION_ID_CACHE_TIMEOUT).build(); this.inboundObserver = new InboundObserver(); this.outboundObserver = outboundObserverFactory.outboundObserverFor(baseOutboundObserverFactory, inboundObserver); @@ -87,11 +99,6 @@ public StreamObserver getOutboundObserver() { return outboundObserver; } - private CompletableFuture> receiverFuture( - String instructionId) { - return receivers.computeIfAbsent(instructionId, (unused) -> new CompletableFuture<>()); - } - /** * Registers a consumer for the specified instruction id. * @@ -99,17 +106,63 @@ private CompletableFuture> receiverF * instruction ids ensuring that the receiver will only see {@link BeamFnApi.Elements} with a * single instruction id. * - *

The caller must {@link #unregisterConsumer unregister the consumer} when they no longer wish - * to receive messages. + *

The caller must either {@link #unregisterConsumer unregister the consumer} when all messages + * have been processed or {@link #poisonInstructionId(String) poison the instruction} if messages + * for the instruction should be dropped. */ public void registerConsumer( String instructionId, CloseableFnDataReceiver receiver) { - receiverFuture(instructionId).complete(receiver); + receivers.compute( + instructionId, + (unused, existing) -> { + if (existing != null) { + if (!existing.complete(receiver)) { + throw new IllegalArgumentException("Instruction id was registered twice"); + } + return existing; + } + if (poisonedInstructionIds.getIfPresent(instructionId) != null) { + throw new IllegalArgumentException("Instruction id was poisoned"); + } + return CompletableFuture.completedFuture(receiver); + }); } - /** Unregisters a consumer. */ + /** Unregisters a previously registered consumer. */ public void unregisterConsumer(String instructionId) { - receivers.remove(instructionId); + @Nullable + CompletableFuture> receiverFuture = + receivers.remove(instructionId); + if (receiverFuture != null && !receiverFuture.isDone()) { + // The future must have been inserted by the inbound observer since registerConsumer completes + // the future. + throw new IllegalArgumentException("Unregistering consumer which was not registered."); + } + } + + /** + * Poisons an instruction id. + * + *

Any records for the instruction on the inbound observer will be dropped for the next {@link + * #POISONED_INSTRUCTION_ID_CACHE_TIMEOUT}. + */ + public void poisonInstructionId(String instructionId) { + poisonedInstructionIds.put(instructionId, Boolean.TRUE); + @Nullable + CompletableFuture> receiverFuture = + receivers.remove(instructionId); + if (receiverFuture != null) { + // Completing exceptionally has no effect if the future was already notified. In that case + // whatever registered the receiver needs to handle cancelling it. + receiverFuture.completeExceptionally(new PoisonedException()); + if (!receiverFuture.isCompletedExceptionally()) { + try { + receiverFuture.get().close(); + } catch (Exception e) { + LOG.warn("Unexpected error closing existing observer"); + } + } + } } @VisibleForTesting @@ -210,27 +263,42 @@ public void onNext(BeamFnApi.Elements value) { } private void forwardToConsumerForInstructionId(String instructionId, BeamFnApi.Elements value) { - if (erroredInstructionIds.containsKey(instructionId)) { - LOG.debug("Ignoring inbound data for failed instruction {}", instructionId); - return; - } - CompletableFuture> consumerFuture = - receiverFuture(instructionId); - if (!consumerFuture.isDone()) { - LOG.debug( - "Received data for instruction {} without consumer ready. " - + "Waiting for consumer to be registered.", - instructionId); - } CloseableFnDataReceiver consumer; try { - consumer = consumerFuture.get(); - + CompletableFuture> consumerFuture = + receivers.computeIfAbsent( + instructionId, + (unused) -> { + if (poisonedInstructionIds.getIfPresent(instructionId) != null) { + throw new PoisonedException(); + } + LOG.debug( + "Received data for instruction {} without consumer ready. " + + "Waiting for consumer to be registered.", + instructionId); + return new CompletableFuture<>(); + }); + // The consumer may not be registered until the bundle processor is fully constructed so we + // conservatively set + // a high timeout. Poisoning will prevent this for occurring for consumers that will not be + // registered. + consumer = consumerFuture.get(3, TimeUnit.HOURS); /* * TODO: On failure we should fail any bundles that were impacted eagerly * instead of relying on the Runner harness to do all the failure handling. */ - } catch (ExecutionException | InterruptedException e) { + } catch (TimeoutException e) { + LOG.error( + "Timed out waiting to observe consumer data stream for instruction {}", + instructionId, + e); + outboundObserver.onError(e); + return; + } catch (ExecutionException | InterruptedException | PoisonedException e) { + if (e instanceof PoisonedException || e.getCause() instanceof PoisonedException) { + LOG.debug("Received data for poisoned instruction {}. Dropping input.", instructionId); + return; + } LOG.error( "Client interrupted during handling of data for instruction {}", instructionId, e); outboundObserver.onError(e); @@ -240,10 +308,11 @@ private void forwardToConsumerForInstructionId(String instructionId, BeamFnApi.E outboundObserver.onError(e); return; } + try { consumer.accept(value); } catch (Exception e) { - erroredInstructionIds.put(instructionId, true); + poisonInstructionId(instructionId); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java index b0d29e2295a8..2c6b61e62121 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java @@ -202,6 +202,12 @@ public WeightedList decodeFromChunkBoundaryToChunkBoundary() { T next = next(); rvals.add(next); } + // We don't support seeking backwards so release the memory of the last + // page if it is completed. + if (inbound.currentStream.available() == 0) { + inbound.position = 0; + inbound.currentStream = EMPTY_STREAM; + } // Uses the size of the ByteString as an approximation for the heap size occupied by the // page, considering an overhead of {@link BYTES_LIST_ELEMENT_OVERHEAD} for each element. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java index b7523ee12b56..7eb04519555b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSink.java @@ -687,11 +687,25 @@ protected final List, ResourceId>> finalizeDestinati distinctFilenames.get(finalFilename)); distinctFilenames.put(finalFilename, result); outputFilenames.add(KV.of(result, finalFilename)); - FileSystems.reportSinkLineage(finalFilename); } + reportSinkLineage(outputFilenames); return outputFilenames; } + /** + * Report sink Lineage. Report every file if number of files no more than 100, otherwise only + * report at directory level. + */ + private void reportSinkLineage(List, ResourceId>> outputFilenames) { + if (outputFilenames.size() <= 100) { + for (KV, ResourceId> kv : outputFilenames) { + FileSystems.reportSinkLineage(kv.getValue()); + } + } else { + FileSystems.reportSinkLineage(outputFilenames.get(0).getValue().getCurrentDirectory()); + } + } + private Collection> createMissingEmptyShards( @Nullable DestinationT dest, @Nullable Integer numShards, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java index 7ddfde441aed..8d6e52c64a52 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java @@ -26,10 +26,12 @@ import java.nio.channels.ReadableByteChannel; import java.nio.channels.SeekableByteChannel; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.ListIterator; import java.util.NoSuchElementException; import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.sdk.io.FileSystem.LineageLevel; import org.apache.beam.sdk.io.fs.EmptyMatchTreatment; import org.apache.beam.sdk.io.fs.MatchResult; import org.apache.beam.sdk.io.fs.MatchResult.Metadata; @@ -297,9 +299,10 @@ public final List> split( System.currentTimeMillis() - startTime, expandedFiles.size(), splitResults.size()); + + reportSourceLineage(expandedFiles); return splitResults; } else { - FileSystems.reportSourceLineage(getSingleFileMetadata().resourceId()); if (isSplittable()) { @SuppressWarnings("unchecked") List> splits = @@ -315,6 +318,37 @@ public final List> split( } } + /** + * Report source Lineage. Due to the size limit of Beam metrics, report full file name or only dir + * depend on the number of files. + * + *

- Number of files<=100, report full file paths; + * + *

- Number of directory<=100, report directory names (one level up); + * + *

- Otherwise, report top level only. + */ + private static void reportSourceLineage(List expandedFiles) { + if (expandedFiles.size() <= 100) { + for (Metadata metadata : expandedFiles) { + FileSystems.reportSourceLineage(metadata.resourceId()); + } + } else { + HashSet uniqueDirs = new HashSet<>(); + for (Metadata metadata : expandedFiles) { + ResourceId dir = metadata.resourceId().getCurrentDirectory(); + uniqueDirs.add(dir); + if (uniqueDirs.size() > 100) { + FileSystems.reportSourceLineage(dir, LineageLevel.TOP_LEVEL); + return; + } + } + for (ResourceId uniqueDir : uniqueDirs) { + FileSystems.reportSourceLineage(uniqueDir); + } + } + } + /** * Determines whether a file represented by this source is can be split into bundles. * diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystem.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystem.java index 11314a318b25..73caa7284e98 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystem.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystem.java @@ -157,10 +157,20 @@ protected abstract void rename( */ protected abstract String getScheme(); + public enum LineageLevel { + FILE, + TOP_LEVEL + } + + /** Report {@link Lineage} metrics for resource id at file level. */ + protected void reportLineage(ResourceIdT resourceId, Lineage lineage) { + reportLineage(resourceId, lineage, LineageLevel.FILE); + } + /** - * Report {@link Lineage} metrics for resource id. + * Report {@link Lineage} metrics for resource id to a given level. * *

Unless override by FileSystem implementations, default to no-op. */ - protected void reportLineage(ResourceIdT unusedId, Lineage unusedLineage) {} + protected void reportLineage(ResourceIdT unusedId, Lineage unusedLineage, LineageLevel level) {} } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java index a4ca9b80dce3..7e2940a2c35b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java @@ -39,6 +39,7 @@ import java.util.regex.Pattern; import javax.annotation.Nonnull; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.FileSystem.LineageLevel; import org.apache.beam.sdk.io.fs.CreateOptions; import org.apache.beam.sdk.io.fs.CreateOptions.StandardCreateOptions; import org.apache.beam.sdk.io.fs.EmptyMatchTreatment; @@ -50,6 +51,7 @@ import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.metrics.Lineage; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.KV; @@ -398,12 +400,36 @@ public ResourceId apply(@Nonnull Metadata input) { /** Report source {@link Lineage} metrics for resource id. */ public static void reportSourceLineage(ResourceId resourceId) { - getFileSystemInternal(resourceId.getScheme()).reportLineage(resourceId, Lineage.getSources()); + reportSourceLineage(resourceId, LineageLevel.FILE); } /** Report sink {@link Lineage} metrics for resource id. */ public static void reportSinkLineage(ResourceId resourceId) { - getFileSystemInternal(resourceId.getScheme()).reportLineage(resourceId, Lineage.getSinks()); + reportSinkLineage(resourceId, LineageLevel.FILE); + } + + /** + * Report source {@link Lineage} metrics for resource id at given level. + * + *

Internal API, no backward compatibility guaranteed. + */ + public static void reportSourceLineage(ResourceId resourceId, LineageLevel level) { + reportLineage(resourceId, Lineage.getSources(), level); + } + + /** + * Report source {@link Lineage} metrics for resource id at given level. + * + *

Internal API, no backward compatibility guaranteed. + */ + public static void reportSinkLineage(ResourceId resourceId, LineageLevel level) { + reportLineage(resourceId, Lineage.getSinks(), level); + } + + /** Report {@link Lineage} metrics for resource id at given level to given Lineage container. */ + private static void reportLineage(ResourceId resourceId, Lineage lineage, LineageLevel level) { + FileSystem fileSystem = getFileSystemInternal(resourceId.getScheme()); + fileSystem.reportLineage(resourceId, lineage, level); } private static class FilterResult { @@ -540,14 +566,19 @@ static FileSystem getFileSystemInternal(String scheme) { * *

It will be used in {@link FileSystemRegistrar FileSystemRegistrars} for all schemes. * - *

This is expected only to be used by runners after {@code Pipeline.run}, or in tests. + *

Outside of workers where Beam FileSystem API is used (e.g. test methods, user code executed + * during pipeline submission), consider use {@link #registerFileSystemsOnce} if initialize + * FileSystem of supported schema is the main goal. */ @Internal public static void setDefaultPipelineOptions(PipelineOptions options) { - checkNotNull(options, "options"); + checkNotNull(options, "options cannot be null"); long id = options.getOptionsId(); int nextRevision = options.revision(); + // entry to set other PipelineOption determined flags + Metrics.setDefaultPipelineOptions(options); + while (true) { KV revision = FILESYSTEM_REVISION.get(); // only update file systems if the pipeline changed or the options revision increased @@ -568,6 +599,23 @@ public static void setDefaultPipelineOptions(PipelineOptions options) { } } + /** + * Register file systems once if never done before. + * + *

This method executes {@link #setDefaultPipelineOptions} only if it has never been run, + * otherwise it returns immediately. + * + *

It is internally used by test setup to avoid repeated filesystem registrations (involves + * expensive ServiceLoader calls) when there are multiple pipeline and PipelineOptions object + * initialized, which is commonly seen in test execution. + */ + @Internal + public static synchronized void registerFileSystemsOnce(PipelineOptions options) { + if (FILESYSTEM_REVISION.get() == null) { + setDefaultPipelineOptions(options); + } + } + @VisibleForTesting static Map verifySchemesAreUnique( PipelineOptions options, Set registrars) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSourceTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSourceTransform.java index bbac337f2d0f..843deb5cab32 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSourceTransform.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/ReadAllViaFileBasedSourceTransform.java @@ -19,7 +19,9 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; +import java.util.HashSet; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.FileSystem.LineageLevel; import org.apache.beam.sdk.io.fs.MatchResult; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.io.range.OffsetRange; @@ -30,6 +32,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.checkerframework.checker.nullness.qual.Nullable; public abstract class ReadAllViaFileBasedSourceTransform extends PTransform, PCollection> { @@ -81,6 +84,9 @@ public static class SplitIntoRangesFn extends DoFn> { private final long desiredBundleSizeBytes; + // track unique resourceId met. Access it only inside reportSourceLineage + private transient @Nullable HashSet uniqueIds; + public SplitIntoRangesFn(long desiredBundleSizeBytes) { this.desiredBundleSizeBytes = desiredBundleSizeBytes; } @@ -88,6 +94,7 @@ public SplitIntoRangesFn(long desiredBundleSizeBytes) { @ProcessElement public void process(ProcessContext c) { MatchResult.Metadata metadata = c.element().getMetadata(); + reportSourceLineage(metadata.resourceId()); if (!metadata.isReadSeekEfficient()) { c.output(KV.of(c.element(), new OffsetRange(0, metadata.sizeBytes()))); return; @@ -97,6 +104,31 @@ public void process(ProcessContext c) { c.output(KV.of(c.element(), range)); } } + + /** + * Report source Lineage. Due to the size limit of Beam metrics, report full file name or only + * top level depend on the number of files. + * + *

- Number of files<=100, report full file paths; + * + *

- Otherwise, report top level only. + */ + @SuppressWarnings("nullness") // only called in processElement, guaranteed to be non-null + private void reportSourceLineage(ResourceId resourceId) { + if (uniqueIds == null) { + uniqueIds = new HashSet<>(); + } else if (uniqueIds.isEmpty()) { + // already at capacity + FileSystems.reportSourceLineage(resourceId, LineageLevel.TOP_LEVEL); + return; + } + uniqueIds.add(resourceId); + FileSystems.reportSourceLineage(resourceId, LineageLevel.FILE); + if (uniqueIds.size() >= 100) { + // avoid reference leak + uniqueIds.clear(); + } + } } public abstract static class AbstractReadFileRangesFn @@ -140,7 +172,6 @@ public void process(ProcessContext c) throws IOException { throw e; } } - FileSystems.reportSourceLineage(resourceId); } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DelegatingCounter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DelegatingCounter.java index a0b2e3b34678..7e8252d4fb3f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DelegatingCounter.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DelegatingCounter.java @@ -19,6 +19,7 @@ import java.io.Serializable; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.metrics.Metrics.MetricsFlag; /** Implementation of {@link Counter} that delegates to the instance for the current context. */ @Internal @@ -70,6 +71,9 @@ public void inc() { /** Increment the counter by the given amount. */ @Override public void inc(long n) { + if (MetricsFlag.counterDisabled()) { + return; + } MetricsContainer container = this.processWideContainer ? MetricsEnvironment.getProcessWideContainer() diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java index a963015e98a7..6c8179006640 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java @@ -18,6 +18,13 @@ package org.apache.beam.sdk.metrics; import java.io.Serializable; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * The Metrics is a utility class for producing various kinds of metrics for reporting @@ -50,9 +57,59 @@ * example off how to query metrics. */ public class Metrics { + private static final Logger LOG = LoggerFactory.getLogger(Metrics.class); private Metrics() {} + static class MetricsFlag { + private static final AtomicReference<@Nullable MetricsFlag> INSTANCE = new AtomicReference<>(); + final boolean counterDisabled; + final boolean stringSetDisabled; + + private MetricsFlag(boolean counterDisabled, boolean stringSetDisabled) { + this.counterDisabled = counterDisabled; + this.stringSetDisabled = stringSetDisabled; + } + + static boolean counterDisabled() { + MetricsFlag flag = INSTANCE.get(); + return flag != null && flag.counterDisabled; + } + + static boolean stringSetDisabled() { + MetricsFlag flag = INSTANCE.get(); + return flag != null && flag.stringSetDisabled; + } + } + + /** + * Initialize metrics flags if not already done so. + * + *

Should be called by worker at worker harness initialization. Should not be called by user + * code (and it does not have an effect as the initialization completed before). + */ + @Internal + public static void setDefaultPipelineOptions(PipelineOptions options) { + MetricsFlag flag = MetricsFlag.INSTANCE.get(); + if (flag == null) { + ExperimentalOptions exp = options.as(ExperimentalOptions.class); + boolean counterDisabled = ExperimentalOptions.hasExperiment(exp, "disableCounterMetrics"); + if (counterDisabled) { + LOG.info("Counter metrics are disabled."); + } + boolean stringSetDisabled = ExperimentalOptions.hasExperiment(exp, "disableStringSetMetrics"); + if (stringSetDisabled) { + LOG.info("StringSet metrics are disabled"); + } + MetricsFlag.INSTANCE.compareAndSet(null, new MetricsFlag(counterDisabled, stringSetDisabled)); + } + } + + @Internal + static void resetDefaultPipelineOptions() { + MetricsFlag.INSTANCE.set(null); + } + /** * Create a metric that can be incremented and decremented, and is aggregated by taking the sum. */ @@ -174,6 +231,9 @@ private DelegatingStringSet(MetricName name) { @Override public void add(String value) { + if (MetricsFlag.stringSetDisabled()) { + return; + } MetricsContainer container = MetricsEnvironment.getCurrentContainer(); if (container != null) { container.getStringSet(name).add(value); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java index 6e5843f533db..78ea34503e54 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java @@ -20,6 +20,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import com.fasterxml.jackson.annotation.JsonCreator; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -386,4 +387,21 @@ static List getConfiguredLoggerFromOptions(SdkHarnessOptions loggingOpti } return configuredLoggers; } + + @Hidden + @Description( + "Timeout used for cache of bundle processors. Defaults to a minute for batch and an hour for streaming.") + @Default.InstanceFactory(BundleProcessorCacheTimeoutFactory.class) + Duration getBundleProcessorCacheTimeout(); + + void setBundleProcessorCacheTimeout(Duration duration); + + class BundleProcessorCacheTimeoutFactory implements DefaultValueFactory { + @Override + public Duration create(PipelineOptions options) { + return options.as(StreamingOptions.class).isStreaming() + ? Duration.ofHours(1) + : Duration.ofMinutes(1); + } + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java index 5ccfe39b92af..f35782c2b9a2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java @@ -19,7 +19,6 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; @@ -32,13 +31,10 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** A {@link SchemaProvider} for AutoValue classes. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) public class AutoValueSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on AutoValue getters. */ @VisibleForTesting @@ -49,7 +45,11 @@ public static class AbstractGetterTypeSupplier implements FieldValueTypeSupplier public List get(TypeDescriptor typeDescriptor) { // If the generated class is passed in, we want to look at the base class to find the getters. - TypeDescriptor targetTypeDescriptor = AutoValueUtils.getBaseAutoValueClass(typeDescriptor); + TypeDescriptor targetTypeDescriptor = + Preconditions.checkNotNull( + AutoValueUtils.getBaseAutoValueClass(typeDescriptor), + "unable to determine base AutoValue class for type {}", + typeDescriptor); List methods = ReflectUtils.getMethods(targetTypeDescriptor.getRawType()).stream() @@ -62,9 +62,9 @@ public List get(TypeDescriptor typeDescriptor) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); return types; } @@ -89,8 +89,8 @@ private static void validateFieldNumbers(List types) } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( targetTypeDescriptor, schema, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java index 8725833bc1da..6e244fefb263 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java @@ -20,6 +20,9 @@ import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.initialization.qual.NotOnlyInitialized; +import org.checkerframework.checker.initialization.qual.UnknownInitialization; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -32,24 +35,25 @@ * significant for larger schemas) on each lookup. This wrapper caches the value returned by the * inner factory, so the schema comparison only need happen on the first lookup. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) -public class CachingFactory implements Factory { +public class CachingFactory implements Factory { private transient @Nullable ConcurrentHashMap, CreatedT> cache = null; - private final Factory innerFactory; + private final @NotOnlyInitialized Factory innerFactory; - public CachingFactory(Factory innerFactory) { + public CachingFactory(@UnknownInitialization Factory innerFactory) { this.innerFactory = innerFactory; } - @Override - public CreatedT create(TypeDescriptor typeDescriptor, Schema schema) { + private ConcurrentHashMap, CreatedT> getCache() { if (cache == null) { cache = new ConcurrentHashMap<>(); } + return cache; + } + + @Override + public CreatedT create(TypeDescriptor typeDescriptor, Schema schema) { + ConcurrentHashMap, CreatedT> cache = getCache(); CreatedT cached = cache.get(typeDescriptor); if (cached != null) { return cached; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java index fb98db8e8343..63ab56dc7609 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java @@ -19,6 +19,7 @@ import java.io.Serializable; import org.apache.beam.sdk.annotations.Internal; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -29,7 +30,7 @@ *

Implementations of this interface are generated at runtime to map object fields to Row fields. */ @Internal -public interface FieldValueGetter extends Serializable { +public interface FieldValueGetter extends Serializable { @Nullable ValueT get(ObjectT object); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index 750709192c08..43aac6a5e20c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -27,7 +27,9 @@ import java.util.Arrays; import java.util.Collections; import java.util.Map; +import java.util.Optional; import java.util.stream.Stream; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -40,10 +42,7 @@ /** Represents type information for a Java type that will be used to infer a Schema type. */ @AutoValue -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@Internal public abstract class FieldValueTypeInformation implements Serializable { /** Optionally returns the field index. */ public abstract @Nullable Integer getNumber(); @@ -125,8 +124,13 @@ public static FieldValueTypeInformation forOneOf( .build(); } - public static FieldValueTypeInformation forField(Field field, int index) { - TypeDescriptor type = TypeDescriptor.of(field.getGenericType()); + public static FieldValueTypeInformation forField( + @Nullable TypeDescriptor typeDescriptor, Field field, int index) { + TypeDescriptor type = + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(field.getGenericType())) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(field.getGenericType())); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(field.getName(), field)) .setNumber(getNumberOverride(index, field)) @@ -134,9 +138,9 @@ public static FieldValueTypeInformation forField(Field field, int index) { .setType(type) .setRawType(type.getRawType()) .setField(field) - .setElementType(getIterableComponentType(field)) - .setMapKeyType(getMapKeyType(field)) - .setMapValueType(getMapValueType(field)) + .setElementType(getIterableComponentType(type)) + .setMapKeyType(getMapKeyType(type)) + .setMapValueType(getMapValueType(type)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(field)) .build(); @@ -185,6 +189,11 @@ public static String getNameOverride( } public static FieldValueTypeInformation forGetter(Method method, int index) { + return forGetter(null, method, index); + } + + public static FieldValueTypeInformation forGetter( + @Nullable TypeDescriptor typeDescriptor, Method method, int index) { String name; if (method.getName().startsWith("get")) { name = ReflectUtils.stripPrefix(method.getName(), "get"); @@ -194,7 +203,12 @@ public static FieldValueTypeInformation forGetter(Method method, int index) { throw new RuntimeException("Getter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericReturnType()); + TypeDescriptor type = + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(method.getGenericReturnType())) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(method.getGenericReturnType())); + boolean nullable = hasNullableReturnType(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(name, method)) @@ -253,10 +267,20 @@ private static boolean isNullableAnnotation(Annotation annotation) { } public static FieldValueTypeInformation forSetter(Method method) { - return forSetter(method, "set"); + return forSetter(null, method); } public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) { + return forSetter(null, method, setterPrefix); + } + + public static FieldValueTypeInformation forSetter( + @Nullable TypeDescriptor typeDescriptor, Method method) { + return forSetter(typeDescriptor, method, "set"); + } + + public static FieldValueTypeInformation forSetter( + @Nullable TypeDescriptor typeDescriptor, Method method, String setterPrefix) { String name; if (method.getName().startsWith(setterPrefix)) { name = ReflectUtils.stripPrefix(method.getName(), setterPrefix); @@ -264,7 +288,11 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr throw new RuntimeException("Setter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericParameterTypes()[0]); + TypeDescriptor type = + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(method.getGenericParameterTypes()[0])) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(method.getGenericParameterTypes()[0])); boolean nullable = hasSingleNullableParameter(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(name) @@ -283,10 +311,6 @@ public FieldValueTypeInformation withName(String name) { return toBuilder().setName(name).build(); } - private static FieldValueTypeInformation getIterableComponentType(Field field) { - return getIterableComponentType(TypeDescriptor.of(field.getGenericType())); - } - static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { // TODO: Figure out nullable elements. TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); @@ -306,23 +330,13 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { .build(); } - // If the Field is a map type, returns the key type, otherwise returns a null reference. - - private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) { - return getMapKeyType(TypeDescriptor.of(field.getGenericType())); - } - + // If the type is a map type, returns the key type, otherwise returns a null reference. private static @Nullable FieldValueTypeInformation getMapKeyType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 0); } - // If the Field is a map type, returns the value type, otherwise returns a null reference. - - private static @Nullable FieldValueTypeInformation getMapValueType(Field field) { - return getMapType(TypeDescriptor.of(field.getGenericType()), 1); - } - + // If the type is a map type, returns the value type, otherwise returns a null reference. private static @Nullable FieldValueTypeInformation getMapValueType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 1); @@ -330,10 +344,9 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { // If the Field is a map type, returns the key or value type (0 is key type, 1 is value). // Otherwise returns a null reference. - @SuppressWarnings("unchecked") private static @Nullable FieldValueTypeInformation getMapType( TypeDescriptor valueType, int index) { - TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); + TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); if (mapType == null) { return null; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java index ce5be71933b8..4e431bb45207 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java @@ -17,13 +17,12 @@ */ package org.apache.beam.sdk.schemas; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; @@ -32,10 +31,13 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.initialization.qual.NotOnlyInitialized; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -46,10 +48,7 @@ * methods which receive {@link TypeDescriptor}s instead of ordinary {@link Class}es as * arguments, which permits to support generic type signatures during schema inference */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) @Deprecated public abstract class GetterBasedSchemaProvider implements SchemaProvider { @@ -67,9 +66,9 @@ public abstract class GetterBasedSchemaProvider implements SchemaProvider { * override it if you want to use the richer type signature contained in the {@link * TypeDescriptor} not subject to the type erasure. */ - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { - return fieldValueGetters(targetTypeDescriptor.getRawType(), schema); + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { + return (List) fieldValueGetters(targetTypeDescriptor.getRawType(), schema); } /** @@ -112,9 +111,10 @@ public SchemaUserTypeCreator schemaTypeCreator( return schemaTypeCreator(targetTypeDescriptor.getRawType(), schema); } - private class ToRowWithValueGetters implements SerializableFunction { + private class ToRowWithValueGetters + implements SerializableFunction { private final Schema schema; - private final Factory> getterFactory; + private final Factory>> getterFactory; public ToRowWithValueGetters(Schema schema) { this.schema = schema; @@ -122,7 +122,12 @@ public ToRowWithValueGetters(Schema schema) { // schema, return a caching factory that caches the first value seen for each class. This // prevents having to lookup the getter list each time createGetters is called. this.getterFactory = - RowValueGettersFactory.of(GetterBasedSchemaProvider.this::fieldValueGetters); + RowValueGettersFactory.of( + (Factory>>) + (typeDescriptor, schema1) -> + (List) + GetterBasedSchemaProvider.this.fieldValueGetters( + typeDescriptor, schema1)); } @Override @@ -160,13 +165,15 @@ public SerializableFunction toRowFunction(TypeDescriptor typeDesc // important to capture the schema once here, so all invocations of the toRowFunction see the // same version of the schema. If schemaFor were to be called inside the lambda below, different // workers would see different versions of the schema. - Schema schema = schemaFor(typeDescriptor); + @NonNull + Schema schema = + Verify.verifyNotNull( + schemaFor(typeDescriptor), "can't create a ToRowFunction with null schema"); return new ToRowWithValueGetters<>(schema); } @Override - @SuppressWarnings("unchecked") public SerializableFunction fromRowFunction(TypeDescriptor typeDescriptor) { return new FromRowUsingCreator<>(typeDescriptor, this); } @@ -181,23 +188,27 @@ public boolean equals(@Nullable Object obj) { return obj != null && this.getClass() == obj.getClass(); } - private static class RowValueGettersFactory implements Factory> { - private final Factory> gettersFactory; - private final Factory> cachingGettersFactory; + private static class RowValueGettersFactory + implements Factory>> { + private final Factory>> gettersFactory; + private final @NotOnlyInitialized Factory>> + cachingGettersFactory; - static Factory> of(Factory> gettersFactory) { - return new RowValueGettersFactory(gettersFactory).cachingGettersFactory; + static Factory>> of( + Factory>> gettersFactory) { + return new RowValueGettersFactory<>(gettersFactory).cachingGettersFactory; } - RowValueGettersFactory(Factory> gettersFactory) { + RowValueGettersFactory(Factory>> gettersFactory) { this.gettersFactory = gettersFactory; this.cachingGettersFactory = new CachingFactory<>(this); } @Override - public List create(TypeDescriptor typeDescriptor, Schema schema) { - List getters = gettersFactory.create(typeDescriptor, schema); - List rowGetters = new ArrayList<>(getters.size()); + public List> create( + TypeDescriptor typeDescriptor, Schema schema) { + List> getters = gettersFactory.create(typeDescriptor, schema); + List> rowGetters = new ArrayList<>(getters.size()); for (int i = 0; i < getters.size(); i++) { rowGetters.add(rowValueGetter(getters.get(i), schema.getField(i).getType())); } @@ -209,71 +220,80 @@ static boolean needsConversion(FieldType type) { return typeName.equals(TypeName.ROW) || typeName.isLogicalType() || ((typeName.equals(TypeName.ARRAY) || typeName.equals(TypeName.ITERABLE)) - && needsConversion(type.getCollectionElementType())) + && needsConversion(Verify.verifyNotNull(type.getCollectionElementType()))) || (typeName.equals(TypeName.MAP) - && (needsConversion(type.getMapKeyType()) - || needsConversion(type.getMapValueType()))); + && (needsConversion(Verify.verifyNotNull(type.getMapKeyType())) + || needsConversion(Verify.verifyNotNull(type.getMapValueType())))); } - FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { + FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { TypeName typeName = type.getTypeName(); if (!needsConversion(type)) { return base; } if (typeName.equals(TypeName.ROW)) { - return new GetRow(base, type.getRowSchema(), cachingGettersFactory); + return new GetRow(base, Verify.verifyNotNull(type.getRowSchema()), cachingGettersFactory); } else if (typeName.equals(TypeName.ARRAY)) { - FieldType elementType = type.getCollectionElementType(); + FieldType elementType = Verify.verifyNotNull(type.getCollectionElementType()); return elementType.getTypeName().equals(TypeName.ROW) ? new GetEagerCollection(base, converter(elementType)) : new GetCollection(base, converter(elementType)); } else if (typeName.equals(TypeName.ITERABLE)) { - return new GetIterable(base, converter(type.getCollectionElementType())); + return new GetIterable( + base, converter(Verify.verifyNotNull(type.getCollectionElementType()))); } else if (typeName.equals(TypeName.MAP)) { - return new GetMap(base, converter(type.getMapKeyType()), converter(type.getMapValueType())); + return new GetMap( + base, + converter(Verify.verifyNotNull(type.getMapKeyType())), + converter(Verify.verifyNotNull(type.getMapValueType()))); } else if (type.isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = type.getLogicalType(OneOfType.class); Schema oneOfSchema = oneOfType.getOneOfSchema(); Map values = oneOfType.getCaseEnumType().getValuesMap(); - Map converters = Maps.newHashMapWithExpectedSize(values.size()); + Map> converters = + Maps.newHashMapWithExpectedSize(values.size()); for (Map.Entry kv : values.entrySet()) { FieldType fieldType = oneOfSchema.getField(kv.getKey()).getType(); - FieldValueGetter converter = converter(fieldType); + FieldValueGetter converter = converter(fieldType); converters.put(kv.getValue(), converter); } return new GetOneOf(base, converters, oneOfType); } else if (typeName.isLogicalType()) { - return new GetLogicalInputType(base, type.getLogicalType()); + return new GetLogicalInputType(base, Verify.verifyNotNull(type.getLogicalType())); } return base; } - FieldValueGetter converter(FieldType type) { + FieldValueGetter converter(FieldType type) { return rowValueGetter(IDENTITY, type); } - static class GetRow extends Converter { + static class GetRow + extends Converter { final Schema schema; - final Factory> factory; + final Factory>> factory; - GetRow(FieldValueGetter getter, Schema schema, Factory> factory) { + GetRow( + FieldValueGetter getter, + Schema schema, + Factory>> factory) { super(getter); this.schema = schema; this.factory = factory; } @Override - Object convert(Object value) { + Object convert(V value) { return Row.withSchema(schema).withFieldValueGetters(factory, value); } } - static class GetEagerCollection extends Converter { + static class GetEagerCollection extends Converter { final FieldValueGetter converter; - GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) { + GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @@ -288,15 +308,16 @@ Object convert(Collection collection) { } } - static class GetCollection extends Converter { + static class GetCollection extends Converter { final FieldValueGetter converter; - GetCollection(FieldValueGetter getter, FieldValueGetter converter) { + GetCollection(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @Override + @SuppressWarnings({"nullness"}) Object convert(Collection collection) { if (collection instanceof List) { // For performance reasons if the input is a list, make sure that we produce a list. @@ -309,45 +330,51 @@ Object convert(Collection collection) { } } - static class GetIterable extends Converter { + static class GetIterable extends Converter { final FieldValueGetter converter; - GetIterable(FieldValueGetter getter, FieldValueGetter converter) { + GetIterable(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @Override + @SuppressWarnings({"nullness"}) Object convert(Iterable value) { return Iterables.transform(value, converter::get); } } - static class GetMap extends Converter> { - final FieldValueGetter keyConverter; - final FieldValueGetter valueConverter; + static class GetMap + extends Converter> { + final FieldValueGetter<@NonNull K1, K2> keyConverter; + final FieldValueGetter<@NonNull V1, V2> valueConverter; GetMap( - FieldValueGetter getter, FieldValueGetter keyConverter, FieldValueGetter valueConverter) { + FieldValueGetter> getter, + FieldValueGetter<@NonNull K1, K2> keyConverter, + FieldValueGetter<@NonNull V1, V2> valueConverter) { super(getter); this.keyConverter = keyConverter; this.valueConverter = valueConverter; } @Override - Object convert(Map value) { - Map returnMap = Maps.newHashMapWithExpectedSize(value.size()); - for (Map.Entry entry : value.entrySet()) { - returnMap.put(keyConverter.get(entry.getKey()), valueConverter.get(entry.getValue())); + Map<@Nullable K2, @Nullable V2> convert(Map<@Nullable K1, @Nullable V1> value) { + Map<@Nullable K2, @Nullable V2> returnMap = Maps.newHashMapWithExpectedSize(value.size()); + for (Map.Entry<@Nullable K1, @Nullable V1> entry : value.entrySet()) { + returnMap.put( + Optional.ofNullable(entry.getKey()).map(keyConverter::get).orElse(null), + Optional.ofNullable(entry.getValue()).map(valueConverter::get).orElse(null)); } return returnMap; } } - static class GetLogicalInputType extends Converter { + static class GetLogicalInputType extends Converter { final LogicalType logicalType; - GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) { + GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) { super(getter); this.logicalType = logicalType; } @@ -359,12 +386,14 @@ Object convert(Object value) { } } - static class GetOneOf extends Converter { + static class GetOneOf extends Converter { final OneOfType oneOfType; - final Map converters; + final Map> converters; GetOneOf( - FieldValueGetter getter, Map converters, OneOfType oneOfType) { + FieldValueGetter getter, + Map> converters, + OneOfType oneOfType) { super(getter); this.converters = converters; this.oneOfType = oneOfType; @@ -373,24 +402,31 @@ static class GetOneOf extends Converter { @Override Object convert(OneOfType.Value value) { EnumerationType.Value caseType = value.getCaseType(); - FieldValueGetter converter = converters.get(caseType.getValue()); - checkState(converter != null, "Missing OneOf converter for case %s.", caseType); + + @NonNull + FieldValueGetter<@NonNull Object, Object> converter = + Verify.verifyNotNull( + converters.get(caseType.getValue()), + "Missing OneOf converter for case %s.", + caseType); + return oneOfType.createValue(caseType, converter.get(value.getValue())); } } - abstract static class Converter implements FieldValueGetter { - final FieldValueGetter getter; + abstract static class Converter + implements FieldValueGetter { + final FieldValueGetter getter; - public Converter(FieldValueGetter getter) { + public Converter(FieldValueGetter getter) { this.getter = getter; } - abstract Object convert(T value); + abstract Object convert(ValueT value); @Override - public @Nullable Object get(Object object) { - T value = (T) getter.get(object); + public @Nullable Object get(ObjectT object) { + ValueT value = getter.get(object); if (value == null) { return null; } @@ -398,7 +434,7 @@ public Converter(FieldValueGetter getter) { } @Override - public @Nullable Object getRaw(Object object) { + public @Nullable Object getRaw(ObjectT object) { return getter.getRaw(object); } @@ -408,16 +444,16 @@ public String name() { } } - private static final FieldValueGetter IDENTITY = - new FieldValueGetter() { + private static final FieldValueGetter<@NonNull Object, Object> IDENTITY = + new FieldValueGetter<@NonNull Object, Object>() { @Override - public @Nullable Object get(Object object) { + public Object get(@NonNull Object object) { return object; } @Override public String name() { - return null; + return "IDENTITY"; } }; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java index de31f9947c36..e7214d8f663a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java @@ -19,6 +19,7 @@ import java.util.List; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A newer version of {@link GetterBasedSchemaProvider}, which works with {@link TypeDescriptor}s, @@ -28,12 +29,12 @@ public abstract class GetterBasedSchemaProviderV2 extends GetterBasedSchemaProvider { @Override public List fieldValueGetters(Class targetClass, Schema schema) { - return fieldValueGetters(TypeDescriptor.of(targetClass), schema); + return (List) fieldValueGetters(TypeDescriptor.of(targetClass), schema); } @Override - public abstract List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema); + public abstract List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema); @Override public List fieldValueTypeInformations( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java index a9cf01c52057..14adf2f6603e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java @@ -19,7 +19,6 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; @@ -34,6 +33,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -49,10 +49,7 @@ *

TODO: Validate equals() method is provided, and if not generate a "slow" equals method based * on the schema. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class JavaBeanSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on getter methods. */ @VisibleForTesting @@ -68,9 +65,9 @@ public List get(TypeDescriptor typeDescriptor) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); return types; } @@ -114,29 +111,32 @@ public List get(TypeDescriptor typeDescriptor) { return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream() .filter(ReflectUtils::isSetter) .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) - .map(FieldValueTypeInformation::forSetter) + .map(m -> FieldValueTypeInformation.forSetter(typeDescriptor, m)) .map( t -> { - if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) { + Method m = + Preconditions.checkNotNull( + t.getMethod(), JavaBeanUtils.SETTER_WITH_NULL_METHOD_ERROR); + if (m.getAnnotation(SchemaFieldNumber.class) != null) { throw new RuntimeException( String.format( "@SchemaFieldNumber can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } - if (t.getMethod().getAnnotation(SchemaFieldName.class) != null) { + if (m.getAnnotation(SchemaFieldName.class) != null) { throw new RuntimeException( String.format( "@SchemaFieldName can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } - if (t.getMethod().getAnnotation(SchemaCaseFormat.class) != null) { + if (m.getAnnotation(SchemaCaseFormat.class) != null) { throw new RuntimeException( String.format( "@SchemaCaseFormat can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } return t; }) @@ -172,8 +172,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( targetTypeDescriptor, schema, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java index 21f07c47b47f..9a8eef2bf2c8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java @@ -21,20 +21,22 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.DefaultTypeConversionsFactory; import org.apache.beam.sdk.schemas.utils.FieldValueTypeSupplier; +import org.apache.beam.sdk.schemas.utils.JavaBeanUtils; import org.apache.beam.sdk.schemas.utils.POJOUtils; import org.apache.beam.sdk.schemas.utils.ReflectUtils; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A {@link SchemaProvider} for Java POJO objects. @@ -49,7 +51,6 @@ *

TODO: Validate equals() method is provided, and if not generate a "slow" equals method based * on the schema. */ -@SuppressWarnings({"nullness", "rawtypes"}) public class JavaFieldSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on public fields. */ @VisibleForTesting @@ -64,9 +65,9 @@ public List get(TypeDescriptor typeDescriptor) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(fields.size()); for (int i = 0; i < fields.size(); ++i) { - types.add(FieldValueTypeInformation.forField(fields.get(i), i)); + types.add(FieldValueTypeInformation.forField(typeDescriptor, fields.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); // If there are no creators registered, then make sure none of the schema fields are final, @@ -75,7 +76,9 @@ public List get(TypeDescriptor typeDescriptor) { && ReflectUtils.getAnnotatedConstructor(typeDescriptor.getRawType()) == null) { Optional finalField = types.stream() - .map(FieldValueTypeInformation::getField) + .flatMap( + fvti -> + Optional.ofNullable(fvti.getField()).map(Stream::of).orElse(Stream.empty())) .filter(f -> Modifier.isFinal(f.getModifiers())) .findAny(); if (finalField.isPresent()) { @@ -115,8 +118,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return POJOUtils.getGetters( targetTypeDescriptor, schema, @@ -149,7 +152,7 @@ public SchemaUserTypeCreator schemaTypeCreator( ReflectUtils.getAnnotatedConstructor(targetTypeDescriptor.getRawType()); if (constructor != null) { return POJOUtils.getConstructorCreator( - targetTypeDescriptor, + (TypeDescriptor) targetTypeDescriptor, constructor, schema, JavaFieldTypeSupplier.INSTANCE, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index 5af59356b174..02607d91b079 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -90,6 +90,7 @@ public String toString() { return Arrays.toString(array); } } + // A mapping between field names an indices. private final BiMap fieldIndices; @@ -830,10 +831,11 @@ public static FieldType iterable(FieldType elementType) { public static FieldType map(FieldType keyType, FieldType valueType) { if (FieldType.BYTES.equals(keyType)) { LOG.warn( - "Using byte arrays as keys in a Map may lead to unexpected behavior and may not work as intended. " - + "Since arrays do not override equals() or hashCode, comparisons will be done on reference equality only. " - + "ByteBuffers, when used as keys, present similar challenges because Row stores ByteBuffer as a byte array. " - + "Consider using a different type of key for more consistent and predictable behavior."); + "Using byte arrays as keys in a Map may lead to unexpected behavior and may not work as" + + " intended. Since arrays do not override equals() or hashCode, comparisons will" + + " be done on reference equality only. ByteBuffers, when used as keys, present" + + " similar challenges because Row stores ByteBuffer as a byte array. Consider" + + " using a different type of key for more consistent and predictable behavior."); } return FieldType.forTypeName(TypeName.MAP) .setMapKeyType(keyType) @@ -1443,7 +1445,7 @@ private static Schema fromFields(List fields) { } /** Return the list of all field names. */ - public List getFieldNames() { + public List<@NonNull String> getFieldNames() { return getFields().stream().map(Schema.Field::getName).collect(Collectors.toList()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java index d7fddd8abfed..300dce61e2ea 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java @@ -27,6 +27,7 @@ import java.lang.reflect.Parameter; import java.lang.reflect.Type; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -62,21 +63,25 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversionsFactory; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for managing AutoValue schemas. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class AutoValueUtils { - public static TypeDescriptor getBaseAutoValueClass(TypeDescriptor typeDescriptor) { + public static @Nullable TypeDescriptor getBaseAutoValueClass( + TypeDescriptor typeDescriptor) { // AutoValue extensions may be nested - while (typeDescriptor != null && typeDescriptor.getRawType().getName().contains("AutoValue_")) { - typeDescriptor = TypeDescriptor.of(typeDescriptor.getRawType().getSuperclass()); + @Nullable TypeDescriptor baseTypeDescriptor = typeDescriptor; + while (baseTypeDescriptor != null + && baseTypeDescriptor.getRawType().getName().contains("AutoValue_")) { + baseTypeDescriptor = + Optional.ofNullable(baseTypeDescriptor.getRawType().getSuperclass()) + .map(TypeDescriptor::of) + .orElse(null); } - return typeDescriptor; + return baseTypeDescriptor; } private static TypeDescriptor getAutoValueGenerated(TypeDescriptor typeDescriptor) { @@ -154,7 +159,11 @@ private static boolean matchConstructor( getterTypes.stream() .collect( Collectors.toMap( - f -> ReflectUtils.stripGetterPrefix(f.getMethod().getName()), + f -> + ReflectUtils.stripGetterPrefix( + Preconditions.checkNotNull( + f.getMethod(), JavaBeanUtils.GETTER_WITH_NULL_METHOD_ERROR) + .getName()), Function.identity())); boolean valid = true; @@ -196,18 +205,23 @@ private static boolean matchConstructor( return null; } - Map setterTypes = - ReflectUtils.getMethods(builderClass).stream() - .filter(ReflectUtils::isSetter) - .map(FieldValueTypeInformation::forSetter) - .collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity())); + Map setterTypes = new HashMap<>(); + + ReflectUtils.getMethods(builderClass).stream() + .filter(ReflectUtils::isSetter) + .map(m -> FieldValueTypeInformation.forSetter(TypeDescriptor.of(builderClass), m)) + .forEach(fv -> setterTypes.putIfAbsent(fv.getName(), fv)); List setterMethods = Lists.newArrayList(); // The builder methods to call in order. List schemaTypes = fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); for (FieldValueTypeInformation type : schemaTypes) { - String autoValueFieldName = ReflectUtils.stripGetterPrefix(type.getMethod().getName()); + String autoValueFieldName = + ReflectUtils.stripGetterPrefix( + Preconditions.checkNotNull( + type.getMethod(), JavaBeanUtils.GETTER_WITH_NULL_METHOD_ERROR) + .getName()); FieldValueTypeInformation setterType = setterTypes.get(autoValueFieldName); if (setterType == null) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index c2b33c2d2315..5297eb113a97 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -20,7 +20,9 @@ import static org.apache.beam.sdk.util.ByteBuddyUtils.getClassLoadingStrategy; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import java.io.Serializable; import java.lang.reflect.Constructor; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; @@ -33,6 +35,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.SortedMap; import net.bytebuddy.ByteBuddy; @@ -41,6 +44,7 @@ import net.bytebuddy.asm.AsmVisitorWrapper; import net.bytebuddy.description.method.MethodDescription.ForLoadedConstructor; import net.bytebuddy.description.method.MethodDescription.ForLoadedMethod; +import net.bytebuddy.description.type.PackageDescription; import net.bytebuddy.description.type.TypeDescription; import net.bytebuddy.description.type.TypeDescription.ForLoadedType; import net.bytebuddy.dynamic.DynamicType; @@ -77,6 +81,8 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeParameter; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Function; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -84,6 +90,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Primitives; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ClassUtils; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTimeZone; import org.joda.time.Instant; @@ -94,8 +101,6 @@ @Internal @SuppressWarnings({ "keyfor", - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" }) public class ByteBuddyUtils { private static final ForLoadedType ARRAYS_TYPE = new ForLoadedType(Arrays.class); @@ -146,7 +151,11 @@ protected String name(TypeDescription superClass) { // If the target class is in a prohibited package (java.*) then leave the original package // alone. String realPackage = - overridePackage(targetPackage) ? targetPackage : superClass.getPackage().getName(); + overridePackage(targetPackage) + ? targetPackage + : Optional.ofNullable(superClass.getPackage()) + .map(PackageDescription::getName) + .orElse(""); return realPackage + className + "$" + SUFFIX + "$" + randomString.nextString(); } @@ -201,25 +210,27 @@ static class ShortCircuitReturnNull extends IfNullElse { // Create a new FieldValueGetter subclass. @SuppressWarnings("unchecked") - public static DynamicType.Builder subclassGetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { + public static + DynamicType.Builder> subclassGetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { TypeDescription.Generic getterGenericType = TypeDescription.Generic.Builder.parameterizedType( FieldValueGetter.class, objectType, fieldType) .build(); - return (DynamicType.Builder) + return (DynamicType.Builder>) byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(getterGenericType); } // Create a new FieldValueSetter subclass. @SuppressWarnings("unchecked") - public static DynamicType.Builder subclassSetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { + public static + DynamicType.Builder> subclassSetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { TypeDescription.Generic setterGenericType = TypeDescription.Generic.Builder.parameterizedType( FieldValueSetter.class, objectType, fieldType) .build(); - return (DynamicType.Builder) + return (DynamicType.Builder) byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(setterGenericType); } @@ -251,9 +262,11 @@ public TypeConversion createSetterConversions(StackManipulati // Base class used below to convert types. @SuppressWarnings("unchecked") public abstract static class TypeConversion { - public T convert(TypeDescriptor typeDescriptor) { + public T convert(TypeDescriptor typeDescriptor) { if (typeDescriptor.isArray() - && !typeDescriptor.getComponentType().getRawType().equals(byte.class)) { + && !Preconditions.checkNotNull(typeDescriptor.getComponentType()) + .getRawType() + .equals(byte.class)) { // Byte arrays are special, so leave those alone. return convertArray(typeDescriptor); } else if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Map.class))) { @@ -338,25 +351,32 @@ protected ConvertType(boolean returnRawTypes) { @Override protected Type convertArray(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(type.getComponentType()); + TypeDescriptor ret = + createCollectionType(Preconditions.checkNotNull(type.getComponentType())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertCollection(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertList(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertIterable(TypeDescriptor type) { - TypeDescriptor ret = createIterableType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createIterableType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @@ -398,8 +418,9 @@ protected Type convertDefault(TypeDescriptor type) { @SuppressWarnings("unchecked") private TypeDescriptor> createCollectionType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = + (TypeDescriptor) + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -407,8 +428,9 @@ private TypeDescriptor> createCollectionType( @SuppressWarnings("unchecked") private TypeDescriptor> createIterableType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = + (TypeDescriptor) + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -420,7 +442,7 @@ private TypeDescriptor> createIterableType( // This function // generates a subclass of Function that can be used to recursively transform each element of the // container. - static Class createCollectionTransformFunction( + static Class createCollectionTransformFunction( Type fromType, Type toType, Function convertElement) { // Generate a TypeDescription for the class we want to generate. TypeDescription.Generic functionGenericType = @@ -428,8 +450,8 @@ static Class createCollectionTransformFunction( Function.class, Primitives.wrap((Class) fromType), Primitives.wrap((Class) toType)) .build(); - DynamicType.Builder builder = - (DynamicType.Builder) + DynamicType.Builder> builder = + (DynamicType.Builder) BYTE_BUDDY .with(new InjectPackageStrategy((Class) fromType)) .subclass(functionGenericType) @@ -463,9 +485,11 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) { .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader(((Class) fromType).getClassLoader()), + ReflectHelpers.findClassLoader(((Class) fromType).getClassLoader()), getClassLoadingStrategy( - ((Class) fromType).getClassLoader() == null ? Function.class : (Class) fromType)) + ((Class) fromType).getClassLoader() == null + ? Function.class + : (Class) fromType)) .getLoaded(); } @@ -508,7 +532,7 @@ public static TransformingMap getTransformingMa return new TransformingMap<>(sourceMap, keyFunction, valueFunction); } - public static class TransformingMap implements Map { + public static class TransformingMap implements Map, Serializable { private final Map delegateMap; public TransformingMap( @@ -547,17 +571,17 @@ public boolean containsValue(Object value) { } @Override - public V2 get(Object key) { + public @Nullable V2 get(Object key) { return delegateMap.get(key); } @Override - public V2 put(K2 key, V2 value) { + public @Nullable V2 put(K2 key, V2 value) { return delegateMap.put(key, value); } @Override - public V2 remove(Object key) { + public @Nullable V2 remove(Object key) { return delegateMap.remove(key); } @@ -635,12 +659,12 @@ protected StackManipulation convertArray(TypeDescriptor type) { // return isComponentTypePrimitive ? Arrays.asList(ArrayUtils.toObject(value)) // : Arrays.asList(value); - TypeDescriptor componentType = type.getComponentType(); + TypeDescriptor componentType = Preconditions.checkNotNull(type.getComponentType()); ForLoadedType loadedArrayType = new ForLoadedType(type.getRawType()); StackManipulation readArrayValue = readValue; // Row always expects to get an Iterable back for array types. Wrap this array into a // List using Arrays.asList before returning. - if (loadedArrayType.getComponentType().isPrimitive()) { + if (Preconditions.checkNotNull(loadedArrayType.getComponentType()).isPrimitive()) { // Arrays.asList doesn't take primitive arrays, so convert first using ArrayUtils.toObject. readArrayValue = new Compound( @@ -668,7 +692,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Generate a SerializableFunction to convert the element-type objects. StackManipulation stackManipulation; - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); @@ -687,10 +711,11 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -707,9 +732,10 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -726,9 +752,10 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -745,8 +772,8 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); - final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); + final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); + final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); Type convertedKeyType = getFactory().createTypeConversion(true).convert(keyType); Type convertedValueType = getFactory().createTypeConversion(true).convert(valueType); @@ -970,16 +997,18 @@ protected StackManipulation convertArray(TypeDescriptor type) { // return isPrimitive ? toArray : ArrayUtils.toPrimitive(toArray); ForLoadedType loadedType = new ForLoadedType(type.getRawType()); + TypeDescription loadedTypeComponentType = Verify.verifyNotNull(loadedType.getComponentType()); + // The type of the array containing the (possibly) boxed values. TypeDescription arrayType = - TypeDescription.Generic.Builder.rawType(loadedType.getComponentType().asBoxed()) + TypeDescription.Generic.Builder.rawType(loadedTypeComponentType.asBoxed()) .asArray() .build() .asErasure(); - Type rowElementType = - getFactory().createTypeConversion(false).convert(type.getComponentType()); - final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(type.getComponentType()); + TypeDescriptor componentType = Preconditions.checkNotNull(type.getComponentType()); + Type rowElementType = getFactory().createTypeConversion(false).convert(componentType); + final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(componentType); StackManipulation readTransformedValue = readValue; if (!arrayElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -999,7 +1028,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Call Collection.toArray(T[[]) to extract the array. Push new T[0] on the stack // before // calling toArray. - ArrayFactory.forType(loadedType.getComponentType().asBoxed().asGenericType()) + ArrayFactory.forType(loadedTypeComponentType.asBoxed().asGenericType()) .withValues(Collections.emptyList()), MethodInvocation.invoke( COLLECTION_TYPE @@ -1016,7 +1045,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Cast the result to T[]. TypeCasting.to(arrayType)); - if (loadedType.getComponentType().isPrimitive()) { + if (loadedTypeComponentType.isPrimitive()) { // The array we extract will be an array of objects. If the pojo field is an array of // primitive types, we need to then convert to an array of unboxed objects. stackManipulation = @@ -1035,11 +1064,9 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor iterableElementType = ReflectUtils.getIterableComponentType(type); + final TypeDescriptor iterableElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(iterableElementType); if (!iterableElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = new ForLoadedType( @@ -1057,11 +1084,9 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + final TypeDescriptor collectionElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(collectionElementType); if (!collectionElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -1080,11 +1105,9 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + final TypeDescriptor collectionElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(collectionElementType); StackManipulation readTrasformedValue = readValue; if (!collectionElementType.hasUnresolvedParameters()) { @@ -1112,12 +1135,12 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - Type rowKeyType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 0)); - final TypeDescriptor keyElementType = ReflectUtils.getMapType(type, 0); - Type rowValueType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 1)); - final TypeDescriptor valueElementType = ReflectUtils.getMapType(type, 1); + final TypeDescriptor keyElementType = + Preconditions.checkNotNull(ReflectUtils.getMapType(type, 0)); + final TypeDescriptor valueElementType = + Preconditions.checkNotNull(ReflectUtils.getMapType(type, 1)); + Type rowKeyType = getFactory().createTypeConversion(false).convert(keyElementType); + Type rowValueType = getFactory().createTypeConversion(false).convert(valueElementType); StackManipulation readTrasformedValue = readValue; if (!keyElementType.hasUnresolvedParameters() @@ -1332,12 +1355,12 @@ protected StackManipulation convertDefault(TypeDescriptor type) { * constructor. */ static class ConstructorCreateInstruction extends InvokeUserCreateInstruction { - private final Constructor constructor; + private final Constructor constructor; ConstructorCreateInstruction( List fields, - Class targetClass, - Constructor constructor, + Class targetClass, + Constructor constructor, TypeConversionsFactory typeConversionsFactory) { super( fields, @@ -1375,7 +1398,7 @@ static class StaticFactoryMethodInstruction extends InvokeUserCreateInstruction StaticFactoryMethodInstruction( List fields, - Class targetClass, + Class targetClass, Method creator, TypeConversionsFactory typeConversionsFactory) { super( @@ -1399,14 +1422,14 @@ protected StackManipulation afterPushingParameters() { static class InvokeUserCreateInstruction implements Implementation { protected final List fields; - protected final Class targetClass; + protected final Class targetClass; protected final List parameters; protected final Map fieldMapping; private final TypeConversionsFactory typeConversionsFactory; protected InvokeUserCreateInstruction( List fields, - Class targetClass, + Class targetClass, List parameters, TypeConversionsFactory typeConversionsFactory) { this.fields = fields; @@ -1424,11 +1447,15 @@ protected InvokeUserCreateInstruction( // actual Java field or method names. FieldValueTypeInformation fieldValue = checkNotNull(fields.get(i)); fieldsByLogicalName.put(fieldValue.getName(), i); - if (fieldValue.getField() != null) { - fieldsByJavaClassMember.put(fieldValue.getField().getName(), i); - } else if (fieldValue.getMethod() != null) { - String name = ReflectUtils.stripGetterPrefix(fieldValue.getMethod().getName()); - fieldsByJavaClassMember.put(name, i); + Field field = fieldValue.getField(); + if (field != null) { + fieldsByJavaClassMember.put(field.getName(), i); + } else { + Method method = fieldValue.getMethod(); + if (method != null) { + String name = ReflectUtils.stripGetterPrefix(method.getName()); + fieldsByJavaClassMember.put(name, i); + } } } @@ -1482,7 +1509,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { StackManipulation readParameter = new StackManipulation.Compound( MethodVariableAccess.REFERENCE.loadFrom(1), - IntegerConstant.forValue(fieldMapping.get(i)), + IntegerConstant.forValue(Preconditions.checkNotNull(fieldMapping.get(i))), ArrayAccess.REFERENCE.load(), TypeCasting.to(convertedType)); stackManipulation = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java index 911f79f6eeed..ee4868ddb2b6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java @@ -22,9 +22,11 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; @@ -54,14 +56,22 @@ import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; /** A set of utilities to generate getter and setter classes for JavaBean objects. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class JavaBeanUtils { + + private static final String X_WITH_NULL_METHOD_ERROR_FMT = + "a %s FieldValueTypeInformation object has a null method field"; + public static final String GETTER_WITH_NULL_METHOD_ERROR = + String.format(X_WITH_NULL_METHOD_ERROR_FMT, "getter"); + public static final String SETTER_WITH_NULL_METHOD_ERROR = + String.format(X_WITH_NULL_METHOD_ERROR_FMT, "setter"); + /** Create a {@link Schema} for a Java Bean class. */ public static Schema schemaFromJavaBeanClass( TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { @@ -69,7 +79,9 @@ public static Schema schemaFromJavaBeanClass( } private static final String CONSTRUCTOR_HELP_STRING = - "In order to infer a Schema from a Java Bean, it must have a constructor annotated with @SchemaCreate, or it must have a compatible setter for every getter used as a Schema field."; + "In order to infer a Schema from a Java Bean, it must have a constructor annotated with" + + " @SchemaCreate, or it must have a compatible setter for every getter used as a Schema" + + " field."; // Make sure that there are matching setters and getters. public static void validateJavaBean( @@ -88,23 +100,26 @@ public static void validateJavaBean( for (FieldValueTypeInformation type : getters) { FieldValueTypeInformation setterType = setterMap.get(type.getName()); + Method m = Preconditions.checkNotNull(type.getMethod(), GETTER_WITH_NULL_METHOD_ERROR); if (setterType == null) { throw new RuntimeException( String.format( - "Java Bean '%s' contains a getter for field '%s', but does not contain a matching setter. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + "Java Bean '%s' contains a getter for field '%s', but does not contain a matching" + + " setter. %s", + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } if (!type.getType().equals(setterType.getType())) { throw new RuntimeException( String.format( "Java Bean '%s' contains a setter for field '%s' that has a mismatching type. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } if (!type.isNullable() == setterType.isNullable()) { throw new RuntimeException( String.format( - "Java Bean '%s' contains a setter for field '%s' that has a mismatching nullable attribute. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + "Java Bean '%s' contains a setter for field '%s' that has a mismatching nullable" + + " attribute. %s", + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } } } @@ -126,36 +141,41 @@ public static List getFieldTypes( // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map, List> CACHED_GETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_GETTERS = Maps.newConcurrentMap(); /** * Return the list of {@link FieldValueGetter}s for a Java Bean class * *

The returned list is ordered by the order of fields in the schema. */ - public static List getGetters( - TypeDescriptor typeDescriptor, + public static List> getGetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { - return CACHED_GETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - return types.stream() - .map(t -> createGetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return types.stream() + .map(t -> JavaBeanUtils.createGetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + }); } - public static FieldValueGetter createGetter( - FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - DynamicType.Builder builder = + public static + FieldValueGetter createGetter( + FieldValueTypeInformation typeInformation, + TypeConversionsFactory typeConversionsFactory) { + final Method m = + Preconditions.checkNotNull(typeInformation.getMethod(), GETTER_WITH_NULL_METHOD_ERROR); + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface( BYTE_BUDDY, - typeInformation.getMethod().getDeclaringClass(), + m.getDeclaringClass(), typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); builder = implementGetterMethods(builder, typeInformation, typeConversionsFactory); try { @@ -163,9 +183,8 @@ public static FieldValueGetter createGetter( .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader( - typeInformation.getMethod().getDeclaringClass().getClassLoader()), - getClassLoadingStrategy(typeInformation.getMethod().getDeclaringClass())) + ReflectHelpers.findClassLoader(m.getDeclaringClass().getClassLoader()), + getClassLoadingStrategy(m.getDeclaringClass())) .getLoaded() .getDeclaredConstructor() .newInstance(); @@ -178,10 +197,11 @@ public static FieldValueGetter createGetter( } } - private static DynamicType.Builder implementGetterMethods( - DynamicType.Builder builder, - FieldValueTypeInformation typeInformation, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementGetterMethods( + DynamicType.Builder> builder, + FieldValueTypeInformation typeInformation, + TypeConversionsFactory typeConversionsFactory) { return builder .method(ElementMatchers.named("name")) .intercept(FixedValue.reference(typeInformation.getName())) @@ -215,12 +235,14 @@ public static List getSetters( }); } - public static FieldValueSetter createSetter( + public static FieldValueSetter createSetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - DynamicType.Builder builder = + final Method m = + Preconditions.checkNotNull(typeInformation.getMethod(), SETTER_WITH_NULL_METHOD_ERROR); + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, - typeInformation.getMethod().getDeclaringClass(), + m.getDeclaringClass(), typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); builder = implementSetterMethods(builder, typeInformation, typeConversionsFactory); try { @@ -228,9 +250,8 @@ public static FieldValueSetter createSetter( .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader( - typeInformation.getMethod().getDeclaringClass().getClassLoader()), - getClassLoadingStrategy(typeInformation.getMethod().getDeclaringClass())) + ReflectHelpers.findClassLoader(m.getDeclaringClass().getClassLoader()), + getClassLoadingStrategy(m.getDeclaringClass())) .getLoaded() .getDeclaredConstructor() .newInstance(); @@ -243,10 +264,11 @@ public static FieldValueSetter createSetter( } } - private static DynamicType.Builder implementSetterMethods( - DynamicType.Builder builder, - FieldValueTypeInformation fieldValueTypeInformation, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementSetterMethods( + DynamicType.Builder> builder, + FieldValueTypeInformation fieldValueTypeInformation, + TypeConversionsFactory typeConversionsFactory) { return builder .method(ElementMatchers.named("name")) .intercept(FixedValue.reference(fieldValueTypeInformation.getName())) @@ -358,6 +380,11 @@ public static SchemaUserTypeCreator createStaticCreator( } } + public static > Comparator comparingNullFirst( + Function keyExtractor) { + return Comparator.comparing(keyExtractor, Comparator.nullsFirst(Comparator.naturalOrder())); + } + // Implements a method to read a public getter out of an object. private static class InvokeGetterInstruction implements Implementation { private final FieldValueTypeInformation typeInformation; @@ -386,7 +413,10 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Method param is offset 1 (offset 0 is the this parameter). MethodVariableAccess.REFERENCE.loadFrom(1), // Invoke the getter - MethodInvocation.invoke(new ForLoadedMethod(typeInformation.getMethod()))); + MethodInvocation.invoke( + new ForLoadedMethod( + Preconditions.checkNotNull( + typeInformation.getMethod(), GETTER_WITH_NULL_METHOD_ERROR)))); StackManipulation stackManipulation = new StackManipulation.Compound( @@ -428,7 +458,9 @@ public ByteCodeAppender appender(final Target implementationTarget) { // The instruction to read the field. StackManipulation readField = MethodVariableAccess.REFERENCE.loadFrom(2); - Method method = fieldValueTypeInformation.getMethod(); + Method method = + Preconditions.checkNotNull( + fieldValueTypeInformation.getMethod(), SETTER_WITH_NULL_METHOD_ERROR); boolean setterMethodReturnsVoid = method.getReturnType().equals(Void.TYPE); // Read the object onto the stack. StackManipulation stackManipulation = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index 571b9c690900..8e33d321a1c6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -62,8 +62,9 @@ import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; -import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.NonNull; /** A set of utilities to generate getter and setter classes for POJOs. */ @SuppressWarnings({ @@ -94,38 +95,40 @@ public static List getFieldTypes( // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_GETTERS = Maps.newConcurrentMap(); - public static List getGetters( - TypeDescriptor typeDescriptor, + public static List> getGetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { // Return the getters ordered by their position in the schema. - return CACHED_GETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - List getters = - types.stream() - .map(t -> createGetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - if (getters.size() != schema.getFieldCount()) { - throw new RuntimeException( - "Was not able to generate getters for schema: " - + schema - + " class: " - + typeDescriptor); - } - return getters; - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + List> getters = + types.stream() + .>map( + t -> POJOUtils.createGetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + if (getters.size() != schema.getFieldCount()) { + throw new RuntimeException( + "Was not able to generate getters for schema: " + + schema + + " class: " + + typeDescriptor); + } + return (List) getters; + }); } // The list of constructors for a class is cached, so we only create the classes the first time // getConstructor is called. - public static final Map CACHED_CREATORS = + public static final Map, SchemaUserTypeCreator> CACHED_CREATORS = Maps.newConcurrentMap(); public static SchemaUserTypeCreator getSetFieldCreator( @@ -150,7 +153,9 @@ private static SchemaUserTypeCreator createSetFieldCreator( TypeConversionsFactory typeConversionsFactory) { // Get the list of class fields ordered by schema. List fields = - types.stream().map(FieldValueTypeInformation::getField).collect(Collectors.toList()); + types.stream() + .map(type -> Preconditions.checkNotNull(type.getField())) + .collect(Collectors.toList()); try { DynamicType.Builder builder = BYTE_BUDDY @@ -175,14 +180,16 @@ private static SchemaUserTypeCreator createSetFieldCreator( | InvocationTargetException e) { throw new RuntimeException( String.format( - "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must have a zero-argument constructor, or a constructor annotated with @SchemaCreate.", + "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must" + + " have a zero-argument constructor, or a constructor annotated with" + + " @SchemaCreate.", clazz, schema)); } } - public static SchemaUserTypeCreator getConstructorCreator( - TypeDescriptor typeDescriptor, - Constructor constructor, + public static SchemaUserTypeCreator getConstructorCreator( + TypeDescriptor typeDescriptor, + Constructor constructor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { @@ -191,13 +198,13 @@ public static SchemaUserTypeCreator getConstructorCreator( c -> { List types = fieldValueTypeSupplier.get(typeDescriptor, schema); - return createConstructorCreator( + return POJOUtils.createConstructorCreator( typeDescriptor.getRawType(), constructor, schema, types, typeConversionsFactory); }); } public static SchemaUserTypeCreator createConstructorCreator( - Class clazz, + Class clazz, Constructor constructor, Schema schema, List types, @@ -291,11 +298,10 @@ public static SchemaUserTypeCreator createStaticCreator( * } * */ - @SuppressWarnings("unchecked") - static @Nullable FieldValueGetter createGetter( + static FieldValueGetter<@NonNull ObjectT, ValueT> createGetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - Field field = typeInformation.getField(); - DynamicType.Builder builder = + Field field = Preconditions.checkNotNull(typeInformation.getField()); + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface( BYTE_BUDDY, field.getDeclaringClass(), @@ -322,11 +328,12 @@ public static SchemaUserTypeCreator createStaticCreator( } } - private static DynamicType.Builder implementGetterMethods( - DynamicType.Builder builder, - Field field, - String name, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementGetterMethods( + DynamicType.Builder> builder, + Field field, + String name, + TypeConversionsFactory typeConversionsFactory) { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .method(ElementMatchers.named("name")) @@ -337,24 +344,25 @@ private static DynamicType.Builder implementGetterMethods( // The list of setters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_SETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_SETTERS = Maps.newConcurrentMap(); - public static List getSetters( - TypeDescriptor typeDescriptor, + public static List> getSetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { // Return the setters, ordered by their position in the schema. - return CACHED_SETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - return types.stream() - .map(t -> createSetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_SETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return types.stream() + .map(t -> createSetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + }); } /** @@ -376,8 +384,8 @@ public static List getSetters( @SuppressWarnings("unchecked") private static FieldValueSetter createSetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - Field field = typeInformation.getField(); - DynamicType.Builder builder = + Field field = Preconditions.checkNotNull(typeInformation.getField()); + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, field.getDeclaringClass(), @@ -403,10 +411,11 @@ private static FieldValueSetter createSetter( } } - private static DynamicType.Builder implementSetterMethods( - DynamicType.Builder builder, - Field field, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementSetterMethods( + DynamicType.Builder> builder, + Field field, + TypeConversionsFactory typeConversionsFactory) { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .method(ElementMatchers.named("name")) @@ -505,11 +514,11 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Implements a method to construct an object. static class SetFieldCreateInstruction implements Implementation { private final List fields; - private final Class pojoClass; + private final Class pojoClass; private final TypeConversionsFactory typeConversionsFactory; SetFieldCreateInstruction( - List fields, Class pojoClass, TypeConversionsFactory typeConversionsFactory) { + List fields, Class pojoClass, TypeConversionsFactory typeConversionsFactory) { this.fields = fields; this.pojoClass = pojoClass; this.typeConversionsFactory = typeConversionsFactory; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index 4349a04c28ad..423fea4c3845 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -32,7 +32,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.SchemaCreate; @@ -88,14 +87,23 @@ public static List getMethods(Class clazz) { return DECLARED_METHODS.computeIfAbsent( clazz, c -> { - return Arrays.stream(c.getDeclaredMethods()) - .filter( - m -> !m.isBridge()) // Covariant overloads insert bridge functions, which we must - // ignore. - .filter(m -> !Modifier.isPrivate(m.getModifiers())) - .filter(m -> !Modifier.isProtected(m.getModifiers())) - .filter(m -> !Modifier.isStatic(m.getModifiers())) - .collect(Collectors.toList()); + List methods = Lists.newArrayList(); + do { + if (c.getPackage() != null && c.getPackage().getName().startsWith("java.")) { + break; // skip java built-in classes + } + Arrays.stream(c.getDeclaredMethods()) + .filter( + m -> + !m.isBridge()) // Covariant overloads insert bridge functions, which we must + // ignore. + .filter(m -> !Modifier.isPrivate(m.getModifiers())) + .filter(m -> !Modifier.isProtected(m.getModifiers())) + .filter(m -> !Modifier.isStatic(m.getModifiers())) + .forEach(methods::add); + c = c.getSuperclass(); + } while (c != null); + return methods; }); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java index ed61f7f3d6f2..328bf19c466c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java @@ -373,7 +373,7 @@ public PipelineResult runWithAdditionalOptionArgs(List additionalArgs) { } newOptions.setStableUniqueNames(CheckEnabled.ERROR); - FileSystems.setDefaultPipelineOptions(options); + FileSystems.registerFileSystemsOnce(options); return run(newOptions); } catch (IOException e) { throw new RuntimeException( @@ -515,7 +515,7 @@ public static PipelineOptions testingPipelineOptions() { } options.setStableUniqueNames(CheckEnabled.ERROR); - FileSystems.setDefaultPipelineOptions(options); + FileSystems.registerFileSystemsOnce(options); return options; } catch (IOException e) { throw new RuntimeException( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java index 8da3cf71af9f..159f92cd5e87 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java @@ -20,6 +20,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableLikeCoder; import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; @@ -82,6 +83,81 @@ public static Iterables iterables() { return new Iterables<>(); } + /** + * Returns a {@link PTransform} that flattens the input {@link PCollection} with a given {@link + * PCollection} resulting in a {@link PCollection} containing all the elements of both {@link + * PCollection}s as its output. + * + *

This is equivalent to creating a {@link PCollectionList} containing both the input and + * {@code other} and then applying {@link #pCollections()}, but has the advantage that it can be + * more easily used inline. + * + *

Both {@cpde PCollections} must have equal {@link WindowFn}s. The output elements of {@code + * Flatten} are in the same windows and have the same timestamps as their corresponding input + * elements. The output {@code PCollection} will have the same {@link WindowFn} as both inputs. + * + * @param other the other PCollection to flatten with the input + * @param the type of the elements in the input and output {@code PCollection}s. + */ + public static PTransform, PCollection> with(PCollection other) { + return new FlattenWithPCollection<>(other); + } + + /** Implementation of {@link #with(PCollection)}. */ + private static class FlattenWithPCollection + extends PTransform, PCollection> { + // We only need to access this at pipeline construction time. + private final transient PCollection other; + + public FlattenWithPCollection(PCollection other) { + this.other = other; + } + + @Override + public PCollection expand(PCollection input) { + return PCollectionList.of(input).and(other).apply(pCollections()); + } + + @Override + public String getKindString() { + return "Flatten.With"; + } + } + + /** + * Returns a {@link PTransform} that flattens the input {@link PCollection} with the output of + * another {@link PTransform} resulting in a {@link PCollection} containing all the elements of + * both the input {@link PCollection}s and the output of the given {@link PTransform} as its + * output. + * + *

This is equivalent to creating a {@link PCollectionList} containing both the input and the + * output of {@code other} and then applying {@link #pCollections()}, but has the advantage that + * it can be more easily used inline. + * + *

Both {@code PCollections} must have equal {@link WindowFn}s. The output elements of {@code + * Flatten} are in the same windows and have the same timestamps as their corresponding input + * elements. The output {@code PCollection} will have the same {@link WindowFn} as both inputs. + * + * @param the type of the elements in the input and output {@code PCollection}s. + * @param other a PTransform whose ouptput should be flattened with the input + */ + public static PTransform, PCollection> with( + PTransform> other) { + return new PTransform, PCollection>() { + @Override + public PCollection expand(PCollection input) { + return PCollectionList.of(input) + .and(input.getPipeline().apply(other)) + .apply(pCollections()); + } + + @Override + public String getKindString() { + return "Flatten.With"; + } + }; + } + /** * A {@link PTransform} that flattens a {@link PCollectionList} into a {@link PCollection} * containing all the elements of all the {@link PCollection}s in its input. Implements {@link diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java new file mode 100644 index 000000000000..492a1cc84f74 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.function.Consumer; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; + +/** + * A PTransform that returns its input, but also applies its input to an auxiliary PTransform, akin + * to the shell {@code tee} command, which is named after the T-splitter used in plumbing. + * + *

This can be useful to write out or otherwise process an intermediate transform without + * breaking the linear flow of a chain of transforms, e.g. + * + *


+ * {@literal PCollection} input = ... ;
+ * {@literal PCollection} result =
+ *     {@literal input.apply(...)}
+ *     ...
+ *     {@literal input.apply(Tee.of(someSideTransform)}
+ *     ...
+ *     {@literal input.apply(...)};
+ * 
+ * + * @param the element type of the input PCollection + */ +public class Tee extends PTransform, PCollection> { + private final PTransform, ?> consumer; + + /** + * Returns a new Tee PTransform that will apply an auxilary transform to the input as well as pass + * it on. + * + * @param consumer An additional PTransform that should process the input PCollection. Its output + * will be ignored. + * @param the type of the elements in the input {@code PCollection}. + */ + public static Tee of(PTransform, ?> consumer) { + return new Tee<>(consumer); + } + + /** + * Returns a new Tee PTransform that will apply an auxilary transform to the input as well as pass + * it on. + * + * @param consumer An arbitrary {@link Consumer} that will be wrapped in a PTransform and applied + * to the input. Its output will be ignored. + * @param the type of the elements in the input {@code PCollection}. + */ + public static Tee of(Consumer> consumer) { + return of( + new PTransform, PCollectionTuple>() { + @Override + public PCollectionTuple expand(PCollection input) { + consumer.accept(input); + return PCollectionTuple.empty(input.getPipeline()); + } + }); + } + + private Tee(PTransform, ?> consumer) { + this.consumer = consumer; + } + + @Override + public PCollection expand(PCollection input) { + input.apply(consumer); + return input; + } + + @Override + protected String getKindString() { + return "Tee(" + consumer.getName() + ")"; + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/HistogramData.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/HistogramData.java index 65ccda06be65..c1ac4bcfba23 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/HistogramData.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/HistogramData.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.util; +import com.google.api.services.dataflow.model.DataflowHistogramValue; import com.google.auto.value.AutoValue; import com.google.auto.value.extension.memoized.Memoized; import java.io.Serializable; @@ -24,6 +25,8 @@ import java.util.Arrays; import java.util.Objects; import javax.annotation.concurrent.GuardedBy; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.IntMath; import org.checkerframework.checker.nullness.qual.Nullable; @@ -74,6 +77,41 @@ public HistogramData(BucketType bucketType) { this.sumOfSquaredDeviations = 0; } + /** + * Create a histogram from DataflowHistogramValue proto. + * + * @param histogramProto DataflowHistogramValue proto used to populate stats for the histogram. + */ + public HistogramData(DataflowHistogramValue histogramProto) { + int numBuckets; + if (histogramProto.getBucketOptions().getLinear() != null) { + double start = histogramProto.getBucketOptions().getLinear().getStart(); + double width = histogramProto.getBucketOptions().getLinear().getWidth(); + numBuckets = histogramProto.getBucketOptions().getLinear().getNumberOfBuckets(); + this.bucketType = LinearBuckets.of(start, width, numBuckets); + this.buckets = new long[bucketType.getNumBuckets()]; + + int idx = 0; + for (long val : histogramProto.getBucketCounts()) { + this.buckets[idx] = val; + this.numBoundedBucketRecords += val; + idx++; + } + } else { + // Assume it's a exponential histogram if its not linear + int scale = histogramProto.getBucketOptions().getExponential().getScale(); + numBuckets = histogramProto.getBucketOptions().getExponential().getNumberOfBuckets(); + this.bucketType = ExponentialBuckets.of(scale, numBuckets); + this.buckets = new long[bucketType.getNumBuckets()]; + int idx = 0; + for (long val : histogramProto.getBucketCounts()) { + this.buckets[idx] = val; + this.numBoundedBucketRecords += val; + idx++; + } + } + } + public BucketType getBucketType() { return this.bucketType; } @@ -293,6 +331,10 @@ public synchronized long getTopBucketCount() { return numTopRecords; } + public synchronized long[] getBucketCount() { + return buckets; + } + public synchronized double getTopBucketMean() { return numTopRecords == 0 ? 0 : topRecordsSum / numTopRecords; } @@ -573,6 +615,42 @@ public double getRangeTo() { // Note: equals() and hashCode() are implemented by the AutoValue. } + /** Used for testing unsupported Bucket formats. */ + @AutoValue + @Internal + @VisibleForTesting + public abstract static class UnsupportedBuckets implements BucketType { + + public static UnsupportedBuckets of() { + return new AutoValue_HistogramData_UnsupportedBuckets(0); + } + + @Override + public int getBucketIndex(double value) { + return 0; + } + + @Override + public double getBucketSize(int index) { + return 0; + } + + @Override + public double getAccumulatedBucketSize(int index) { + return 0; + } + + @Override + public double getRangeFrom() { + return 0; + } + + @Override + public double getRangeTo() { + return 0; + } + } + @Override public synchronized boolean equals(@Nullable Object object) { if (object instanceof HistogramData) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java index cd38da100a79..0999f2ad0771 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java @@ -45,9 +45,6 @@ *
  • Return {@link CompletableFuture} only to the producer of a future value. * */ -@SuppressWarnings({ - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) public class MoreFutures { /** @@ -99,22 +96,18 @@ public static boolean isCancelled(CompletionStage future) { */ public static CompletionStage supplyAsync( ThrowingSupplier supplier, ExecutorService executorService) { - CompletableFuture result = new CompletableFuture<>(); - - CompletionStage wrapper = - CompletableFuture.runAsync( - () -> { - try { - result.complete(supplier.get()); - } catch (InterruptedException e) { - result.completeExceptionally(e); - Thread.currentThread().interrupt(); - } catch (Throwable t) { - result.completeExceptionally(t); - } - }, - executorService); - return wrapper.thenCompose(nothing -> result); + return CompletableFuture.supplyAsync( + () -> { + try { + return supplier.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CompletionException(e); + } catch (Throwable t) { + throw new CompletionException(t); + } + }, + executorService); } /** @@ -132,23 +125,18 @@ public static CompletionStage supplyAsync(ThrowingSupplier supplier) { */ public static CompletionStage runAsync( ThrowingRunnable runnable, ExecutorService executorService) { - CompletableFuture result = new CompletableFuture<>(); - - CompletionStage wrapper = - CompletableFuture.runAsync( - () -> { - try { - runnable.run(); - result.complete(null); - } catch (InterruptedException e) { - result.completeExceptionally(e); - Thread.currentThread().interrupt(); - } catch (Throwable t) { - result.completeExceptionally(t); - } - }, - executorService); - return wrapper.thenCompose(nothing -> result); + return CompletableFuture.runAsync( + () -> { + try { + runnable.run(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CompletionException(e); + } catch (Throwable t) { + throw new CompletionException(t); + } + }, + executorService); } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java index aeb76492bb6d..c2d945bbaac1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java @@ -44,6 +44,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSortedSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; +import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for working with with {@link Class Classes} and {@link Method Methods}. */ @SuppressWarnings({"nullness", "keyfor"}) // TODO(https://github.com/apache/beam/issues/20497) @@ -216,7 +217,7 @@ public static Iterable loadServicesOrdered(Class iface) { * which by default would use the proposed {@code ClassLoader}, which can be null. The fallback is * as follows: context ClassLoader, class ClassLoader and finally the system ClassLoader. */ - public static ClassLoader findClassLoader(final ClassLoader proposed) { + public static ClassLoader findClassLoader(@Nullable final ClassLoader proposed) { ClassLoader classLoader = proposed; if (classLoader == null) { classLoader = ReflectHelpers.class.getClassLoader(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/GroupIntoBatchesTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/GroupIntoBatchesTranslation.java index 499c7fd21f51..7129854d44cc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/GroupIntoBatchesTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/GroupIntoBatchesTranslation.java @@ -83,7 +83,7 @@ private static GroupIntoBatchesPayload getPayloadFromParameters( return RunnerApi.GroupIntoBatchesPayload.newBuilder() .setBatchSize(params.getBatchSize()) .setBatchSizeBytes(params.getBatchSizeBytes()) - .setMaxBufferingDurationMillis(params.getMaxBufferingDuration().getStandardSeconds() * 1000) + .setMaxBufferingDurationMillis(params.getMaxBufferingDuration().getMillis()) .build(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java index cd6ab7dd414a..de1717f0a45f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java @@ -43,6 +43,9 @@ public class PipelineOptionsTranslation { new ObjectMapper() .registerModules(ObjectMapper.findModules(ReflectHelpers.findClassLoader())); + public static final String PIPELINE_OPTIONS_URN_PREFIX = "beam:option:"; + public static final String PIPELINE_OPTIONS_URN_SUFFIX = ":v1"; + /** Converts the provided {@link PipelineOptions} to a {@link Struct}. */ public static Struct toProto(PipelineOptions options) { Struct.Builder builder = Struct.newBuilder(); @@ -65,9 +68,9 @@ public static Struct toProto(PipelineOptions options) { while (optionsEntries.hasNext()) { Map.Entry entry = optionsEntries.next(); optionsUsingUrns.put( - "beam:option:" + PIPELINE_OPTIONS_URN_PREFIX + CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, entry.getKey()) - + ":v1", + + PIPELINE_OPTIONS_URN_SUFFIX, entry.getValue()); } @@ -92,7 +95,9 @@ public static PipelineOptions fromProto(Struct protoOptions) { mapWithoutUrns.put( CaseFormat.LOWER_UNDERSCORE.to( CaseFormat.LOWER_CAMEL, - optionKey.substring("beam:option:".length(), optionKey.length() - ":v1".length())), + optionKey.substring( + PIPELINE_OPTIONS_URN_PREFIX.length(), + optionKey.length() - PIPELINE_OPTIONS_URN_SUFFIX.length())), optionValue); } return MAPPER.readValue( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index ee3852d70bbe..591a83600561 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.values.RowUtils.RowFieldMatcher; import org.apache.beam.sdk.values.RowUtils.RowPosition; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.ReadableDateTime; @@ -771,6 +772,7 @@ public FieldValueBuilder withFieldValue( checkState(values.isEmpty()); return new FieldValueBuilder(schema, null).withFieldValue(fieldAccessDescriptor, value); } + /** * Sets field values using the field names. Nested values can be set using the field selection * syntax. @@ -836,10 +838,10 @@ public int nextFieldId() { } @Internal - public Row withFieldValueGetters( - Factory> fieldValueGetterFactory, Object getterTarget) { + public <@NonNull T> Row withFieldValueGetters( + Factory>> fieldValueGetterFactory, T getterTarget) { checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + return new RowWithGetters<>(schema, fieldValueGetterFactory, getterTarget); } public Row build() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index 9731507fb0f6..35e0ac20d3f7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -42,13 +42,13 @@ * the appropriate fields from the POJO. */ @SuppressWarnings("rawtypes") -public class RowWithGetters extends Row { - private final Object getterTarget; - private final List getters; +public class RowWithGetters extends Row { + private final T getterTarget; + private final List> getters; private @Nullable Map cache = null; RowWithGetters( - Schema schema, Factory> getterFactory, Object getterTarget) { + Schema schema, Factory>> getterFactory, T getterTarget) { super(schema); this.getterTarget = getterTarget; this.getters = getterFactory.create(TypeDescriptor.of(getterTarget.getClass()), schema); @@ -56,7 +56,7 @@ public class RowWithGetters extends Row { @Override @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) - public @Nullable T getValue(int fieldIdx) { + public W getValue(int fieldIdx) { Field field = getSchema().getField(fieldIdx); boolean cacheField = cacheFieldType(field); @@ -64,7 +64,7 @@ public class RowWithGetters extends Row { cache = new TreeMap<>(); } - Object fieldValue; + @Nullable Object fieldValue; if (cacheField) { if (cache == null) { cache = new TreeMap<>(); @@ -72,15 +72,12 @@ public class RowWithGetters extends Row { fieldValue = cache.computeIfAbsent( fieldIdx, - new Function() { + new Function() { @Override - public Object apply(Integer idx) { - FieldValueGetter getter = getters.get(idx); + public @Nullable Object apply(Integer idx) { + FieldValueGetter getter = getters.get(idx); checkStateNotNull(getter); - @SuppressWarnings("nullness") - @NonNull - Object value = getter.get(getterTarget); - return value; + return getter.get(getterTarget); } }); } else { @@ -90,7 +87,7 @@ public Object apply(Integer idx) { if (fieldValue == null && !field.getType().getNullable()) { throw new RuntimeException("Null value set on non-nullable field " + field); } - return (T) fieldValue; + return (W) fieldValue; } private boolean cacheFieldType(Field field) { @@ -116,7 +113,7 @@ public int getFieldCount() { return rawValues; } - public List getGetters() { + public List> getGetters() { return getters; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java index 3a7a0d5a8935..37580824b558 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java @@ -280,6 +280,7 @@ public void testFailedProcessingCausesAdditionalInboundDataToBeIgnored() throws DESCRIPTOR, OutboundObserverFactory.clientDirect(), inboundObserver -> TestStreams.withOnNext(outboundValues::add).build()); + final AtomicBoolean closed = new AtomicBoolean(); multiplexer.registerConsumer( DATA_INSTRUCTION_ID, new CloseableFnDataReceiver() { @@ -290,7 +291,7 @@ public void flush() throws Exception { @Override public void close() throws Exception { - fail("Unexpected call"); + closed.set(true); } @Override @@ -320,6 +321,7 @@ public void accept(BeamFnApi.Elements input) throws Exception { dataInboundValues, Matchers.contains( BeamFnApi.Elements.newBuilder().addData(data.setTransformId("A").build()).build())); + assertTrue(closed.get()); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java index 750d43a4f9ae..662c4f52628a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java @@ -24,7 +24,9 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.hasItem; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,6 +41,7 @@ import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.UsesAttemptedMetrics; @@ -245,6 +248,24 @@ public void testCounterToCell() { counter.dec(5L); verify(mockCounter).inc(-5); } + + @Test + public void testMetricsFlag() { + Metrics.resetDefaultPipelineOptions(); + assertFalse(Metrics.MetricsFlag.counterDisabled()); + assertFalse(Metrics.MetricsFlag.stringSetDisabled()); + PipelineOptions options = + PipelineOptionsFactory.fromArgs("--experiments=disableCounterMetrics").create(); + Metrics.setDefaultPipelineOptions(options); + assertTrue(Metrics.MetricsFlag.counterDisabled()); + assertFalse(Metrics.MetricsFlag.stringSetDisabled()); + Metrics.resetDefaultPipelineOptions(); + options = PipelineOptionsFactory.fromArgs("--experiments=disableStringSetMetrics").create(); + Metrics.setDefaultPipelineOptions(options); + assertFalse(Metrics.MetricsFlag.counterDisabled()); + assertTrue(Metrics.MetricsFlag.stringSetDisabled()); + Metrics.resetDefaultPipelineOptions(); + } } /** Tests for committed metrics. */ diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java new file mode 100644 index 000000000000..26e3278df025 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas; + +import static org.junit.Assert.assertEquals; + +import java.util.Map; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.junit.Test; + +public class FieldValueTypeInformationTest { + public static class GenericClass { + public T t; + + public GenericClass(T t) { + this.t = t; + } + + public T getT() { + return t; + } + + public void setT(T t) { + this.t = t; + } + } + + private final TypeDescriptor>> typeDescriptor = + new TypeDescriptor>>() {}; + private final TypeDescriptor> expectedFieldTypeDescriptor = + new TypeDescriptor>() {}; + + @Test + public void testForGetter() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forGetter( + typeDescriptor, GenericClass.class.getMethod("getT"), 0); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } + + @Test + public void testForField() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forField(typeDescriptor, GenericClass.class.getField("t"), 0); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } + + @Test + public void testForSetter() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forSetter( + typeDescriptor, GenericClass.class.getMethod("setT", Object.class)); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java index 021e39b84849..7e9cf9a894b9 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.PrimitiveMapBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBean; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.junit.Test; @@ -142,11 +143,11 @@ public void testGeneratedSimpleGetters() { simpleBean.setBigDecimal(new BigDecimal(42)); simpleBean.setStringBuilder(new StringBuilder("stringBuilder")); - List getters = + List> getters = JavaBeanUtils.getGetters( new TypeDescriptor() {}, SIMPLE_BEAN_SCHEMA, - new JavaBeanSchema.GetterTypeSupplier(), + new GetterTypeSupplier(), new DefaultTypeConversionsFactory()); assertEquals(12, getters.size()); assertEquals("str", getters.get(0).name()); @@ -220,7 +221,7 @@ public void testGeneratedSimpleBoxedGetters() { bean.setaLong(44L); bean.setaBoolean(true); - List getters = + List> getters = JavaBeanUtils.getGetters( new TypeDescriptor() {}, BEAN_WITH_BOXED_FIELDS_SCHEMA, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java index 723353ed8d15..378cdc06805f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java @@ -52,6 +52,7 @@ import org.apache.beam.sdk.schemas.utils.TestPOJOs.PrimitiveMapPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.SimplePOJO; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.joda.time.Instant; import org.junit.Test; @@ -158,7 +159,7 @@ public void testGeneratedSimpleGetters() { new BigDecimal(42), new StringBuilder("stringBuilder")); - List getters = + List> getters = POJOUtils.getGetters( new TypeDescriptor() {}, SIMPLE_POJO_SCHEMA, @@ -184,7 +185,7 @@ public void testGeneratedSimpleGetters() { @Test public void testGeneratedSimpleSetters() { SimplePOJO simplePojo = new SimplePOJO(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, SIMPLE_POJO_SCHEMA, @@ -223,7 +224,7 @@ public void testGeneratedSimpleSetters() { public void testGeneratedSimpleBoxedGetters() { POJOWithBoxedFields pojo = new POJOWithBoxedFields((byte) 41, (short) 42, 43, 44L, true); - List getters = + List> getters = POJOUtils.getGetters( new TypeDescriptor() {}, POJO_WITH_BOXED_FIELDS_SCHEMA, @@ -239,7 +240,7 @@ public void testGeneratedSimpleBoxedGetters() { @Test public void testGeneratedSimpleBoxedSetters() { POJOWithBoxedFields pojo = new POJOWithBoxedFields(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, POJO_WITH_BOXED_FIELDS_SCHEMA, @@ -262,7 +263,7 @@ public void testGeneratedSimpleBoxedSetters() { @Test public void testGeneratedByteBufferSetters() { POJOWithByteArray pojo = new POJOWithByteArray(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, POJO_WITH_BYTE_ARRAY_SCHEMA, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java index 282a41bed0dc..7a02d95a5046 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java @@ -402,6 +402,32 @@ public void testFlattenWithDifferentInputAndOutputCoders2() { ///////////////////////////////////////////////////////////////////////////// + @Test + @Category(NeedsRunner.class) + public void testFlattenWithPCollection() { + PCollection output = + p.apply(Create.of(LINES)) + .apply("FlattenWithLines1", Flatten.with(p.apply("Create1", Create.of(LINES)))) + .apply("FlattenWithLines2", Flatten.with(p.apply("Create2", Create.of(LINES2)))); + + PAssert.that(output).containsInAnyOrder(flattenLists(Arrays.asList(LINES, LINES2, LINES))); + p.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testFlattenWithPTransform() { + PCollection output = + p.apply(Create.of(LINES)) + .apply("Create1", Flatten.with(Create.of(LINES))) + .apply("Create2", Flatten.with(Create.of(LINES2))); + + PAssert.that(output).containsInAnyOrder(flattenLists(Arrays.asList(LINES, LINES2, LINES))); + p.run(); + } + + ///////////////////////////////////////////////////////////////////////////// + @Test @Category(NeedsRunner.class) public void testEqualWindowFnPropagation() { @@ -470,6 +496,7 @@ public void testIncompatibleWindowFnPropagationFailure() { public void testFlattenGetName() { Assert.assertEquals("Flatten.Iterables", Flatten.iterables().getName()); Assert.assertEquals("Flatten.PCollections", Flatten.pCollections().getName()); + Assert.assertEquals("Flatten.With", Flatten.with((PCollection) null).getName()); } ///////////////////////////////////////////////////////////////////////////// diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java new file mode 100644 index 000000000000..ee3a00c46caa --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.UUID; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashMultimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimaps; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for Tee. */ +@RunWith(JUnit4.class) +public class TeeTest { + + @Rule public final transient TestPipeline p = TestPipeline.create(); + + @Test + @Category(NeedsRunner.class) + public void testTee() { + List elements = Arrays.asList("a", "b", "c"); + CollectToMemory collector = new CollectToMemory<>(); + PCollection output = p.apply(Create.of(elements)).apply(Tee.of(collector)); + + PAssert.that(output).containsInAnyOrder(elements); + p.run().waitUntilFinish(); + + // Here we assert that this "sink" had the correct side effects. + assertThat(collector.get(), containsInAnyOrder(elements.toArray(new String[3]))); + } + + private static class CollectToMemory extends PTransform, PCollection> { + + private static final Multimap ALL_ELEMENTS = + Multimaps.synchronizedMultimap(HashMultimap.create()); + + UUID uuid = UUID.randomUUID(); + + @Override + public PCollection expand(PCollection input) { + return input.apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + ALL_ELEMENTS.put(uuid, c.element()); + } + })); + } + + @SuppressWarnings("unchecked") + public Collection get() { + return (Collection) ALL_ELEMENTS.get(uuid); + } + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java index 296a53f48e80..fd178f8e7649 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Objects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -226,5 +227,19 @@ public long getNum() { public String getStr() { return this.str; } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Pojo)) { + return false; + } + Pojo pojo = (Pojo) o; + return num == pojo.num && Objects.equals(str, pojo.str); + } + + @Override + public int hashCode() { + return Objects.hash(num, str); + } } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/GroupIntoBatchesTranslationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/GroupIntoBatchesTranslationTest.java index cb2054e09144..65f571467ca4 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/GroupIntoBatchesTranslationTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/GroupIntoBatchesTranslationTest.java @@ -71,6 +71,7 @@ public static Iterable> transform() { ImmutableSet.of( // gib -> gib, // gib -> gib.withMaxBufferingDuration(Duration.ZERO), // + gib -> gib.withMaxBufferingDuration(Duration.millis(200)), // gib -> gib.withMaxBufferingDuration(Duration.standardSeconds(10))); return Sets.cartesianProduct( @@ -150,7 +151,7 @@ private void verifyPayload( assertThat(payload.getBatchSize(), equalTo(params.getBatchSize())); assertThat(payload.getBatchSizeBytes(), equalTo(params.getBatchSizeBytes())); assertThat( - payload.getMaxBufferingDurationMillis(), - equalTo(params.getMaxBufferingDuration().getStandardSeconds() * 1000)); + Duration.millis(payload.getMaxBufferingDurationMillis()), + equalTo(params.getMaxBufferingDuration())); } } diff --git a/sdks/java/expansion-service/build.gradle b/sdks/java/expansion-service/build.gradle index 4dd8c8968ed9..a25583870acf 100644 --- a/sdks/java/expansion-service/build.gradle +++ b/sdks/java/expansion-service/build.gradle @@ -57,3 +57,7 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.main.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } + +compileJava { + outputs.upToDateWhen { false } +} \ No newline at end of file diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java index 770da14fa1cf..9c5b5a0ad136 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java @@ -60,7 +60,6 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.PortablePipelineOptions; -import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; @@ -143,7 +142,7 @@ public static class ExternalTransformRegistrarLoader public Map knownTransforms() { Map providers = new HashMap<>(); - // First check and register ExternalTransformBuilder in serviceloader style, converting + // First check and register ExternalTransformBuilder in ServiceLoader style, converting // to TransformProvider after validation. Map registeredBuilders = loadTransformBuilders(); for (Map.Entry registeredBuilder : @@ -535,7 +534,7 @@ private static void invokeSetter(ConfigT config, @Nullable Object valu } private @MonotonicNonNull Map registeredTransforms; - private final PipelineOptions pipelineOptions; + private final PipelineOptions commandLineOptions; private final @Nullable String loopbackAddress; public ExpansionService() { @@ -551,7 +550,7 @@ public ExpansionService(PipelineOptions opts) { } public ExpansionService(PipelineOptions opts, @Nullable String loopbackAddress) { - this.pipelineOptions = opts; + this.commandLineOptions = opts; this.loopbackAddress = loopbackAddress; } @@ -587,12 +586,15 @@ private Map loadRegisteredTransforms() { request.getTransform().getSpec().getUrn()); LOG.debug("Full transform: {}", request.getTransform()); Set existingTransformIds = request.getComponents().getTransformsMap().keySet(); - Pipeline pipeline = - createPipeline(PipelineOptionsTranslation.fromProto(request.getPipelineOptions())); + + PipelineOptions pipelineOptionsFromRequest = + PipelineOptionsTranslation.fromProto(request.getPipelineOptions()); + Pipeline pipeline = createPipeline(pipelineOptionsFromRequest); + boolean isUseDeprecatedRead = - ExperimentalOptions.hasExperiment(pipelineOptions, "use_deprecated_read") + ExperimentalOptions.hasExperiment(commandLineOptions, "use_deprecated_read") || ExperimentalOptions.hasExperiment( - pipelineOptions, "beam_fn_api_use_deprecated_read"); + commandLineOptions, "beam_fn_api_use_deprecated_read"); if (!isUseDeprecatedRead) { ExperimentalOptions.addExperiment( pipeline.getOptions().as(ExperimentalOptions.class), "beam_fn_api"); @@ -629,7 +631,7 @@ private Map loadRegisteredTransforms() { if (transformProvider == null) { if (getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP).equals(urn)) { AllowList allowList = - pipelineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist(); + commandLineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist(); assert allowList != null; transformProvider = new JavaClassLookupTransformProvider(allowList); } else if (getUrn(SCHEMA_TRANSFORM).equals(urn)) { @@ -671,7 +673,7 @@ private Map loadRegisteredTransforms() { RunnerApi.Environment defaultEnvironment = Environments.createOrGetDefaultEnvironment( pipeline.getOptions().as(PortablePipelineOptions.class)); - if (pipelineOptions.as(ExpansionServiceOptions.class).getAlsoStartLoopbackWorker()) { + if (commandLineOptions.as(ExpansionServiceOptions.class).getAlsoStartLoopbackWorker()) { PortablePipelineOptions externalOptions = PipelineOptionsFactory.create().as(PortablePipelineOptions.class); externalOptions.setDefaultEnvironmentType(Environments.ENVIRONMENT_EXTERNAL); @@ -723,35 +725,34 @@ private Map loadRegisteredTransforms() { } protected Pipeline createPipeline(PipelineOptions requestOptions) { - // TODO: [https://github.com/apache/beam/issues/21064]: implement proper validation - PipelineOptions effectiveOpts = PipelineOptionsFactory.create(); - PortablePipelineOptions portableOptions = effectiveOpts.as(PortablePipelineOptions.class); - PortablePipelineOptions specifiedOptions = pipelineOptions.as(PortablePipelineOptions.class); - Optional.ofNullable(specifiedOptions.getDefaultEnvironmentType()) - .ifPresent(portableOptions::setDefaultEnvironmentType); - Optional.ofNullable(specifiedOptions.getDefaultEnvironmentConfig()) - .ifPresent(portableOptions::setDefaultEnvironmentConfig); - List filesToStage = specifiedOptions.getFilesToStage(); + // We expect the ExpansionRequest to contain a valid set of options to be used for this + // expansion. + // Additionally, we override selected options using options values set via command line or + // ExpansionService wide overrides. + + PortablePipelineOptions requestPortablePipelineOptions = + requestOptions.as(PortablePipelineOptions.class); + PortablePipelineOptions commandLinePortablePipelineOptions = + commandLineOptions.as(PortablePipelineOptions.class); + Optional.ofNullable(commandLinePortablePipelineOptions.getDefaultEnvironmentType()) + .ifPresent(requestPortablePipelineOptions::setDefaultEnvironmentType); + Optional.ofNullable(commandLinePortablePipelineOptions.getDefaultEnvironmentConfig()) + .ifPresent(requestPortablePipelineOptions::setDefaultEnvironmentConfig); + List filesToStage = commandLinePortablePipelineOptions.getFilesToStage(); if (filesToStage != null) { - effectiveOpts.as(PortablePipelineOptions.class).setFilesToStage(filesToStage); + requestPortablePipelineOptions + .as(PortablePipelineOptions.class) + .setFilesToStage(filesToStage); } - effectiveOpts + requestPortablePipelineOptions .as(ExperimentalOptions.class) - .setExperiments(pipelineOptions.as(ExperimentalOptions.class).getExperiments()); - effectiveOpts.setRunner(NotRunnableRunner.class); - effectiveOpts + .setExperiments(commandLineOptions.as(ExperimentalOptions.class).getExperiments()); + requestPortablePipelineOptions.setRunner(NotRunnableRunner.class); + requestPortablePipelineOptions .as(ExpansionServiceOptions.class) .setExpansionServiceConfig( - pipelineOptions.as(ExpansionServiceOptions.class).getExpansionServiceConfig()); - // TODO(https://github.com/apache/beam/issues/20090): Figure out the correct subset of options - // to propagate. - if (requestOptions.as(StreamingOptions.class).getUpdateCompatibilityVersion() != null) { - effectiveOpts - .as(StreamingOptions.class) - .setUpdateCompatibilityVersion( - requestOptions.as(StreamingOptions.class).getUpdateCompatibilityVersion()); - } - return Pipeline.create(effectiveOpts); + commandLineOptions.as(ExpansionServiceOptions.class).getExpansionServiceConfig()); + return Pipeline.create(requestOptions); } @Override diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java index 1c8d515d5c85..9ee0c2c1797b 100644 --- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.expansion.service; +import static org.apache.beam.sdk.util.construction.PipelineOptionsTranslation.PIPELINE_OPTIONS_URN_PREFIX; +import static org.apache.beam.sdk.util.construction.PipelineOptionsTranslation.PIPELINE_OPTIONS_URN_SUFFIX; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.contains; @@ -49,6 +51,8 @@ import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; @@ -58,15 +62,20 @@ import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Struct; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Value; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Resources; import org.checkerframework.checker.nullness.qual.Nullable; import org.hamcrest.Matchers; +import org.junit.Assert; import org.junit.Test; /** Tests for {@link ExpansionService}. */ @@ -76,6 +85,7 @@ public class ExpansionServiceTest { private static final String TEST_URN = "test:beam:transforms:count"; + private static final String TEST_OPTIONS_URN = "test:beam:transforms:test_options"; private static final String TEST_NAME = "TestName"; @@ -98,9 +108,59 @@ public class ExpansionServiceTest { @AutoService(ExpansionService.ExpansionServiceRegistrar.class) public static class TestTransformRegistrar implements ExpansionService.ExpansionServiceRegistrar { + static final String EXPECTED_STRING_VALUE = "abcde"; + static final Boolean EXPECTED_BOOLEAN_VALUE = true; + static final Integer EXPECTED_INTEGER_VALUE = 12345; + @Override public Map knownTransforms() { - return ImmutableMap.of(TEST_URN, (spec, options) -> Count.perElement()); + return ImmutableMap.of( + TEST_URN, (spec, options) -> Count.perElement(), + TEST_OPTIONS_URN, + (spec, options) -> + new TestOptionsTransform( + EXPECTED_STRING_VALUE, EXPECTED_BOOLEAN_VALUE, EXPECTED_INTEGER_VALUE)); + } + } + + public interface TestOptions extends PipelineOptions { + String getStringOption(); + + void setStringOption(String value); + + Boolean getBooleanOption(); + + void setBooleanOption(Boolean value); + + Integer getIntegerOption(); + + void setIntegerOption(Integer value); + } + + public static class TestOptionsTransform + extends PTransform, PCollection> { + String expectedStringValue; + + Boolean expectedBooleanValue; + + Integer expectedIntegerValue; + + public TestOptionsTransform( + String expectedStringValue, Boolean expectedBooleanValue, Integer expectedIntegerValue) { + this.expectedStringValue = expectedStringValue; + this.expectedBooleanValue = expectedBooleanValue; + this.expectedIntegerValue = expectedIntegerValue; + } + + @Override + public PCollection expand(PCollection input) { + TestOptions testOption = input.getPipeline().getOptions().as(TestOptions.class); + + Assert.assertEquals(expectedStringValue, testOption.getStringOption()); + Assert.assertEquals(expectedBooleanValue, testOption.getBooleanOption()); + Assert.assertEquals(expectedIntegerValue, testOption.getIntegerOption()); + + return input; } } @@ -146,6 +206,58 @@ public void testConstruct() { } } + @Test + public void testConstructWithPipelineOptions() { + PipelineOptionsFactory.register(TestOptions.class); + Pipeline p = Pipeline.create(); + p.apply(Impulse.create()); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + String inputPcollId = + Iterables.getOnlyElement( + Iterables.getOnlyElement(pipelineProto.getComponents().getTransformsMap().values()) + .getOutputsMap() + .values()); + + Struct optionsStruct = + Struct.newBuilder() + .putFields( + PIPELINE_OPTIONS_URN_PREFIX + "string_option" + PIPELINE_OPTIONS_URN_SUFFIX, + Value.newBuilder() + .setStringValue(TestTransformRegistrar.EXPECTED_STRING_VALUE) + .build()) + .putFields( + PIPELINE_OPTIONS_URN_PREFIX + "boolean_option" + PIPELINE_OPTIONS_URN_SUFFIX, + Value.newBuilder() + .setBoolValue(TestTransformRegistrar.EXPECTED_BOOLEAN_VALUE) + .build()) + .putFields( + PIPELINE_OPTIONS_URN_PREFIX + "integer_option" + PIPELINE_OPTIONS_URN_SUFFIX, + Value.newBuilder() + .setNumberValue(TestTransformRegistrar.EXPECTED_INTEGER_VALUE) + .build()) + .build(); + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setPipelineOptions(optionsStruct) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(TEST_OPTIONS_URN)) + .putInputs("input", inputPcollId)) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + + // Verify it has the right input. + assertThat(expandedTransform.getInputsMap().values(), contains(inputPcollId)); + + // Verify it has the right output. + assertThat(expandedTransform.getOutputsMap().keySet(), contains("output")); + } + @Test public void testConstructGenerateSequenceWithRegistration() { ExternalTransforms.ExternalConfigurationPayload payload = diff --git a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java index 4b6538157fd0..78ba610ad4d1 100644 --- a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java +++ b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java @@ -276,11 +276,11 @@ public static class RecordBatchRowIterator implements Iterator, AutoCloseab new ArrowValueConverterVisitor(); private final Schema schema; private final VectorSchemaRoot vectorSchemaRoot; - private final Factory> fieldValueGetters; + private final Factory>> fieldValueGetters; private Integer currRowIndex; private static class FieldVectorListValueGetterFactory - implements Factory> { + implements Factory>> { private final List fieldVectors; static FieldVectorListValueGetterFactory of(List fieldVectors) { @@ -292,7 +292,8 @@ private FieldVectorListValueGetterFactory(List fieldVectors) { } @Override - public List create(TypeDescriptor typeDescriptor, Schema schema) { + public List> create( + TypeDescriptor typeDescriptor, Schema schema) { return this.fieldVectors.stream() .map( (fieldVector) -> { diff --git a/sdks/java/extensions/avro/build.gradle b/sdks/java/extensions/avro/build.gradle index 8ff0612a0eab..6631779e609c 100644 --- a/sdks/java/extensions/avro/build.gradle +++ b/sdks/java/extensions/avro/build.gradle @@ -128,6 +128,7 @@ avroVersions.each { k, v -> description = "Runs Avro extension tests with Avro version $v" outputs.upToDateWhen { false } classpath = sourceSets."avro$k".runtimeClasspath + systemProperty "beam.target.avro.version", v include '**/*.class' exclude '**/AvroIOTest$NeedsRunnerTests$*.class' diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java index e75647a2ccfa..203bcccbf562 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.schemas.SchemaProvider; import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A {@link SchemaProvider} for AVRO generated SpecificRecords and POJOs. @@ -44,8 +45,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return AvroUtils.getGetters(targetTypeDescriptor, schema); } diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java index 1b1c45969307..da7daf605d89 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -94,11 +94,13 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.EnsuresNonNullIf; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; import org.joda.time.Days; import org.joda.time.Duration; import org.joda.time.Instant; @@ -152,6 +154,38 @@ public class AvroUtils { new ForLoadedType(ReadableInstant.class); private static final ForLoadedType JODA_INSTANT = new ForLoadedType(Instant.class); + // contains workarounds for third-party methods that accept nullable arguments but lack proper + // annotations + private static class NullnessCheckerWorkarounds { + + private static ReflectData newReflectData(Class clazz) { + // getClassLoader returns @Nullable Classloader, but it's ok, as ReflectData constructor + // actually tolerates null classloader argument despite lacking the @Nullable annotation + @SuppressWarnings("nullness") + @NonNull + ClassLoader classLoader = clazz.getClassLoader(); + return new ReflectData(classLoader); + } + + private static void builderSet( + GenericRecordBuilder builder, String fieldName, @Nullable Object value) { + // the value argument can actually be null here, it's not annotated as such in the method + // though, hence this wrapper + builder.set(fieldName, castToNonNull(value)); + } + + private static Object createFixed( + @Nullable Object old, byte[] bytes, org.apache.avro.Schema schema) { + // old is tolerated when null, due to an instanceof check + return GenericData.get().createFixed(castToNonNull(old), bytes, schema); + } + + @SuppressWarnings("nullness") + private static @NonNull T castToNonNull(@Nullable T value) { + return value; + } + } + public static void addLogicalTypeConversions(final GenericData data) { // do not add DecimalConversion by default as schema must have extra 'scale' and 'precision' // properties. avro reflect already handles BigDecimal as string with the 'java-class' property @@ -235,7 +269,9 @@ public static FixedBytesField withSize(int size) { /** Create a {@link FixedBytesField} from a Beam {@link FieldType}. */ public static @Nullable FixedBytesField fromBeamFieldType(FieldType fieldType) { if (fieldType.getTypeName().isLogicalType() - && fieldType.getLogicalType().getIdentifier().equals(FixedBytes.IDENTIFIER)) { + && checkNotNull(fieldType.getLogicalType()) + .getIdentifier() + .equals(FixedBytes.IDENTIFIER)) { int length = fieldType.getLogicalType(FixedBytes.class).getLength(); return new FixedBytesField(length); } else { @@ -264,7 +300,7 @@ public FieldType toBeamType() { /** Convert to an AVRO type. */ public org.apache.avro.Schema toAvroType(String name, String namespace) { - return org.apache.avro.Schema.createFixed(name, null, namespace, size); + return org.apache.avro.Schema.createFixed(name, "", namespace, size); } } @@ -463,7 +499,7 @@ private AvroUtils() {} * @param clazz avro class */ public static Schema toBeamSchema(Class clazz) { - ReflectData data = new ReflectData(clazz.getClassLoader()); + ReflectData data = NullnessCheckerWorkarounds.newReflectData(clazz); return toBeamSchema(data.getSchema(clazz)); } @@ -486,10 +522,17 @@ public static Schema toBeamSchema(org.apache.avro.Schema schema) { return builder.build(); } + @EnsuresNonNullIf( + expression = {"#1"}, + result = false) + private static boolean isNullOrEmpty(@Nullable String str) { + return str == null || str.isEmpty(); + } + /** Converts a Beam Schema into an AVRO schema. */ public static org.apache.avro.Schema toAvroSchema( Schema beamSchema, @Nullable String name, @Nullable String namespace) { - final String schemaName = Strings.isNullOrEmpty(name) ? "topLevelRecord" : name; + final String schemaName = isNullOrEmpty(name) ? "topLevelRecord" : name; final String schemaNamespace = namespace == null ? "" : namespace; String childNamespace = !"".equals(schemaNamespace) ? schemaNamespace + "." + schemaName : schemaName; @@ -498,7 +541,7 @@ public static org.apache.avro.Schema toAvroSchema( org.apache.avro.Schema.Field recordField = toAvroField(field, childNamespace); fields.add(recordField); } - return org.apache.avro.Schema.createRecord(schemaName, null, schemaNamespace, false, fields); + return org.apache.avro.Schema.createRecord(schemaName, "", schemaNamespace, false, fields); } public static org.apache.avro.Schema toAvroSchema(Schema beamSchema) { @@ -557,7 +600,8 @@ public static GenericRecord toGenericRecord( GenericRecordBuilder builder = new GenericRecordBuilder(avroSchema); for (int i = 0; i < beamSchema.getFieldCount(); ++i) { Field field = beamSchema.getField(i); - builder.set( + NullnessCheckerWorkarounds.builderSet( + builder, field.getName(), genericFromBeamField( field.getType(), avroSchema.getField(field.getName()).schema(), row.getValue(i))); @@ -567,7 +611,7 @@ public static GenericRecord toGenericRecord( @SuppressWarnings("unchecked") public static SerializableFunction getToRowFunction( - Class clazz, org.apache.avro.@Nullable Schema schema) { + Class clazz, org.apache.avro.Schema schema) { if (GenericRecord.class.equals(clazz)) { Schema beamSchema = toBeamSchema(schema); return (SerializableFunction) getGenericRecordToRowFunction(beamSchema); @@ -662,9 +706,9 @@ public static SerializableFunction getGenericRecordToRowFunc } private static class GenericRecordToRowFn implements SerializableFunction { - private final Schema schema; + private final @Nullable Schema schema; - GenericRecordToRowFn(Schema schema) { + GenericRecordToRowFn(@Nullable Schema schema) { this.schema = schema; } @@ -701,7 +745,7 @@ public static SerializableFunction getRowToGenericRecordFunc } private static class RowToGenericRecordFn implements SerializableFunction { - private transient org.apache.avro.Schema avroSchema; + private transient org.apache.avro.@Nullable Schema avroSchema; RowToGenericRecordFn(org.apache.avro.@Nullable Schema avroSchema) { this.avroSchema = avroSchema; @@ -751,7 +795,8 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE public static SchemaCoder schemaCoder(TypeDescriptor type) { @SuppressWarnings("unchecked") Class clazz = (Class) type.getRawType(); - org.apache.avro.Schema avroSchema = new ReflectData(clazz.getClassLoader()).getSchema(clazz); + org.apache.avro.Schema avroSchema = + NullnessCheckerWorkarounds.newReflectData(clazz).getSchema(clazz); Schema beamSchema = toBeamSchema(avroSchema); return SchemaCoder.of( beamSchema, type, getToRowFunction(clazz, avroSchema), getFromRowFunction(clazz)); @@ -790,7 +835,7 @@ public static SchemaCoder schemaCoder(org.apache.avro.Schema sche */ public static SchemaCoder schemaCoder(Class clazz, org.apache.avro.Schema schema) { return SchemaCoder.of( - getSchema(clazz, schema), + checkNotNull(getSchema(clazz, schema)), TypeDescriptor.of(clazz), getToRowFunction(clazz, schema), getFromRowFunction(clazz)); @@ -821,7 +866,7 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = methods.get(i); if (ReflectUtils.isGetter(method)) { FieldValueTypeInformation fieldValueTypeInformation = - FieldValueTypeInformation.forGetter(method, i); + FieldValueTypeInformation.forGetter(typeDescriptor, method, i); String name = mapping.get(fieldValueTypeInformation.getName()); if (name != null) { types.add(fieldValueTypeInformation.withName(name)); @@ -871,7 +916,8 @@ public List get(TypeDescriptor typeDescriptor) { for (int i = 0; i < classFields.size(); ++i) { java.lang.reflect.Field f = classFields.get(i); if (!f.isAnnotationPresent(AvroIgnore.class)) { - FieldValueTypeInformation typeInformation = FieldValueTypeInformation.forField(f, i); + FieldValueTypeInformation typeInformation = + FieldValueTypeInformation.forField(typeDescriptor, f, i); AvroName avroname = f.getAnnotation(AvroName.class); if (avroname != null) { typeInformation = typeInformation.withName(avroname.value()); @@ -895,7 +941,7 @@ public static List getFieldTypes( } /** Get generated getters for an AVRO-generated SpecificRecord or a POJO. */ - public static List getGetters( + public static List> getGetters( TypeDescriptor typeDescriptor, Schema schema) { if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { return JavaBeanUtils.getGetters( @@ -968,7 +1014,7 @@ private static FieldType toFieldType(TypeWithNullability type) { break; case FIXED: - fieldType = FixedBytesField.fromAvroType(type.type).toBeamType(); + fieldType = checkNotNull(FixedBytesField.fromAvroType(type.type)).toBeamType(); break; case STRING: @@ -1066,7 +1112,8 @@ private static org.apache.avro.Schema getFieldSchema( break; case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + Schema.LogicalType logicalType = checkNotNull(fieldType.getLogicalType()); + String identifier = logicalType.getIdentifier(); if (FixedBytes.IDENTIFIER.equals(identifier)) { FixedBytesField fixedBytesField = checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); @@ -1077,15 +1124,13 @@ private static org.apache.avro.Schema getFieldSchema( } else if (FixedString.IDENTIFIER.equals(identifier) || "CHAR".equals(identifier) || "NCHAR".equals(identifier)) { - baseType = - buildHiveLogicalTypeSchema("char", (int) fieldType.getLogicalType().getArgument()); + baseType = buildHiveLogicalTypeSchema("char", checkNotNull(logicalType.getArgument())); } else if (VariableString.IDENTIFIER.equals(identifier) || "NVARCHAR".equals(identifier) || "VARCHAR".equals(identifier) || "LONGNVARCHAR".equals(identifier) || "LONGVARCHAR".equals(identifier)) { - baseType = - buildHiveLogicalTypeSchema("varchar", (int) fieldType.getLogicalType().getArgument()); + baseType = buildHiveLogicalTypeSchema("varchar", checkNotNull(logicalType.getArgument())); } else if (EnumerationType.IDENTIFIER.equals(identifier)) { EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); baseType = @@ -1103,7 +1148,7 @@ private static org.apache.avro.Schema getFieldSchema( baseType = LogicalTypes.timeMillis().addToSchema(org.apache.avro.Schema.create(Type.INT)); } else { throw new RuntimeException( - "Unhandled logical type " + fieldType.getLogicalType().getIdentifier()); + "Unhandled logical type " + checkNotNull(fieldType.getLogicalType()).getIdentifier()); } break; @@ -1111,22 +1156,23 @@ private static org.apache.avro.Schema getFieldSchema( case ITERABLE: baseType = org.apache.avro.Schema.createArray( - getFieldSchema(fieldType.getCollectionElementType(), fieldName, namespace)); + getFieldSchema( + checkNotNull(fieldType.getCollectionElementType()), fieldName, namespace)); break; case MAP: - if (fieldType.getMapKeyType().getTypeName().isStringType()) { + if (checkNotNull(fieldType.getMapKeyType()).getTypeName().isStringType()) { // Avro only supports string keys in maps. baseType = org.apache.avro.Schema.createMap( - getFieldSchema(fieldType.getMapValueType(), fieldName, namespace)); + getFieldSchema(checkNotNull(fieldType.getMapValueType()), fieldName, namespace)); } else { throw new IllegalArgumentException("Avro only supports maps with string keys"); } break; case ROW: - baseType = toAvroSchema(fieldType.getRowSchema(), fieldName, namespace); + baseType = toAvroSchema(checkNotNull(fieldType.getRowSchema()), fieldName, namespace); break; default: @@ -1167,7 +1213,9 @@ private static org.apache.avro.Schema getFieldSchema( case DECIMAL: BigDecimal decimal = (BigDecimal) value; LogicalType logicalType = typeWithNullability.type.getLogicalType(); - return new Conversions.DecimalConversion().toBytes(decimal, null, logicalType); + @SuppressWarnings("nullness") + ByteBuffer result = new Conversions.DecimalConversion().toBytes(decimal, null, logicalType); + return result; case DATETIME: if (typeWithNullability.type.getType() == Type.INT) { @@ -1185,7 +1233,7 @@ private static org.apache.avro.Schema getFieldSchema( return ByteBuffer.wrap((byte[]) value); case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + String identifier = checkNotNull(fieldType.getLogicalType()).getIdentifier(); if (FixedBytes.IDENTIFIER.equals(identifier)) { FixedBytesField fixedBytesField = checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); @@ -1193,9 +1241,11 @@ private static org.apache.avro.Schema getFieldSchema( if (byteArray.length != fixedBytesField.getSize()) { throw new IllegalArgumentException("Incorrectly sized byte array."); } - return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + return NullnessCheckerWorkarounds.createFixed( + null, (byte[]) value, typeWithNullability.type); } else if (VariableBytes.IDENTIFIER.equals(identifier)) { - return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + return NullnessCheckerWorkarounds.createFixed( + null, (byte[]) value, typeWithNullability.type); } else if (FixedString.IDENTIFIER.equals(identifier) || "CHAR".equals(identifier) || "NCHAR".equals(identifier)) { @@ -1239,26 +1289,27 @@ private static org.apache.avro.Schema getFieldSchema( case ARRAY: case ITERABLE: Iterable iterable = (Iterable) value; - List translatedArray = Lists.newArrayListWithExpectedSize(Iterables.size(iterable)); + List<@Nullable Object> translatedArray = + Lists.newArrayListWithExpectedSize(Iterables.size(iterable)); for (Object arrayElement : iterable) { translatedArray.add( genericFromBeamField( - fieldType.getCollectionElementType(), + checkNotNull(fieldType.getCollectionElementType()), typeWithNullability.type.getElementType(), arrayElement)); } return translatedArray; case MAP: - Map map = Maps.newHashMap(); + Map map = Maps.newHashMap(); Map valueMap = (Map) value; for (Map.Entry entry : valueMap.entrySet()) { - Utf8 key = new Utf8((String) entry.getKey()); + Utf8 key = new Utf8((String) checkNotNull(entry.getKey())); map.put( key, genericFromBeamField( - fieldType.getMapValueType(), + checkNotNull(fieldType.getMapValueType()), typeWithNullability.type.getValueType(), entry.getValue())); } @@ -1282,8 +1333,8 @@ private static org.apache.avro.Schema getFieldSchema( * @return value converted for {@link Row} */ @SuppressWarnings("unchecked") - public static @Nullable Object convertAvroFieldStrict( - @Nullable Object value, + public static @PolyNull Object convertAvroFieldStrict( + @PolyNull Object value, @Nonnull org.apache.avro.Schema avroSchema, @Nonnull FieldType fieldType) { if (value == null) { @@ -1383,7 +1434,8 @@ private static Object convertBytesStrict(ByteBuffer bb, FieldType fieldType) { private static Object convertFixedStrict(GenericFixed fixed, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "fixed"); - checkArgument(FixedBytes.IDENTIFIER.equals(fieldType.getLogicalType().getIdentifier())); + checkArgument( + FixedBytes.IDENTIFIER.equals(checkNotNull(fieldType.getLogicalType()).getIdentifier())); return fixed.bytes().clone(); // clone because GenericFixed is mutable } @@ -1434,7 +1486,10 @@ private static Object convertBooleanStrict(Boolean value, FieldType fieldType) { private static Object convertEnumStrict(Object value, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "enum"); - checkArgument(fieldType.getLogicalType().getIdentifier().equals(EnumerationType.IDENTIFIER)); + checkArgument( + checkNotNull(fieldType.getLogicalType()) + .getIdentifier() + .equals(EnumerationType.IDENTIFIER)); EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); return enumerationType.valueOf(value.toString()); } @@ -1442,7 +1497,8 @@ private static Object convertEnumStrict(Object value, FieldType fieldType) { private static Object convertUnionStrict( Object value, org.apache.avro.Schema unionAvroSchema, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "oneOfType"); - checkArgument(fieldType.getLogicalType().getIdentifier().equals(OneOfType.IDENTIFIER)); + checkArgument( + checkNotNull(fieldType.getLogicalType()).getIdentifier().equals(OneOfType.IDENTIFIER)); OneOfType oneOfType = fieldType.getLogicalType(OneOfType.class); int fieldNumber = GenericData.get().resolveUnion(unionAvroSchema, value); FieldType baseFieldType = oneOfType.getOneOfSchema().getField(fieldNumber).getType(); @@ -1459,7 +1515,7 @@ private static Object convertArrayStrict( FieldType elemFieldType = fieldType.getCollectionElementType(); for (Object value : values) { - ret.add(convertAvroFieldStrict(value, elemAvroSchema, elemFieldType)); + ret.add(convertAvroFieldStrict(value, elemAvroSchema, checkNotNull(elemFieldType))); } return ret; @@ -1470,10 +1526,10 @@ private static Object convertMapStrict( org.apache.avro.Schema valueAvroSchema, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.MAP, "map"); - checkNotNull(fieldType.getMapKeyType()); - checkNotNull(fieldType.getMapValueType()); + FieldType mapKeyType = checkNotNull(fieldType.getMapKeyType()); + FieldType mapValueType = checkNotNull(fieldType.getMapValueType()); - if (!fieldType.getMapKeyType().equals(FieldType.STRING)) { + if (!FieldType.STRING.equals(fieldType.getMapKeyType())) { throw new IllegalArgumentException( "Can't convert 'string' map keys to " + fieldType.getMapKeyType()); } @@ -1482,8 +1538,8 @@ private static Object convertMapStrict( for (Map.Entry value : values.entrySet()) { ret.put( - convertStringStrict(value.getKey(), fieldType.getMapKeyType()), - convertAvroFieldStrict(value.getValue(), valueAvroSchema, fieldType.getMapValueType())); + convertStringStrict(value.getKey(), mapKeyType), + convertAvroFieldStrict(value.getValue(), valueAvroSchema, mapValueType)); } return ret; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/AvroVersionVerificationTest.java similarity index 53% rename from sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java rename to sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/AvroVersionVerificationTest.java index f15fc5307374..f9e9a54b0531 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java +++ b/sdks/java/extensions/avro/src/test/java/org/apache/beam/sdk/extensions/avro/AvroVersionVerificationTest.java @@ -15,27 +15,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.gcp.spanner.changestreams; +package org.apache.beam.sdk.extensions.avro; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertEquals; +import org.apache.avro.Schema; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.Assume; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; -public class NameGeneratorTest { - private static final int MAXIMUM_POSTGRES_TABLE_NAME_LENGTH = 63; - - @Test - public void testGenerateMetadataTableNameRemovesHyphens() { - final String tableName = - NameGenerator.generatePartitionMetadataTableName("my-database-id-12345"); - assertFalse(tableName.contains("-")); - } - +@RunWith(JUnit4.class) +public class AvroVersionVerificationTest { @Test - public void testGenerateMetadataTableNameIsShorterThan64Characters() { - final String tableName = - NameGenerator.generatePartitionMetadataTableName("my-database-id1-maximum-length"); - assertTrue(tableName.length() <= MAXIMUM_POSTGRES_TABLE_NAME_LENGTH); + public void testAvroVersion() { + @Nullable String targetVer = System.getProperty("beam.target.avro.version"); + Assume.assumeTrue(!Strings.isNullOrEmpty(targetVer)); + String actualVer = Schema.class.getPackage().getImplementationVersion(); + assertEquals(targetVer, actualVer); } } diff --git a/sdks/java/extensions/google-cloud-platform-core/build.gradle b/sdks/java/extensions/google-cloud-platform-core/build.gradle index 6cb8d3248ac1..8d21df50006b 100644 --- a/sdks/java/extensions/google-cloud-platform-core/build.gradle +++ b/sdks/java/extensions/google-cloud-platform-core/build.gradle @@ -66,12 +66,10 @@ task integrationTestKms(type: Test) { group = "Verification" def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' def gcpTempRoot = project.findProperty('gcpTempRootKms') ?: 'gs://temp-storage-for-end-to-end-tests-cmek' - def gcpGrpcTempRoot = project.findProperty('gcpGrpcTempRoot') ?: 'gs://gcs-grpc-team-apache-beam-testing' def dataflowKmsKey = project.findProperty('dataflowKmsKey') ?: "projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test" systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ "--project=${gcpProject}", "--tempRoot=${gcpTempRoot}", - "--grpcTempRoot=${gcpGrpcTempRoot}", "--dataflowKmsKey=${dataflowKmsKey}", ]) diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcsOptions.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcsOptions.java index 18d637254115..1285b88663e7 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcsOptions.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcsOptions.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.extensions.gcp.options; import com.fasterxml.jackson.annotation.JsonIgnore; +import com.google.cloud.hadoop.gcsio.GoogleCloudStorageReadOptions; import com.google.cloud.hadoop.util.AsyncWriteChannelOptions; import java.util.concurrent.ExecutorService; import org.apache.beam.sdk.extensions.gcp.storage.GcsPathValidator; @@ -44,6 +45,15 @@ public interface GcsOptions extends ApplicationNameOptions, GcpOptions, Pipeline void setGcsUtil(GcsUtil value); + @JsonIgnore + @Description( + "The GoogleCloudStorageReadOptions instance that should be used to read from Google Cloud Storage.") + @Default.InstanceFactory(GcsUtil.GcsReadOptionsFactory.class) + @Hidden + GoogleCloudStorageReadOptions getGoogleCloudStorageReadOptions(); + + void setGoogleCloudStorageReadOptions(GoogleCloudStorageReadOptions value); + /** * The ExecutorService instance to use to create threads, can be overridden to specify an * ExecutorService that is compatible with the user's environment. If unset, the default is to use diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/storage/GcsFileSystem.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/storage/GcsFileSystem.java index 6332051c0ddc..32079ebf55a3 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/storage/GcsFileSystem.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/storage/GcsFileSystem.java @@ -217,9 +217,19 @@ protected String getScheme() { @Override protected void reportLineage(GcsResourceId resourceId, Lineage lineage) { + reportLineage(resourceId, lineage, LineageLevel.FILE); + } + + @Override + protected void reportLineage(GcsResourceId resourceId, Lineage lineage, LineageLevel level) { GcsPath path = resourceId.getGcsPath(); if (!path.getBucket().isEmpty()) { - lineage.add("gcs", ImmutableList.of(path.getBucket(), path.getObject())); + ImmutableList.Builder segments = + ImmutableList.builder().add(path.getBucket()); + if (level != LineageLevel.TOP_LEVEL && !path.getObject().isEmpty()) { + segments.add(path.getObject()); + } + lineage.add("gcs", segments.build()); } else { LOG.warn("Report Lineage on relative path {} is unsupported", path.getObject()); } diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtil.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtil.java index 8d3596f17b3b..d58154132a72 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtil.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtil.java @@ -123,6 +123,14 @@ public static GcsCountersOptions create( } } + public static class GcsReadOptionsFactory + implements DefaultValueFactory { + @Override + public GoogleCloudStorageReadOptions create(PipelineOptions options) { + return GoogleCloudStorageReadOptions.DEFAULT; + } + } + /** * This is a {@link DefaultValueFactory} able to create a {@link GcsUtil} using any transport * flags specified on the {@link PipelineOptions}. @@ -153,7 +161,8 @@ public GcsUtil create(PipelineOptions options) { : null, gcsOptions.getEnableBucketWriteMetricCounter() ? gcsOptions.getGcsWriteCounterPrefix() - : null)); + : null), + gcsOptions.getGoogleCloudStorageReadOptions()); } /** Returns an instance of {@link GcsUtil} based on the given parameters. */ @@ -164,7 +173,8 @@ public static GcsUtil create( ExecutorService executorService, Credentials credentials, @Nullable Integer uploadBufferSizeBytes, - GcsCountersOptions gcsCountersOptions) { + GcsCountersOptions gcsCountersOptions, + GoogleCloudStorageReadOptions gcsReadOptions) { return new GcsUtil( storageClient, httpRequestInitializer, @@ -173,7 +183,8 @@ public static GcsUtil create( credentials, uploadBufferSizeBytes, null, - gcsCountersOptions); + gcsCountersOptions, + gcsReadOptions); } } @@ -249,7 +260,8 @@ public static boolean isWildcard(GcsPath spec) { Credentials credentials, @Nullable Integer uploadBufferSizeBytes, @Nullable Integer rewriteDataOpBatchLimit, - GcsCountersOptions gcsCountersOptions) { + GcsCountersOptions gcsCountersOptions, + GoogleCloudStorageReadOptions gcsReadOptions) { this.storageClient = storageClient; this.httpRequestInitializer = httpRequestInitializer; this.uploadBufferSizeBytes = uploadBufferSizeBytes; @@ -260,6 +272,7 @@ public static boolean isWildcard(GcsPath spec) { googleCloudStorageOptions = GoogleCloudStorageOptions.builder() .setAppName("Beam") + .setReadChannelOptions(gcsReadOptions) .setGrpcEnabled(shouldUseGrpc) .build(); googleCloudStorage = @@ -565,7 +578,9 @@ private SeekableByteChannel wrapInCounting( public SeekableByteChannel open(GcsPath path) throws IOException { String bucket = path.getBucket(); SeekableByteChannel channel = - googleCloudStorage.open(new StorageResourceId(path.getBucket(), path.getObject())); + googleCloudStorage.open( + new StorageResourceId(path.getBucket(), path.getObject()), + this.googleCloudStorageOptions.getReadChannelOptions()); return wrapInCounting(channel, bucket); } diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java index f5075a3f2c55..26d98125a3af 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java @@ -55,6 +55,8 @@ public void testGcpCoreApiSurface() throws Exception { classesInPackage("com.google.api.services.storage"), classesInPackage("com.google.auth"), classesInPackage("com.fasterxml.jackson.annotation"), + classesInPackage("com.google.cloud.hadoop.gcsio"), + classesInPackage("com.google.common.collect"), // Via gcs-connector ReadOptions builder classesInPackage("java"), classesInPackage("javax"), classesInPackage("org.apache.beam.sdk"), diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/storage/GcsFileSystemTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/storage/GcsFileSystemTest.java index 0b79cde1f187..f2ff7118f95d 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/storage/GcsFileSystemTest.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/storage/GcsFileSystemTest.java @@ -23,6 +23,9 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.api.services.storage.model.Objects; @@ -38,6 +41,7 @@ import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; import org.apache.beam.sdk.io.fs.MatchResult; import org.apache.beam.sdk.io.fs.MatchResult.Status; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -235,6 +239,20 @@ public void testMatchNonGlobs() throws Exception { contains(toFilenames(matchResults.get(4)).toArray())); } + @Test + public void testReportLineageOnBucket() { + verifyLineage("gs://testbucket", ImmutableList.of("testbucket")); + verifyLineage("gs://testbucket/", ImmutableList.of("testbucket")); + verifyLineage("gs://testbucket/foo/bar.txt", ImmutableList.of("testbucket", "foo/bar.txt")); + } + + private void verifyLineage(String uri, List expected) { + GcsResourceId path = GcsResourceId.fromGcsPath(GcsPath.fromUri(uri)); + Lineage mockLineage = mock(Lineage.class); + gcsFileSystem.reportLineage(path, mockLineage); + verify(mockLineage, times(1)).add("gcs", expected); + } + private StorageObject createStorageObject(String gcsFilename, long fileSize) { GcsPath gcsPath = GcsPath.fromUri(gcsFilename); // Google APIs will use null for empty files. diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java index 6f1e0e985c24..6477564f01a1 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java @@ -21,7 +21,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; import com.google.protobuf.ByteString; import java.io.IOException; @@ -35,10 +34,7 @@ import org.apache.beam.sdk.extensions.gcp.util.GcsUtil.CreateOptions; import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; import org.apache.beam.sdk.io.FileSystems; -import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.ExperimentalOptions; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.testing.UsesKms; @@ -99,8 +95,6 @@ public void testWriteAndReadGcsWithGrpc() throws IOException { "%s/GcsUtilIT-%tF-% writeGcsTextFile(gcsUtil, wrongFilename, testContent)); - // Write a test file in a bucket with gRPC enabled. - GcsGrpcOptions grpcOptions = options.as(GcsGrpcOptions.class); - assertNotNull(grpcOptions.getGrpcTempRoot()); - String tempLocationWithGrpc = grpcOptions.getGrpcTempRoot() + "/temp"; + String tempLocationWithGrpc = options.getTempRoot() + "/temp"; String filename = String.format(outputPattern, tempLocationWithGrpc, new Date()); writeGcsTextFile(gcsUtil, filename, testContent); @@ -132,15 +117,6 @@ public void testWriteAndReadGcsWithGrpc() throws IOException { gcsUtil.remove(Collections.singletonList(filename)); } - public interface GcsGrpcOptions extends PipelineOptions { - /** Get tempRoot in a gRPC-enabled bucket. */ - @Description("TempRoot in a gRPC-enabled bucket") - String getGrpcTempRoot(); - - /** Set the tempRoot in a gRPC-enabled bucket. */ - void setGrpcTempRoot(String grpcTempRoot); - } - void writeGcsTextFile(GcsUtil gcsUtil, String filename, String content) throws IOException { GcsPath gcsPath = GcsPath.fromUri(filename); try (WritableByteChannel channel = diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilTest.java index bd7f46ec8951..97082572ce41 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilTest.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilTest.java @@ -177,6 +177,32 @@ public void testCreationWithGcsUtilProvided() { assertSame(gcsUtil, pipelineOptions.getGcsUtil()); } + @Test + public void testCreationWithExplicitGoogleCloudStorageReadOptions() throws Exception { + GoogleCloudStorageReadOptions readOptions = + GoogleCloudStorageReadOptions.builder() + .setFadvise(GoogleCloudStorageReadOptions.Fadvise.AUTO) + .setSupportGzipEncoding(true) + .setFastFailOnNotFound(false) + .build(); + + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + pipelineOptions.setGoogleCloudStorageReadOptions(readOptions); + + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + GoogleCloudStorage googleCloudStorageMock = Mockito.spy(GoogleCloudStorage.class); + Mockito.when(googleCloudStorageMock.open(Mockito.any(), Mockito.any())) + .thenReturn(Mockito.mock(SeekableByteChannel.class)); + gcsUtil.setCloudStorageImpl(googleCloudStorageMock); + + assertEquals(readOptions, pipelineOptions.getGoogleCloudStorageReadOptions()); + + // Assert read options are passed to GCS calls + pipelineOptions.getGcsUtil().open(GcsPath.fromUri("gs://bucket/path")); + Mockito.verify(googleCloudStorageMock, Mockito.times(1)) + .open(StorageResourceId.fromStringPath("gs://bucket/path"), readOptions); + } + @Test public void testMultipleThreadsCanCompleteOutOfOrderWithDefaultThreadPool() throws Exception { GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); @@ -1630,7 +1656,8 @@ public static GcsUtilMock createMock(PipelineOptions options) { : null, gcsOptions.getEnableBucketWriteMetricCounter() ? gcsOptions.getGcsWriteCounterPrefix() - : null)); + : null), + gcsOptions.getGoogleCloudStorageReadOptions()); } private GcsUtilMock( @@ -1641,7 +1668,8 @@ private GcsUtilMock( Credentials credentials, @Nullable Integer uploadBufferSizeBytes, @Nullable Integer rewriteDataOpBatchLimit, - GcsCountersOptions gcsCountersOptions) { + GcsCountersOptions gcsCountersOptions, + GoogleCloudStorageReadOptions gcsReadOptions) { super( storageClient, httpRequestInitializer, @@ -1650,7 +1678,8 @@ private GcsUtilMock( credentials, uploadBufferSizeBytes, rewriteDataOpBatchLimit, - gcsCountersOptions); + gcsCountersOptions, + gcsReadOptions); } @Override diff --git a/sdks/java/extensions/ordered/build.gradle b/sdks/java/extensions/ordered/build.gradle index 10c9785b9eed..8bee1901bd3a 100644 --- a/sdks/java/extensions/ordered/build.gradle +++ b/sdks/java/extensions/ordered/build.gradle @@ -28,6 +28,12 @@ dependencies { implementation library.java.vendored_guava_32_1_2_jre testImplementation library.java.junit testImplementation library.java.hamcrest + testImplementation library.java.slf4j_jdk14 testImplementation project(path: ':sdks:java:core') + testImplementation 'junit:junit:4.13.1' + testImplementation project(path: ':runners:google-cloud-dataflow-java') testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") + testImplementation project(path: ":runners:google-cloud-dataflow-java") + testImplementation project(path: ":sdks:java:extensions:google-cloud-platform-core") + testImplementation project(path: ":sdks:java:io:google-cloud-platform") } \ No newline at end of file diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ContiguousSequenceRange.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ContiguousSequenceRange.java new file mode 100644 index 000000000000..c16cf9328dcd --- /dev/null +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ContiguousSequenceRange.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered; + +import com.google.auto.value.AutoValue; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.joda.time.Instant; + +/** A range of contiguous event sequences and the latest timestamp of the events in the range. */ +@AutoValue +public abstract class ContiguousSequenceRange { + public static final ContiguousSequenceRange EMPTY = + ContiguousSequenceRange.of( + Long.MIN_VALUE, Long.MIN_VALUE, Instant.ofEpochMilli(Long.MIN_VALUE)); + + /** @return inclusive starting sequence */ + public abstract long getStart(); + + /** @return exclusive end sequence */ + public abstract long getEnd(); + + /** @return latest timestamp of all events in the range */ + public abstract Instant getTimestamp(); + + public static ContiguousSequenceRange of(long start, long end, Instant timestamp) { + return new AutoValue_ContiguousSequenceRange(start, end, timestamp); + } + + static class CompletedSequenceRangeCoder extends CustomCoder { + + private static final CompletedSequenceRangeCoder INSTANCE = new CompletedSequenceRangeCoder(); + + static CompletedSequenceRangeCoder of() { + return INSTANCE; + } + + private CompletedSequenceRangeCoder() {} + + @Override + public void encode( + ContiguousSequenceRange value, @UnknownKeyFor @NonNull @Initialized OutputStream outStream) + throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull + @Initialized IOException { + VarLongCoder.of().encode(value.getStart(), outStream); + VarLongCoder.of().encode(value.getEnd(), outStream); + InstantCoder.of().encode(value.getTimestamp(), outStream); + } + + @Override + public ContiguousSequenceRange decode(@UnknownKeyFor @NonNull @Initialized InputStream inStream) + throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull + @Initialized IOException { + long start = VarLongCoder.of().decode(inStream); + long end = VarLongCoder.of().decode(inStream); + Instant timestamp = InstantCoder.of().decode(inStream); + return ContiguousSequenceRange.of(start, end, timestamp); + } + } +} diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/EventExaminer.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/EventExaminer.java index 1e4fe7565517..b5de67f16ced 100644 --- a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/EventExaminer.java +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/EventExaminer.java @@ -31,7 +31,8 @@ public interface EventExaminer> extends Serializable { /** - * Is this event the first expected event for the given key and window? + * Is this event the first expected event for the given key and window if the per key sequence is + * used? In case of global sequence it determines the first global sequence event. * * @param sequenceNumber the sequence number of the event as defined by the key of the input * PCollection to {@link OrderedEventProcessor} @@ -41,8 +42,8 @@ public interface EventExaminer> boolean isInitialEvent(long sequenceNumber, EventT event); /** - * If the event was the first event in the sequence, create the state to hold the required data - * needed for processing. This data will be persisted. + * If the event was the first event for a given key, create the state to hold the required data + * needed for processing. This data will be persisted in a Beam state. * * @param event the first event in the sequence. * @return the state to persist. @@ -53,6 +54,8 @@ public interface EventExaminer> /** * Is this event the last expected event for a given key and window? * + *

    Note, this method is not used yet with global sequences. + * * @param sequenceNumber of the event * @param event being processed * @return true if the last event. There are cases where it's impossible to know whether it's the diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/GlobalSequenceTracker.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/GlobalSequenceTracker.java new file mode 100644 index 000000000000..aa12c30a5317 --- /dev/null +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/GlobalSequenceTracker.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered; + +import org.apache.beam.sdk.extensions.ordered.ContiguousSequenceRange.CompletedSequenceRangeCoder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.windowing.AfterFirst; +import org.apache.beam.sdk.transforms.windowing.AfterPane; +import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TimestampedValue; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; + +/** + * PTransform to produce the side input of the maximum contiguous range of sequence numbers. + * + * @param type of event key + * @param type of event + * @param type of processing result + * @param type of state + */ +class GlobalSequenceTracker< + EventKeyT, EventT, ResultT, StateT extends MutableState> + extends PTransform< + PCollection>>>, + PCollectionView> { + + private final Combine.GloballyAsSingletonView< + TimestampedValue>>, ContiguousSequenceRange> + sideInputProducer; + private final @Nullable Duration frequencyOfGeneration; + private final int maxElementsBeforeReevaluatingGlobalSequence; + + /** + * Constructor used in batch pipelines. + * + * @param sideInputProducer + */ + public GlobalSequenceTracker( + Combine.GloballyAsSingletonView< + TimestampedValue>>, ContiguousSequenceRange> + sideInputProducer) { + this.sideInputProducer = sideInputProducer; + this.frequencyOfGeneration = null; + this.maxElementsBeforeReevaluatingGlobalSequence = 0; + } + + public GlobalSequenceTracker( + Combine.GloballyAsSingletonView< + TimestampedValue>>, ContiguousSequenceRange> + sideInputProducer, + Duration globalSequenceGenerationFrequency, + int maxElementsBeforeReevaluatingGlobalSequence) { + this.sideInputProducer = sideInputProducer; + this.frequencyOfGeneration = globalSequenceGenerationFrequency; + this.maxElementsBeforeReevaluatingGlobalSequence = maxElementsBeforeReevaluatingGlobalSequence; + } + + @Override + public PCollectionView expand( + PCollection>>> input) { + input + .getPipeline() + .getCoderRegistry() + .registerCoderForClass(ContiguousSequenceRange.class, CompletedSequenceRangeCoder.of()); + + if (frequencyOfGeneration != null) { + // This branch will only be executed in case of streaming pipelines. + // For batch pipelines the side input should only be computed once. + input = + input.apply( + "Triggering Setup", + // Reproduce the windowing of the input PCollection, but change the triggering + // in order to create a slowing changing side input + Window.>>>into( + (WindowFn>>, ?>) + input.getWindowingStrategy().getWindowFn()) + .accumulatingFiredPanes() + .withAllowedLateness(input.getWindowingStrategy().getAllowedLateness()) + .triggering( + Repeatedly.forever( + AfterFirst.of( + AfterPane.elementCountAtLeast( + maxElementsBeforeReevaluatingGlobalSequence), + AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(frequencyOfGeneration))))); + } + return input.apply("Create Side Input", sideInputProducer); + } +} diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/GlobalSequencesProcessorDoFn.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/GlobalSequencesProcessorDoFn.java new file mode 100644 index 000000000000..64c2d119c97d --- /dev/null +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/GlobalSequencesProcessorDoFn.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered; + +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.extensions.ordered.ProcessingState.ProcessingStateCoder; +import org.apache.beam.sdk.state.OrderedListState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Main Stateful DoFn used to process events in the global sequence mode. + * + * @param + * @param + * @param + * @param + */ +class GlobalSequencesProcessorDoFn< + EventT, EventKeyT, ResultT, StateT extends MutableState> + extends ProcessorDoFn { + + private static final Logger LOG = LoggerFactory.getLogger(GlobalSequencesProcessorDoFn.class); + + private static final String BATCH_EMISSION_TIMER = "batchTimer"; + + @TimerId(BATCH_EMISSION_TIMER) + @SuppressWarnings("unused") + private final TimerSpec batchTimerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + private static final String BUFFERED_EVENTS = "bufferedEvents"; + + @StateId(BUFFERED_EVENTS) + @SuppressWarnings("unused") + private final StateSpec> bufferedEventsSpec; + + @StateId(PROCESSING_STATE) + @SuppressWarnings("unused") + private final StateSpec>> processingStateSpec; + + @StateId(MUTABLE_STATE) + @SuppressWarnings("unused") + private final StateSpec> mutableStateSpec; + + @StateId(WINDOW_CLOSED) + @SuppressWarnings("unused") + private final StateSpec> windowClosedSpec; + + @TimerId(STATUS_EMISSION_TIMER) + @SuppressWarnings("unused") + private final TimerSpec statusEmissionTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final PCollectionView latestContiguousRangeSideInput; + + private final Duration maxLateness; + + GlobalSequencesProcessorDoFn( + EventExaminer eventExaminer, + Coder eventCoder, + Coder stateCoder, + Coder keyCoder, + TupleTag> mainOutputTupleTag, + TupleTag> statusTupleTag, + Duration statusUpdateFrequency, + TupleTag>>> unprocessedEventTupleTag, + boolean produceStatusUpdateOnEveryEvent, + long maxNumberOfResultsToProduce, + PCollectionView latestContiguousRangeSideInput, + Duration maxLateness) { + super( + eventExaminer, + mainOutputTupleTag, + statusTupleTag, + statusUpdateFrequency, + unprocessedEventTupleTag, + produceStatusUpdateOnEveryEvent, + maxNumberOfResultsToProduce); + + this.latestContiguousRangeSideInput = latestContiguousRangeSideInput; + this.bufferedEventsSpec = StateSpecs.orderedList(eventCoder); + this.processingStateSpec = StateSpecs.value(ProcessingStateCoder.of(keyCoder)); + this.mutableStateSpec = StateSpecs.value(stateCoder); + this.windowClosedSpec = StateSpecs.value(BooleanCoder.of()); + this.maxLateness = maxLateness; + } + + @Override + boolean checkForFirstOrLastEvent() { + return false; + } + + @Override + boolean checkForSequenceGapInBufferedEvents() { + return false; + } + + @ProcessElement + public void processElement( + ProcessContext context, + @Element KV> eventAndSequence, + @StateId(BUFFERED_EVENTS) OrderedListState bufferedEventsProxy, + @AlwaysFetched @StateId(PROCESSING_STATE) + ValueState> processingStateProxy, + @StateId(MUTABLE_STATE) ValueState mutableStateProxy, + @TimerId(STATUS_EMISSION_TIMER) Timer statusEmissionTimer, + @TimerId(BATCH_EMISSION_TIMER) Timer batchEmissionTimer, + MultiOutputReceiver outputReceiver, + BoundedWindow window) { + + ContiguousSequenceRange lastContiguousRange = context.sideInput(latestContiguousRangeSideInput); + + EventT event = eventAndSequence.getValue().getValue(); + EventKeyT key = eventAndSequence.getKey(); + long sequence = eventAndSequence.getValue().getKey(); + + if (LOG.isTraceEnabled()) { + LOG.trace(key + ": " + sequence + " lastRange: " + lastContiguousRange); + } + + ProcessingState processingState = processingStateProxy.read(); + + if (processingState == null) { + // This is the first time we see this key/window pair + processingState = new ProcessingState<>(key); + if (statusUpdateFrequency != null) { + // Set up the timer to produce the status of the processing on a regular basis + statusEmissionTimer.offset(statusUpdateFrequency).setRelative(); + } + } + + processingState.updateGlobalSequenceDetails(lastContiguousRange); + + if (event == null) { + // This is a ticker event. We only need to update the state as it relates to the global + // sequence. + processingStateProxy.write(processingState); + + setBatchEmissionTimerIfNeeded(batchEmissionTimer, processingState); + + return; + } + + if (numberOfResultsBeforeBundleStart == null) { + // Per key processing is synchronized by Beam. There is no need to have it here. + numberOfResultsBeforeBundleStart = processingState.getResultCount(); + } + + processingState.eventReceived(); + + StateT state = + processNewEvent( + sequence, + event, + processingState, + mutableStateProxy, + bufferedEventsProxy, + outputReceiver); + + saveStates( + processingStateProxy, + processingState, + mutableStateProxy, + state, + outputReceiver, + window.maxTimestamp()); + + setBatchEmissionTimerIfNeeded(batchEmissionTimer, processingState); + } + + private void setBatchEmissionTimerIfNeeded( + Timer batchEmissionTimer, ProcessingState processingState) { + ContiguousSequenceRange lastCompleteGlobalSequence = processingState.getLastContiguousRange(); + if (lastCompleteGlobalSequence != null + && processingState.thereAreGloballySequencedEventsToBeProcessed()) { + batchEmissionTimer.set(lastCompleteGlobalSequence.getTimestamp().plus(maxLateness)); + } + } + + @OnTimer(BATCH_EMISSION_TIMER) + public void onBatchEmission( + OnTimerContext context, + @StateId(BUFFERED_EVENTS) OrderedListState bufferedEventsState, + @AlwaysFetched @StateId(PROCESSING_STATE) + ValueState> processingStatusState, + @AlwaysFetched @StateId(MUTABLE_STATE) ValueState mutableStateState, + @TimerId(BATCH_EMISSION_TIMER) Timer batchEmissionTimer, + MultiOutputReceiver outputReceiver) { + + // At this point everything in the buffered state is ready to be processed up to the latest + // global sequence. + @Nullable ProcessingState processingState = processingStatusState.read(); + if (processingState == null) { + LOG.warn("Missing the processing state. Probably occurred during pipeline drainage"); + return; + } + + StateT state = mutableStateState.read(); + + ContiguousSequenceRange lastContiguousRange = processingState.getLastContiguousRange(); + if (lastContiguousRange == null) { + LOG.warn("Last complete global instance is null."); + return; + } + + Long earliestBufferedSequence = processingState.getEarliestBufferedSequence(); + if (earliestBufferedSequence == null) { + LOG.warn("Earliest buffered sequence is null."); + return; + } + + if (LOG.isTraceEnabled()) { + LOG.trace("Emission timer: " + processingState); + } + + this.numberOfResultsBeforeBundleStart = processingState.getResultCount(); + + state = + processBufferedEventRange( + processingState, + state, + bufferedEventsState, + outputReceiver, + batchEmissionTimer, + lastContiguousRange); + + saveStates( + processingStatusState, + processingState, + mutableStateState, + state, + outputReceiver, + // TODO: validate that this is correct. + context.window().maxTimestamp()); + } + + @OnTimer(STATUS_EMISSION_TIMER) + @SuppressWarnings("unused") + public void onStatusEmission( + MultiOutputReceiver outputReceiver, + @TimerId(STATUS_EMISSION_TIMER) Timer statusEmissionTimer, + @StateId(WINDOW_CLOSED) ValueState windowClosedState, + @StateId(PROCESSING_STATE) ValueState> processingStateState) { + + processStatusTimerEvent( + outputReceiver, statusEmissionTimer, windowClosedState, processingStateState); + } +} diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessor.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessor.java index 935647c0e7e5..fb23a7c8667a 100644 --- a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessor.java +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessor.java @@ -19,52 +19,44 @@ import com.google.auto.value.AutoValue; import java.util.Arrays; -import java.util.Iterator; import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.BooleanCoder; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VarLongCoder; -import org.apache.beam.sdk.extensions.ordered.ProcessingState.ProcessingStateCoder; -import org.apache.beam.sdk.extensions.ordered.UnprocessedEvent.Reason; +import org.apache.beam.sdk.extensions.ordered.OrderedProcessingHandler.OrderedProcessingGlobalSequenceHandler; import org.apache.beam.sdk.extensions.ordered.UnprocessedEvent.UnprocessedEventCoder; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.schemas.SchemaRegistry; -import org.apache.beam.sdk.state.OrderedListState; -import org.apache.beam.sdk.state.StateSpec; -import org.apache.beam.sdk.state.StateSpecs; -import org.apache.beam.sdk.state.TimeDomain; -import org.apache.beam.sdk.state.Timer; -import org.apache.beam.sdk.state.TimerSpec; -import org.apache.beam.sdk.state.TimerSpecs; -import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; +import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypeDescriptor; -import org.joda.time.Duration; import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Transform for processing ordered events. Events are grouped by the key and within each key they * are applied according to the provided sequence. Events which arrive out of sequence are buffered * and processed after all the missing events for a given key have arrived. * - * @param - * @param - * @param + *

    There are two sequencing modes - a sequence per key and a global sequence. See {@link + * OrderedProcessingHandler} for details on how to configure this transform. + * + * @param type of event + * @param type of event key + * @param type of the state */ @AutoValue @SuppressWarnings({"nullness", "TypeNameShadowing"}) @@ -74,6 +66,18 @@ public abstract class OrderedEventProcessor< PCollection>>, OrderedEventProcessorResult> { + public static final String GLOBAL_SEQUENCE_TRACKER = "global_sequence_tracker"; + + /** + * Create the transform. + * + * @param handler provides the configuration of this transform + * @param type of event + * @param type of event key + * @param type of the result object + * @param type of the state to store + * @return the transform + */ public static < EventTypeT, EventKeyTypeT, @@ -129,10 +133,67 @@ public OrderedEventProcessorResult expand( throw new RuntimeException("Unable to get result coder", e); } - PCollectionTuple processingResult = + KvCoder mainOutputCoder = KvCoder.of(keyCoder, resultCoder); + KvCoder processingStatusCoder = + KvCoder.of(keyCoder, getOrderedProcessingStatusCoder(pipeline)); + KvCoder>> unprocessedEventsCoder = + KvCoder.of( + keyCoder, KvCoder.of(VarLongCoder.of(), new UnprocessedEventCoder<>(eventCoder))); + + if (handler instanceof OrderedProcessingGlobalSequenceHandler) { + OrderedProcessingGlobalSequenceHandler + globalSequenceHandler = + (OrderedProcessingGlobalSequenceHandler) handler; + + return expandGlobalSequenceProcessing( + input, + mainOutput, + statusOutput, + unprocessedEventOutput, + handler, + pipeline, + keyCoder, + eventCoder, + stateCoder, + mainOutputCoder, + processingStatusCoder, + unprocessedEventsCoder, + globalSequenceHandler); + } else { + return expandPerKeyProcessing( + input, + mainOutput, + statusOutput, + unprocessedEventOutput, + handler, + pipeline, + keyCoder, + eventCoder, + stateCoder, + mainOutputCoder, + processingStatusCoder, + unprocessedEventsCoder); + } + } + + private OrderedEventProcessorResult expandPerKeyProcessing( + PCollection>> input, + TupleTag> mainOutput, + TupleTag> statusOutput, + TupleTag>>> unprocessedEventOutput, + OrderedProcessingHandler handler, + Pipeline pipeline, + Coder keyCoder, + Coder eventCoder, + Coder stateCoder, + KvCoder mainOutputCoder, + KvCoder processingStatusCoder, + KvCoder>> unprocessedEventsCoder) { + PCollectionTuple processingResult; + processingResult = input.apply( ParDo.of( - new OrderedProcessorDoFn<>( + new SequencePerKeyProcessorDoFn<>( handler.getEventExaminer(), eventCoder, stateCoder, @@ -146,13 +207,6 @@ public OrderedEventProcessorResult expand( .withOutputTags( mainOutput, TupleTagList.of(Arrays.asList(statusOutput, unprocessedEventOutput)))); - - KvCoder mainOutputCoder = KvCoder.of(keyCoder, resultCoder); - KvCoder processingStatusCoder = - KvCoder.of(keyCoder, getOrderedProcessingStatusCoder(pipeline)); - KvCoder>> unprocessedEventsCoder = - KvCoder.of( - keyCoder, KvCoder.of(VarLongCoder.of(), new UnprocessedEventCoder<>(eventCoder))); return new OrderedEventProcessorResult<>( pipeline, processingResult.get(mainOutput).setCoder(mainOutputCoder), @@ -163,6 +217,84 @@ public OrderedEventProcessorResult expand( unprocessedEventOutput); } + private OrderedEventProcessorResult expandGlobalSequenceProcessing( + PCollection>> input, + TupleTag> mainOutput, + TupleTag> statusOutput, + TupleTag>>> unprocessedEventOutput, + OrderedProcessingHandler handler, + Pipeline pipeline, + Coder keyCoder, + Coder eventCoder, + Coder stateCoder, + KvCoder mainOutputCoder, + KvCoder processingStatusCoder, + KvCoder>> unprocessedEventsCoder, + OrderedProcessingGlobalSequenceHandler + globalSequenceHandler) { + PCollectionTuple processingResult; + boolean streamingProcessing = input.isBounded() == IsBounded.UNBOUNDED; + + final PCollectionView latestContiguousRange = + input + .apply("Convert to SequenceAndTimestamp", ParDo.of(new ToTimestampedEventConverter<>())) + .apply( + "Global Sequence Tracker", + streamingProcessing + ? new GlobalSequenceTracker<>( + globalSequenceHandler.getGlobalSequenceCombiner(), + globalSequenceHandler.getContiguousSequenceRangeReevaluationFrequency(), + globalSequenceHandler + .getMaxElementCountToTriggerContinuousSequenceRangeReevaluation()) + : new GlobalSequenceTracker<>( + globalSequenceHandler.getGlobalSequenceCombiner())); + + if (streamingProcessing) { + PCollection>> tickers = + input.apply( + "Create Tickers", + new PerKeyTickerGenerator<>( + keyCoder, + eventCoder, + globalSequenceHandler.getContiguousSequenceRangeReevaluationFrequency())); + + input = + PCollectionList.of(input) + .and(tickers) + .apply("Combine Events and Tickers", Flatten.pCollections()) + .setCoder(tickers.getCoder()); + } + processingResult = + input.apply( + ParDo.of( + new GlobalSequencesProcessorDoFn<>( + handler.getEventExaminer(), + eventCoder, + stateCoder, + keyCoder, + mainOutput, + statusOutput, + handler.getStatusUpdateFrequency(), + unprocessedEventOutput, + handler.isProduceStatusUpdateOnEveryEvent(), + handler.getMaxOutputElementsPerBundle(), + latestContiguousRange, + input.getWindowingStrategy().getAllowedLateness())) + .withOutputTags( + mainOutput, + TupleTagList.of(Arrays.asList(statusOutput, unprocessedEventOutput))) + .withSideInput(GLOBAL_SEQUENCE_TRACKER, latestContiguousRange)); + return new OrderedEventProcessorResult<>( + pipeline, + processingResult.get(mainOutput).setCoder(mainOutputCoder), + mainOutput, + processingResult.get(statusOutput).setCoder(processingStatusCoder), + statusOutput, + processingResult.get(unprocessedEventOutput).setCoder(unprocessedEventsCoder), + unprocessedEventOutput, + latestContiguousRange); + } + private static Coder getOrderedProcessingStatusCoder(Pipeline pipeline) { SchemaRegistry schemaRegistry = pipeline.getSchemaRegistry(); Coder result; @@ -179,497 +311,16 @@ private static Coder getOrderedProcessingStatusCoder(Pi return result; } - /** - * Main DoFn for processing ordered events. - * - * @param - * @param - * @param - */ - static class OrderedProcessorDoFn< - EventTypeT, - EventKeyTypeT, - ResultTypeT, - StateTypeT extends MutableState> - extends DoFn>, KV> { - - private static final Logger LOG = LoggerFactory.getLogger(OrderedProcessorDoFn.class); - - private static final String PROCESSING_STATE = "processingState"; - private static final String MUTABLE_STATE = "mutableState"; - private static final String BUFFERED_EVENTS = "bufferedEvents"; - private static final String STATUS_EMISSION_TIMER = "statusTimer"; - private static final String LARGE_BATCH_EMISSION_TIMER = "largeBatchTimer"; - private static final String WINDOW_CLOSED = "windowClosed"; - private final EventExaminer eventExaminer; - - @StateId(BUFFERED_EVENTS) - @SuppressWarnings("unused") - private final StateSpec> bufferedEventsSpec; - - @StateId(PROCESSING_STATE) - @SuppressWarnings("unused") - private final StateSpec>> processingStateSpec; - - @SuppressWarnings("unused") - @StateId(MUTABLE_STATE) - private final StateSpec> mutableStateSpec; - - @StateId(WINDOW_CLOSED) - @SuppressWarnings("unused") - private final StateSpec> windowClosedSpec; - - @TimerId(STATUS_EMISSION_TIMER) - @SuppressWarnings("unused") - private final TimerSpec statusEmissionTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); - - @TimerId(LARGE_BATCH_EMISSION_TIMER) - @SuppressWarnings("unused") - private final TimerSpec largeBatchEmissionTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME); - - private final TupleTag> statusTupleTag; - private final Duration statusUpdateFrequency; - - private final TupleTag> mainOutputTupleTag; - private final TupleTag>>> - unprocessedEventsTupleTag; - private final boolean produceStatusUpdateOnEveryEvent; - - private final long maxNumberOfResultsToProduce; - - private Long numberOfResultsBeforeBundleStart; - - /** - * Stateful DoFn to do the bulk of processing. - * - * @param eventExaminer - * @param eventCoder - * @param stateCoder - * @param keyCoder - * @param mainOutputTupleTag - * @param statusTupleTag - * @param statusUpdateFrequency - * @param unprocessedEventTupleTag - * @param produceStatusUpdateOnEveryEvent - * @param maxNumberOfResultsToProduce - */ - OrderedProcessorDoFn( - EventExaminer eventExaminer, - Coder eventCoder, - Coder stateCoder, - Coder keyCoder, - TupleTag> mainOutputTupleTag, - TupleTag> statusTupleTag, - Duration statusUpdateFrequency, - TupleTag>>> - unprocessedEventTupleTag, - boolean produceStatusUpdateOnEveryEvent, - long maxNumberOfResultsToProduce) { - this.eventExaminer = eventExaminer; - this.bufferedEventsSpec = StateSpecs.orderedList(eventCoder); - this.mutableStateSpec = StateSpecs.value(stateCoder); - this.processingStateSpec = StateSpecs.value(ProcessingStateCoder.of(keyCoder)); - this.windowClosedSpec = StateSpecs.value(BooleanCoder.of()); - this.mainOutputTupleTag = mainOutputTupleTag; - this.statusTupleTag = statusTupleTag; - this.unprocessedEventsTupleTag = unprocessedEventTupleTag; - this.statusUpdateFrequency = statusUpdateFrequency; - this.produceStatusUpdateOnEveryEvent = produceStatusUpdateOnEveryEvent; - this.maxNumberOfResultsToProduce = maxNumberOfResultsToProduce; - } - - @StartBundle - public void onBundleStart() { - numberOfResultsBeforeBundleStart = null; - } - - @FinishBundle - public void onBundleFinish() { - // This might be necessary because this field is also used in a Timer - numberOfResultsBeforeBundleStart = null; - } + static class ToTimestampedEventConverter + extends DoFn< + KV>, TimestampedValue>>> { @ProcessElement - public void processElement( - @StateId(BUFFERED_EVENTS) OrderedListState bufferedEventsState, - @AlwaysFetched @StateId(PROCESSING_STATE) - ValueState> processingStateState, - @StateId(MUTABLE_STATE) ValueState mutableStateState, - @TimerId(STATUS_EMISSION_TIMER) Timer statusEmissionTimer, - @TimerId(LARGE_BATCH_EMISSION_TIMER) Timer largeBatchEmissionTimer, - @Element KV> eventAndSequence, - MultiOutputReceiver outputReceiver, - BoundedWindow window) { - - EventKeyTypeT key = eventAndSequence.getKey(); - long sequence = eventAndSequence.getValue().getKey(); - EventTypeT event = eventAndSequence.getValue().getValue(); - - ProcessingState processingState = processingStateState.read(); - - if (processingState == null) { - // This is the first time we see this key/window pair - processingState = new ProcessingState<>(key); - if (statusUpdateFrequency != null) { - // Set up the timer to produce the status of the processing on a regular basis - statusEmissionTimer.offset(statusUpdateFrequency).setRelative(); - } - } - - if (numberOfResultsBeforeBundleStart == null) { - // Per key processing is synchronized by Beam. There is no need to have it here. - numberOfResultsBeforeBundleStart = processingState.getResultCount(); - } - - processingState.eventReceived(); - - StateTypeT state = - processNewEvent( - sequence, - event, - processingState, - mutableStateState, - bufferedEventsState, - outputReceiver); - - processBufferedEvents( - processingState, state, bufferedEventsState, outputReceiver, largeBatchEmissionTimer); - - saveStates( - processingStateState, - processingState, - mutableStateState, - state, - outputReceiver, - window.maxTimestamp()); - - checkIfProcessingIsCompleted(processingState); - } - - private boolean checkIfProcessingIsCompleted(ProcessingState processingState) { - boolean result = processingState.isProcessingCompleted(); - if (result) { - LOG.info("Processing for key '" + processingState.getKey() + "' is completed."); - } - return result; - } - - private void saveStates( - ValueState> processingStatusState, - ProcessingState processingStatus, - ValueState currentStateState, - StateTypeT state, - MultiOutputReceiver outputReceiver, - Instant windowTimestamp) { - // There is always a change to the processing status - processingStatusState.write(processingStatus); - - // Stored state may not have changes if the element was out of sequence. - if (state != null) { - currentStateState.write(state); - } - - if (produceStatusUpdateOnEveryEvent) { - // During pipeline draining the window timestamp is set to a large value in the future. - // Producing an event before that results in error, that's why this logic exist. - Instant statusTimestamp = windowTimestamp; - - emitProcessingStatus(processingStatus, outputReceiver, statusTimestamp); - } - } - - private void emitProcessingStatus( - ProcessingState processingState, - MultiOutputReceiver outputReceiver, - Instant statusTimestamp) { - outputReceiver - .get(statusTupleTag) - .outputWithTimestamp( - KV.of( - processingState.getKey(), - OrderedProcessingStatus.create( - processingState.getLastOutputSequence(), - processingState.getBufferedEventCount(), - processingState.getEarliestBufferedSequence(), - processingState.getLatestBufferedSequence(), - processingState.getEventsReceived(), - processingState.getResultCount(), - processingState.getDuplicates(), - processingState.isLastEventReceived())), - statusTimestamp); - } - - /** - * Process the just received event. - * - * @return newly created or updated State. If null is returned - the event wasn't processed. - */ - private StateTypeT processNewEvent( - long currentSequence, - EventTypeT currentEvent, - ProcessingState processingState, - ValueState currentStateState, - OrderedListState bufferedEventsState, - MultiOutputReceiver outputReceiver) { - if (currentSequence == Long.MAX_VALUE) { - // OrderedListState can't handle the timestamp based on MAX_VALUE. - // To avoid exceptions, we DLQ this event. - outputReceiver - .get(unprocessedEventsTupleTag) - .output( - KV.of( - processingState.getKey(), - KV.of( - currentSequence, - UnprocessedEvent.create( - currentEvent, Reason.sequence_id_outside_valid_range)))); - return null; - } - - if (processingState.hasAlreadyBeenProcessed(currentSequence)) { - outputReceiver - .get(unprocessedEventsTupleTag) - .output( - KV.of( - processingState.getKey(), - KV.of( - currentSequence, UnprocessedEvent.create(currentEvent, Reason.duplicate)))); - return null; - } - - StateTypeT state; - boolean thisIsTheLastEvent = eventExaminer.isLastEvent(currentSequence, currentEvent); - if (eventExaminer.isInitialEvent(currentSequence, currentEvent)) { - // First event of the key/window - // What if it's a duplicate event - it will reset everything. Shall we drop/DLQ anything - // that's before the processingState.lastOutputSequence? - state = eventExaminer.createStateOnInitialEvent(currentEvent); - - processingState.eventAccepted(currentSequence, thisIsTheLastEvent); - - ResultTypeT result = state.produceResult(); - if (result != null) { - outputReceiver.get(mainOutputTupleTag).output(KV.of(processingState.getKey(), result)); - processingState.resultProduced(); - } - - // Nothing else to do. We will attempt to process buffered events later. - return state; - } - - if (processingState.isNextEvent(currentSequence)) { - // Event matches expected sequence - state = currentStateState.read(); - - try { - state.mutate(currentEvent); - } catch (Exception e) { - outputReceiver - .get(unprocessedEventsTupleTag) - .output( - KV.of( - processingState.getKey(), - KV.of(currentSequence, UnprocessedEvent.create(currentEvent, e)))); - return null; - } - - ResultTypeT result = state.produceResult(); - if (result != null) { - outputReceiver.get(mainOutputTupleTag).output(KV.of(processingState.getKey(), result)); - processingState.resultProduced(); - } - processingState.eventAccepted(currentSequence, thisIsTheLastEvent); - - return state; - } - - // Event is not ready to be processed yet - Instant eventTimestamp = Instant.ofEpochMilli(currentSequence); - bufferedEventsState.add(TimestampedValue.of(currentEvent, eventTimestamp)); - processingState.eventBuffered(currentSequence, thisIsTheLastEvent); - - // This will signal that the state hasn't been mutated and we don't need to save it. - return null; - } - - /** Process buffered events. */ - private void processBufferedEvents( - ProcessingState processingState, - StateTypeT state, - OrderedListState bufferedEventsState, - MultiOutputReceiver outputReceiver, - Timer largeBatchEmissionTimer) { - if (state == null) { - // Only when the current event caused a state mutation and the state is passed to this - // method should we attempt to process buffered events - return; - } - - if (!processingState.readyToProcessBufferedEvents()) { - return; - } - - if (reachedMaxResultCountForBundle(processingState, largeBatchEmissionTimer)) { - // No point in trying to process buffered events - return; - } - - Instant startRange = Instant.ofEpochMilli(processingState.getEarliestBufferedSequence()); - Instant endRange = Instant.ofEpochMilli(processingState.getLatestBufferedSequence() + 1); - Instant endClearRange = null; - - // readRange is efficiently implemented and will bring records in batches - Iterable> events = - bufferedEventsState.readRange(startRange, endRange); - - Iterator> bufferedEventsIterator = events.iterator(); - while (bufferedEventsIterator.hasNext()) { - TimestampedValue timestampedEvent = bufferedEventsIterator.next(); - Instant eventTimestamp = timestampedEvent.getTimestamp(); - long eventSequence = eventTimestamp.getMillis(); - - EventTypeT bufferedEvent = timestampedEvent.getValue(); - if (processingState.checkForDuplicateBatchedEvent(eventSequence)) { - outputReceiver - .get(unprocessedEventsTupleTag) - .output( - KV.of( - processingState.getKey(), - KV.of( - eventSequence, - UnprocessedEvent.create(bufferedEvent, Reason.duplicate)))); - continue; - } - - if (eventSequence > processingState.getLastOutputSequence() + 1) { - processingState.foundSequenceGap(eventSequence); - // Records will be cleared up to this element - endClearRange = Instant.ofEpochMilli(eventSequence); - break; - } - - // This check needs to be done after we checked for sequence gap and before we - // attempt to process the next element which can result in a new result. - if (reachedMaxResultCountForBundle(processingState, largeBatchEmissionTimer)) { - endClearRange = Instant.ofEpochMilli(eventSequence); - break; - } - - try { - state.mutate(bufferedEvent); - } catch (Exception e) { - outputReceiver - .get(unprocessedEventsTupleTag) - .output( - KV.of( - processingState.getKey(), - KV.of(eventSequence, UnprocessedEvent.create(bufferedEvent, e)))); - // There is a chance that the next event will have the same sequence number and will - // process successfully. - continue; - } - - ResultTypeT result = state.produceResult(); - if (result != null) { - outputReceiver.get(mainOutputTupleTag).output(KV.of(processingState.getKey(), result)); - processingState.resultProduced(); - } - processingState.processedBufferedEvent(eventSequence); - // Remove this record also - endClearRange = Instant.ofEpochMilli(eventSequence + 1); - } - - bufferedEventsState.clearRange(startRange, endClearRange); - } - - private boolean reachedMaxResultCountForBundle( - ProcessingState processingState, Timer largeBatchEmissionTimer) { - boolean exceeded = - processingState.resultsProducedInBundle(numberOfResultsBeforeBundleStart) - >= maxNumberOfResultsToProduce; - if (exceeded) { - LOG.info( - "Setting the timer to output next batch of events for key '" - + processingState.getKey() - + "'"); - // See GroupIntoBatches for examples on how to hold the timestamp. - // TODO: test that on draining the pipeline all the results are still produced correctly. - // See: https://github.com/apache/beam/issues/30781 - largeBatchEmissionTimer.offset(Duration.millis(1)).setRelative(); - } - return exceeded; - } - - @OnTimer(LARGE_BATCH_EMISSION_TIMER) - public void onBatchEmission( - OnTimerContext context, - @StateId(BUFFERED_EVENTS) OrderedListState bufferedEventsState, - @AlwaysFetched @StateId(PROCESSING_STATE) - ValueState> processingStatusState, - @AlwaysFetched @StateId(MUTABLE_STATE) ValueState currentStateState, - @TimerId(LARGE_BATCH_EMISSION_TIMER) Timer largeBatchEmissionTimer, - MultiOutputReceiver outputReceiver) { - ProcessingState processingState = processingStatusState.read(); - if (processingState == null) { - LOG.warn("Processing state is empty. Ignore it if the pipeline is being cancelled."); - return; - } - StateTypeT state = currentStateState.read(); - if (state == null) { - LOG.warn("Mutable state is empty. Ignore it if the pipeline is being cancelled."); - return; - } - - LOG.debug("Starting to process batch for key '" + processingState.getKey() + "'"); - - this.numberOfResultsBeforeBundleStart = processingState.getResultCount(); - - processBufferedEvents( - processingState, state, bufferedEventsState, outputReceiver, largeBatchEmissionTimer); - - saveStates( - processingStatusState, - processingState, - currentStateState, - state, - outputReceiver, - // TODO: validate that this is correct. - context.window().maxTimestamp()); - - checkIfProcessingIsCompleted(processingState); - } - - @OnTimer(STATUS_EMISSION_TIMER) - @SuppressWarnings("unused") - public void onStatusEmission( - MultiOutputReceiver outputReceiver, - @TimerId(STATUS_EMISSION_TIMER) Timer statusEmissionTimer, - @StateId(WINDOW_CLOSED) ValueState windowClosedState, - @StateId(PROCESSING_STATE) - ValueState> processingStateState) { - - ProcessingState currentState = processingStateState.read(); - if (currentState == null) { - // This could happen if the state has been purged already during the draining. - // It means that there is nothing that we can do and we just need to return. - LOG.warn( - "Current processing state is null in onStatusEmission() - most likely the pipeline is shutting down."); - return; - } - - emitProcessingStatus(currentState, outputReceiver, Instant.now()); - - Boolean windowClosed = windowClosedState.read(); - if (!currentState.isProcessingCompleted() - // Stop producing statuses if we are finished for a particular key - && (windowClosed == null || !windowClosed)) { - statusEmissionTimer.offset(statusUpdateFrequency).setRelative(); - } - } - - @OnWindowExpiration - public void onWindowExpiration(@StateId(WINDOW_CLOSED) ValueState windowClosedState) { - windowClosedState.write(true); + public void convert( + @Element KV> element, + @Timestamp Instant timestamp, + OutputReceiver>>> outputReceiver) { + outputReceiver.output(TimestampedValue.of(element, timestamp)); } } } diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorResult.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorResult.java index f61df6254b25..48b9fafc99af 100644 --- a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorResult.java +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorResult.java @@ -18,10 +18,12 @@ package org.apache.beam.sdk.extensions.ordered; import java.util.Map; +import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; @@ -29,10 +31,15 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; /** - * The result of the ordered processing. Two PCollections are returned: + * The result of the ordered processing. Three PCollections are returned: *

  • output - the key/value of the mutated states + *
  • unprocessedEvents - the key/value of the events that failed to be processed and the failure + * reason *
  • processingStatuses - the key/value of the status of processing for a particular key * + *

    In case of global sequence processing, the result also contains PCollectionView of the + * latest contiguous sequence range + * * @param * @param */ @@ -48,6 +55,8 @@ public class OrderedEventProcessorResult implements POutp unprocessedEventPCollection; private final TupleTag>>> unprocessedEventTupleTag; + private final @Nullable PCollectionView latestContiguousRange; + OrderedEventProcessorResult( Pipeline pipeline, PCollection> outputPCollection, @@ -57,6 +66,27 @@ public class OrderedEventProcessorResult implements POutp PCollection>>> unprocessedEventPCollection, TupleTag>>> unprocessedEventTupleTag) { + this( + pipeline, + outputPCollection, + outputPCollectionTupleTag, + eventProcessingStatusPCollection, + eventProcessingStatusTupleTag, + unprocessedEventPCollection, + unprocessedEventTupleTag, + null); + } + + OrderedEventProcessorResult( + Pipeline pipeline, + PCollection> outputPCollection, + TupleTag> outputPCollectionTupleTag, + PCollection> eventProcessingStatusPCollection, + TupleTag> eventProcessingStatusTupleTag, + PCollection>>> unprocessedEventPCollection, + TupleTag>>> unprocessedEventTupleTag, + @Nullable PCollectionView latestContiguousRange) { + this.pipeline = pipeline; this.outputPCollection = outputPCollection; this.outputPCollectionTupleTag = outputPCollectionTupleTag; @@ -64,6 +94,7 @@ public class OrderedEventProcessorResult implements POutp this.eventProcessingStatusTupleTag = eventProcessingStatusTupleTag; this.unprocessedEventPCollection = unprocessedEventPCollection; this.unprocessedEventTupleTag = unprocessedEventTupleTag; + this.latestContiguousRange = latestContiguousRange; } private final Pipeline pipeline; @@ -104,4 +135,8 @@ public PCollection> output() { public PCollection>>> unprocessedEvents() { return unprocessedEventPCollection; } + + public @Nullable PCollectionView latestContiguousRange() { + return latestContiguousRange; + } } diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedProcessingHandler.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedProcessingHandler.java index 444fdb118091..d8ad13330a1a 100644 --- a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedProcessingHandler.java +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedProcessingHandler.java @@ -22,7 +22,11 @@ import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.extensions.ordered.combiner.DefaultSequenceCombiner; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.GloballyAsSingletonView; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TimestampedValue; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; @@ -30,6 +34,11 @@ /** * Parent class for Ordered Processing configuration handlers. * + *

    There are two types of processing - when the sequence numbers are contiguous per key and these + * sequences per keys are independent of each other, and when there is a global sequence shared by + * all keys. In case of the global sequence processing the custom handler must extend from {@see + * OrderedProcessingGlobalSequenceHandler}. + * * @param type of events to be processed * @param type of keys which will be used to group the events * @param type of internal State which will be used for processing @@ -217,4 +226,75 @@ public int getMaxOutputElementsPerBundle() { public void setMaxOutputElementsPerBundle(int maxOutputElementsPerBundle) { this.maxOutputElementsPerBundle = maxOutputElementsPerBundle; } + + /** + * Parent class for Ordered Processing configuration handlers to handle processing of the events + * where global sequence is used. + * + * @param type of events to be processed + * @param type of keys which will be used to group the events + * @param type of internal State which will be used for processing + * @param type of the result of the processing which will be output + */ + public abstract static class OrderedProcessingGlobalSequenceHandler< + EventT, KeyT, StateT extends MutableState, ResultT> + extends OrderedProcessingHandler { + + public OrderedProcessingGlobalSequenceHandler( + Class eventTClass, + Class keyTClass, + Class stateTClass, + Class resultTClass) { + super(eventTClass, keyTClass, stateTClass, resultTClass); + } + + /** + * Provide the global sequence combiner. Default is to use {@link DefaultSequenceCombiner}. + * + * @return combiner + */ + public GloballyAsSingletonView< + TimestampedValue>>, ContiguousSequenceRange> + getGlobalSequenceCombiner() { + return Combine.globally(new DefaultSequenceCombiner(getEventExaminer())) + .asSingletonView(); + } + + /** + * How frequently the combiner should reevaluate the maximum range? This parameter only affects + * the behaviour of streaming pipelines. + * + *

    This parameter is used together with {@link + * OrderedProcessingGlobalSequenceHandler#getMaxElementCountToTriggerContinuousSequenceRangeReevaluation()}. + * The re-evaluation will occur as soon as the number of new elements exceeds the threshold or + * the time exceeds the frequency. + * + *

    Notice that some runners cache the output of side inputs and this parameter might not + * appear to have an effect unless the cache time-to-live is equal or less than this frequency. + * For Dataflow runner, see {@link this + * Dataflow streaming pipeline option} + * + * @return frequency of reevaluating the {@link ContiguousSequenceRange}. Default - every + * second. + * @see + * OrderedProcessingGlobalSequenceHandler#getMaxElementCountToTriggerContinuousSequenceRangeReevaluation() + */ + public Duration getContiguousSequenceRangeReevaluationFrequency() { + return Duration.standardSeconds(1); + } + + /** + * Number of new elements to trigger the re-evaluation. + * + *

    See {@link + * OrderedProcessingGlobalSequenceHandler#getContiguousSequenceRangeReevaluationFrequency()} for + * additional details. + * + * @return batch size. Default - 1000. + */ + public int getMaxElementCountToTriggerContinuousSequenceRangeReevaluation() { + return 1000; + } + } } diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedProcessingStatus.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedProcessingStatus.java index 6659bd2e2b92..7a556de1017b 100644 --- a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedProcessingStatus.java +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/OrderedProcessingStatus.java @@ -30,16 +30,16 @@ public abstract class OrderedProcessingStatus { public static OrderedProcessingStatus create( - Long lastOutputSequence, + @Nullable Long lastProcessedSequence, long numberOfBufferedEvents, - Long earliestBufferedSequence, - Long latestBufferedSequence, + @Nullable Long earliestBufferedSequence, + @Nullable Long latestBufferedSequence, long numberOfReceivedEvents, long resultCount, long duplicateCount, boolean lastEventReceived) { return new AutoValue_OrderedProcessingStatus.Builder() - .setLastProcessedSequence(lastOutputSequence) + .setLastProcessedSequence(lastProcessedSequence) .setNumberOfBufferedEvents(numberOfBufferedEvents) .setEarliestBufferedSequence(earliestBufferedSequence) .setLatestBufferedSequence(latestBufferedSequence) @@ -55,8 +55,7 @@ public static OrderedProcessingStatus create( * @return Last sequence processed. If null is returned - no elements for the given key and window * have been processed yet. */ - @Nullable - public abstract Long getLastProcessedSequence(); + public abstract @Nullable Long getLastProcessedSequence(); /** @return Number of events received out of sequence and buffered. */ public abstract long getNumberOfBufferedEvents(); @@ -129,13 +128,13 @@ public final int hashCode() { @AutoValue.Builder public abstract static class Builder { - public abstract Builder setLastProcessedSequence(Long value); + public abstract Builder setLastProcessedSequence(@Nullable Long value); public abstract Builder setNumberOfBufferedEvents(long value); - public abstract Builder setEarliestBufferedSequence(Long value); + public abstract Builder setEarliestBufferedSequence(@Nullable Long value); - public abstract Builder setLatestBufferedSequence(Long value); + public abstract Builder setLatestBufferedSequence(@Nullable Long value); public abstract Builder setNumberOfReceivedEvents(long value); diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/PerKeyTickerGenerator.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/PerKeyTickerGenerator.java new file mode 100644 index 000000000000..a18ba53f5266 --- /dev/null +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/PerKeyTickerGenerator.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered; + +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * PTransform to generate per key tickers with certain frequency. + * + * @param + * @param + */ +class PerKeyTickerGenerator + extends PTransform< + PCollection>>, + PCollection>>> { + + private static final Logger LOG = LoggerFactory.getLogger(PerKeyTickerGenerator.class); + + private final Coder eventKeyCoder; + private final Coder eventCoder; + private final Duration tickerFrequency; + + PerKeyTickerGenerator( + Coder eventKeyCoder, Coder eventCoder, Duration tickerFrequency) { + this.eventKeyCoder = eventKeyCoder; + this.eventCoder = eventCoder; + this.tickerFrequency = tickerFrequency; + } + + @Override + public @UnknownKeyFor @NonNull @Initialized PCollection>> expand( + PCollection>> input) { + return input + .apply( + "Generate Tickers", + ParDo.of(new PerKeyTickerGeneratorDoFn<>(eventKeyCoder, tickerFrequency))) + .setCoder( + KvCoder.of(eventKeyCoder, KvCoder.of(VarLongCoder.of(), NullableCoder.of(eventCoder)))); + } + + static class PerKeyTickerGeneratorDoFn + extends DoFn>, KV>> { + + private static final String STATE = "state"; + private static final String TIMER = "timer"; + + @StateId(STATE) + @SuppressWarnings("unused") + private final StateSpec> stateSpec; + + @TimerId(TIMER) + @SuppressWarnings("unused") + private final TimerSpec tickerTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final Duration tickerFrequency; + + PerKeyTickerGeneratorDoFn(Coder keyCoder, Duration tickerFrequency) { + stateSpec = StateSpecs.value(keyCoder); + this.tickerFrequency = tickerFrequency; + } + + @ProcessElement + public void process( + @Element KV> element, + @AlwaysFetched @StateId(STATE) ValueState state, + @TimerId(TIMER) Timer tickerTimer) { + @Nullable EventKeyT keyValue = state.read(); + if (keyValue != null) { + return; + } + + tickerTimer.offset(tickerFrequency).setRelative(); + + state.write(element.getKey()); + } + + @OnTimer(TIMER) + public void onTimer( + @StateId(STATE) ValueState state, + @TimerId(TIMER) Timer tickerTimer, + OutputReceiver>> outputReceiver) { + + @Nullable EventKeyT key = state.read(); + if (key == null) { + LOG.error("Expected to get the key from the state, but got null"); + return; + } + + // Null value will be an indicator to the main transform that the element is a ticker + outputReceiver.output(KV.of(key, KV.of(0L, null))); + tickerTimer.offset(tickerFrequency).setRelative(); + } + } +} diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ProcessingState.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ProcessingState.java index 4b591a37faab..425eb4444a63 100644 --- a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ProcessingState.java +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ProcessingState.java @@ -51,6 +51,8 @@ class ProcessingState { private long resultCount; + @Nullable private ContiguousSequenceRange lastCompleteGlobalSequence; + private KeyT key; public ProcessingState(KeyT key) { @@ -59,6 +61,7 @@ public ProcessingState(KeyT key) { this.lastOutputSequence = null; this.earliestBufferedSequence = null; this.latestBufferedSequence = null; + this.lastCompleteGlobalSequence = null; } /** @@ -130,6 +133,15 @@ public KeyT getKey() { return key; } + public @Nullable ContiguousSequenceRange getLastContiguousRange() { + return lastCompleteGlobalSequence; + } + + public void setLastCompleteGlobalSequence( + @Nullable ContiguousSequenceRange lastCompleteGlobalSequence) { + this.lastCompleteGlobalSequence = lastCompleteGlobalSequence; + } + /** * Current event matched the sequence and was processed. * @@ -229,6 +241,32 @@ public int hashCode() { key); } + @Override + public String toString() { + return "ProcessingState{" + + "lastOutputSequence=" + + lastOutputSequence + + ", latestBufferedSequence=" + + latestBufferedSequence + + ", earliestBufferedSequence=" + + earliestBufferedSequence + + ", bufferedEventCount=" + + bufferedEventCount + + ", lastEventReceived=" + + lastEventReceived + + ", eventsReceived=" + + eventsReceived + + ", duplicates=" + + duplicates + + ", resultCount=" + + resultCount + + ", lastCompleteGlobalSequence=" + + lastCompleteGlobalSequence + + ", key=" + + key + + '}'; + } + public boolean isProcessingCompleted() { return lastEventReceived && bufferedEventCount == 0; } @@ -274,6 +312,23 @@ public long resultsProducedInBundle(long numberOfResultsBeforeBundleStart) { return resultCount - numberOfResultsBeforeBundleStart; } + public void updateGlobalSequenceDetails(ContiguousSequenceRange updated) { + if (thereAreGloballySequencedEventsToBeProcessed()) { + // We don't update the timer if we can already process events in the onTimer batch. + // Otherwise, it's possible that we will be pushing the timer to later timestamps + // without a chance to run and produce output. + return; + } + this.lastCompleteGlobalSequence = updated; + } + + public boolean thereAreGloballySequencedEventsToBeProcessed() { + return bufferedEventCount > 0 + && lastCompleteGlobalSequence != null + && earliestBufferedSequence != null + && earliestBufferedSequence < lastCompleteGlobalSequence.getEnd(); + } + /** * Coder for the processing status. * @@ -287,6 +342,9 @@ static class ProcessingStateCoder extends Coder> { private static final VarIntCoder INTEGER_CODER = VarIntCoder.of(); private static final BooleanCoder BOOLEAN_CODER = BooleanCoder.of(); + private static final NullableCoder SEQUENCE_AND_TIMESTAMP_CODER = + NullableCoder.of(ContiguousSequenceRange.CompletedSequenceRangeCoder.of()); + private Coder keyCoder; private ProcessingStateCoder(Coder keyCoder) { @@ -308,6 +366,7 @@ public void encode(ProcessingState value, OutputStream outStream) throws I LONG_CODER.encode(value.getResultCount(), outStream); BOOLEAN_CODER.encode(value.isLastEventReceived(), outStream); keyCoder.encode(value.getKey(), outStream); + SEQUENCE_AND_TIMESTAMP_CODER.encode(value.getLastContiguousRange(), outStream); } @Override @@ -321,17 +380,23 @@ public ProcessingState decode(InputStream inStream) throws IOException { long resultCount = LONG_CODER.decode(inStream); boolean isLastEventReceived = BOOLEAN_CODER.decode(inStream); KeyT key = keyCoder.decode(inStream); - - return new ProcessingState<>( - key, - lastOutputSequence, - earliestBufferedSequence, - latestBufferedSequence, - bufferedRecordCount, - recordsReceivedCount, - duplicates, - resultCount, - isLastEventReceived); + ContiguousSequenceRange lastCompleteGlobalSequence = + SEQUENCE_AND_TIMESTAMP_CODER.decode(inStream); + + ProcessingState result = + new ProcessingState<>( + key, + lastOutputSequence, + earliestBufferedSequence, + latestBufferedSequence, + bufferedRecordCount, + recordsReceivedCount, + duplicates, + resultCount, + isLastEventReceived); + result.setLastCompleteGlobalSequence(lastCompleteGlobalSequence); + + return result; } @Override diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ProcessorDoFn.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ProcessorDoFn.java new file mode 100644 index 000000000000..a05b0829074a --- /dev/null +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/ProcessorDoFn.java @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered; + +import java.util.Iterator; +import javax.annotation.Nullable; +import org.apache.beam.sdk.extensions.ordered.UnprocessedEvent.Reason; +import org.apache.beam.sdk.state.OrderedListState; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base DoFn for processing ordered events. + * + * @param type of the events to process + * @param event key type + * @param state type + */ +abstract class ProcessorDoFn< + EventT, EventKeyT, ResultT, StateT extends MutableState> + extends DoFn>, KV> { + + private static final Logger LOG = LoggerFactory.getLogger(ProcessorDoFn.class); + + protected static final String PROCESSING_STATE = "processingState"; + protected static final String MUTABLE_STATE = "mutableState"; + + protected static final String STATUS_EMISSION_TIMER = "statusTimer"; + protected static final String WINDOW_CLOSED = "windowClosed"; + protected final EventExaminer eventExaminer; + + private final TupleTag> statusTupleTag; + protected final Duration statusUpdateFrequency; + + protected final TupleTag> mainOutputTupleTag; + protected final TupleTag>>> + unprocessedEventsTupleTag; + private final boolean produceStatusUpdateOnEveryEvent; + + private final long maxNumberOfResultsToProduce; + + protected @Nullable Long numberOfResultsBeforeBundleStart = 0L; + + ProcessorDoFn( + EventExaminer eventExaminer, + TupleTag> mainOutputTupleTag, + TupleTag> statusTupleTag, + Duration statusUpdateFrequency, + TupleTag>>> unprocessedEventTupleTag, + boolean produceStatusUpdateOnEveryEvent, + long maxNumberOfResultsToProduce) { + this.eventExaminer = eventExaminer; + + this.mainOutputTupleTag = mainOutputTupleTag; + this.statusTupleTag = statusTupleTag; + this.unprocessedEventsTupleTag = unprocessedEventTupleTag; + this.statusUpdateFrequency = statusUpdateFrequency; + this.produceStatusUpdateOnEveryEvent = produceStatusUpdateOnEveryEvent; + this.maxNumberOfResultsToProduce = maxNumberOfResultsToProduce; + } + + @StartBundle + public void onBundleStart() { + numberOfResultsBeforeBundleStart = null; + } + + @FinishBundle + public void onBundleFinish() { + // This might be necessary because this field is also used in a Timer + numberOfResultsBeforeBundleStart = null; + } + + /** @return true if each event needs to be examined. */ + abstract boolean checkForFirstOrLastEvent(); + + /** + * Process the just received event. + * + * @return newly created or updated State. If null is returned - the event wasn't processed. + */ + protected @javax.annotation.Nullable StateT processNewEvent( + long currentSequence, + EventT currentEvent, + ProcessingState processingState, + ValueState currentStateState, + OrderedListState bufferedEventsState, + MultiOutputReceiver outputReceiver) { + if (currentSequence == Long.MAX_VALUE) { + // OrderedListState can't handle the timestamp based on MAX_VALUE. + // To avoid exceptions, we DLQ this event. + outputReceiver + .get(unprocessedEventsTupleTag) + .output( + KV.of( + processingState.getKey(), + KV.of( + currentSequence, + UnprocessedEvent.create( + currentEvent, Reason.sequence_id_outside_valid_range)))); + return null; + } + + if (processingState.hasAlreadyBeenProcessed(currentSequence)) { + outputReceiver + .get(unprocessedEventsTupleTag) + .output( + KV.of( + processingState.getKey(), + KV.of(currentSequence, UnprocessedEvent.create(currentEvent, Reason.duplicate)))); + return null; + } + + StateT state; + boolean thisIsTheLastEvent = + checkForFirstOrLastEvent() && eventExaminer.isLastEvent(currentSequence, currentEvent); + if (checkForFirstOrLastEvent() && eventExaminer.isInitialEvent(currentSequence, currentEvent)) { + // First event of the key/window + // What if it's a duplicate event - it will reset everything. Shall we drop/DLQ anything + // that's before the processingState.lastOutputSequence? + state = eventExaminer.createStateOnInitialEvent(currentEvent); + + processingState.eventAccepted(currentSequence, thisIsTheLastEvent); + + ResultT result = state.produceResult(); + if (result != null) { + outputReceiver.get(mainOutputTupleTag).output(KV.of(processingState.getKey(), result)); + processingState.resultProduced(); + } + + // Nothing else to do. We will attempt to process buffered events later. + return state; + } + + if (processingState.isNextEvent(currentSequence)) { + // Event matches expected sequence + state = currentStateState.read(); + if (state == null) { + LOG.warn("Unexpectedly got an empty state. Most likely cause is pipeline drainage."); + return null; + } + + try { + state.mutate(currentEvent); + } catch (Exception e) { + outputReceiver + .get(unprocessedEventsTupleTag) + .output( + KV.of( + processingState.getKey(), + KV.of(currentSequence, UnprocessedEvent.create(currentEvent, e)))); + return null; + } + + ResultT result = state.produceResult(); + if (result != null) { + outputReceiver.get(mainOutputTupleTag).output(KV.of(processingState.getKey(), result)); + processingState.resultProduced(); + } + processingState.eventAccepted(currentSequence, thisIsTheLastEvent); + + return state; + } + + // Event is not ready to be processed yet + bufferEvent( + currentSequence, currentEvent, processingState, bufferedEventsState, thisIsTheLastEvent); + + // This will signal that the state hasn't been mutated. We don't need to save it. + return null; + } + + protected void saveStates( + ValueState> processingStatusState, + ProcessingState processingStatus, + ValueState currentStateState, + @Nullable StateT state, + MultiOutputReceiver outputReceiver, + Instant windowTimestamp) { + // There is always a change to the processing status + processingStatusState.write(processingStatus); + + // Stored state may not have changes if the element was out of sequence. + if (state != null) { + currentStateState.write(state); + } + + if (produceStatusUpdateOnEveryEvent) { + // During pipeline draining the window timestamp is set to a large value in the future. + // Producing an event before that results in error, that's why this logic exist. + Instant statusTimestamp = windowTimestamp; + + emitProcessingStatus(processingStatus, outputReceiver, statusTimestamp); + } + } + + void processStatusTimerEvent( + MultiOutputReceiver outputReceiver, + Timer statusEmissionTimer, + ValueState windowClosedState, + ValueState> processingStateState) { + ProcessingState currentState = processingStateState.read(); + if (currentState == null) { + // This could happen if the state has been purged already during the draining. + // It means that there is nothing that we can do. + LOG.warn( + "Current processing state is null in onStatusEmission() - most likely the pipeline is shutting down."); + return; + } + + emitProcessingStatus(currentState, outputReceiver, Instant.now()); + + Boolean windowClosed = windowClosedState.read(); + if (!currentState.isProcessingCompleted() + // Stop producing statuses if we are finished for a particular key + && (windowClosed == null || !windowClosed)) { + statusEmissionTimer.offset(statusUpdateFrequency).setRelative(); + } + } + + protected void emitProcessingStatus( + ProcessingState processingState, + MultiOutputReceiver outputReceiver, + Instant statusTimestamp) { + if (LOG.isTraceEnabled()) { + LOG.trace("Emitting status for: " + processingState.getKey() + ", " + processingState); + } + outputReceiver + .get(statusTupleTag) + .outputWithTimestamp( + KV.of( + processingState.getKey(), + OrderedProcessingStatus.create( + processingState.getLastOutputSequence(), + processingState.getBufferedEventCount(), + processingState.getEarliestBufferedSequence(), + processingState.getLatestBufferedSequence(), + processingState.getEventsReceived(), + processingState.getResultCount(), + processingState.getDuplicates(), + processingState.isLastEventReceived())), + statusTimestamp); + } + + protected boolean reachedMaxResultCountForBundle( + ProcessingState processingState, Timer largeBatchEmissionTimer) { + boolean exceeded = + processingState.resultsProducedInBundle( + numberOfResultsBeforeBundleStart == null ? 0 : numberOfResultsBeforeBundleStart) + >= maxNumberOfResultsToProduce; + if (exceeded) { + if (LOG.isTraceEnabled()) { + LOG.trace( + "Setting the timer to output next batch of events for key '" + + processingState.getKey() + + "'"); + } + // See GroupIntoBatches for examples on how to hold the timestamp. + // TODO: test that on draining the pipeline all the results are still produced correctly. + // See: https://github.com/apache/beam/issues/30781 + largeBatchEmissionTimer.offset(Duration.millis(1)).setRelative(); + } + return exceeded; + } + + private void bufferEvent( + long currentSequence, + EventT currentEvent, + ProcessingState processingState, + OrderedListState bufferedEventsState, + boolean thisIsTheLastEvent) { + Instant eventTimestamp = fromLong(currentSequence); + bufferedEventsState.add(TimestampedValue.of(currentEvent, eventTimestamp)); + processingState.eventBuffered(currentSequence, thisIsTheLastEvent); + } + + abstract boolean checkForSequenceGapInBufferedEvents(); + + @Nullable + StateT processBufferedEventRange( + ProcessingState processingState, + @Nullable StateT state, + OrderedListState bufferedEventsState, + MultiOutputReceiver outputReceiver, + Timer largeBatchEmissionTimer, + ContiguousSequenceRange contiguousSequenceRange) { + Long earliestBufferedSequence = processingState.getEarliestBufferedSequence(); + Long latestBufferedSequence = processingState.getLatestBufferedSequence(); + if (earliestBufferedSequence == null || latestBufferedSequence == null) { + return state; + } + Instant startRange = fromLong(earliestBufferedSequence); + Instant endRange = fromLong(latestBufferedSequence + 1); + + // readRange is efficiently implemented and will bring records in batches + Iterable> events = bufferedEventsState.readRange(startRange, endRange); + + Instant endClearRange = startRange; // it will get re-adjusted later. + + Iterator> bufferedEventsIterator = events.iterator(); + while (bufferedEventsIterator.hasNext()) { + TimestampedValue timestampedEvent = bufferedEventsIterator.next(); + Instant eventTimestamp = timestampedEvent.getTimestamp(); + long eventSequence = eventTimestamp.getMillis(); + + EventT bufferedEvent = timestampedEvent.getValue(); + boolean skipProcessing = false; + boolean beforeInitialSequence = false; + + if (contiguousSequenceRange != null && eventSequence < contiguousSequenceRange.getStart()) { + // In case of global sequence processing - remove the elements below the range start + skipProcessing = true; + beforeInitialSequence = true; + endClearRange = fromLong(eventSequence); + } + if (processingState.checkForDuplicateBatchedEvent(eventSequence)) { + // There could be multiple events under the same sequence number. Only the first one + // will get processed. The rest are considered duplicates. + skipProcessing = true; + } + + if (skipProcessing) { + outputReceiver + .get(unprocessedEventsTupleTag) + .output( + KV.of( + processingState.getKey(), + KV.of( + eventSequence, + UnprocessedEvent.create( + bufferedEvent, + beforeInitialSequence + ? Reason.before_initial_sequence + : Reason.duplicate)))); + // TODO: When there is a large number of duplicates this can cause a situation where + // we produce too much output and the runner will start throwing unrecoverable errors. + // Need to add counting logic to accumulate both the normal and DLQ outputs. + continue; + } + + Long lastOutputSequence = processingState.getLastOutputSequence(); + boolean currentEventIsNextInSequence = + lastOutputSequence != null && eventSequence == lastOutputSequence + 1; + boolean continueProcessing = + checkForSequenceGapInBufferedEvents() + ? currentEventIsNextInSequence + : (eventSequence < contiguousSequenceRange.getEnd() || currentEventIsNextInSequence); + if (!continueProcessing) { + processingState.foundSequenceGap(eventSequence); + // Records will be cleared up to this element + endClearRange = fromLong(eventSequence); + break; + } + + // This check needs to be done after we checked for sequence gap and before we + // attempt to process the next element which can result in a new result. + if (reachedMaxResultCountForBundle(processingState, largeBatchEmissionTimer)) { + endClearRange = fromLong(eventSequence); + break; + } + + // Remove this record also + endClearRange = fromLong(eventSequence + 1); + + try { + if (state == null) { + if (LOG.isTraceEnabled()) { + LOG.trace("Creating a new state: " + processingState.getKey() + " " + bufferedEvent); + } + state = eventExaminer.createStateOnInitialEvent(bufferedEvent); + } else { + if (LOG.isTraceEnabled()) { + LOG.trace("Mutating " + processingState.getKey() + " " + bufferedEvent); + } + state.mutate(bufferedEvent); + } + } catch (Exception e) { + outputReceiver + .get(unprocessedEventsTupleTag) + .output( + KV.of( + processingState.getKey(), + KV.of(eventSequence, UnprocessedEvent.create(bufferedEvent, e)))); + // There is a chance that the next event will have the same sequence number and will + // process successfully. + continue; + } + + ResultT result = state.produceResult(); + if (result != null) { + outputReceiver.get(mainOutputTupleTag).output(KV.of(processingState.getKey(), result)); + processingState.resultProduced(); + } + processingState.processedBufferedEvent(eventSequence); + } + + bufferedEventsState.clearRange(startRange, endClearRange); + + return state; + } + + static Instant fromLong(long value) { + return Instant.ofEpochMilli(value); + } +} diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/SequencePerKeyProcessorDoFn.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/SequencePerKeyProcessorDoFn.java new file mode 100644 index 000000000000..878a0664ac87 --- /dev/null +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/SequencePerKeyProcessorDoFn.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered; + +import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.extensions.ordered.ProcessingState.ProcessingStateCoder; +import org.apache.beam.sdk.state.OrderedListState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Stateful DoFn to process per key sequences. + * + * @param event type + * @param event key type + * @param result type + * @param state type + */ +class SequencePerKeyProcessorDoFn< + EventTypeT, + EventKeyTypeT, + ResultTypeT, + StateTypeT extends MutableState> + extends ProcessorDoFn { + + private static final Logger LOG = LoggerFactory.getLogger(SequencePerKeyProcessorDoFn.class); + + private static final String LARGE_BATCH_EMISSION_TIMER = "largeBatchTimer"; + protected static final String BUFFERED_EVENTS = "bufferedEvents"; + + @TimerId(LARGE_BATCH_EMISSION_TIMER) + @SuppressWarnings("unused") + private final TimerSpec largeBatchEmissionTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @StateId(BUFFERED_EVENTS) + @SuppressWarnings("unused") + private final StateSpec> bufferedEventsSpec; + + @SuppressWarnings("unused") + @StateId(MUTABLE_STATE) + private final StateSpec> mutableStateSpec; + + @StateId(WINDOW_CLOSED) + @SuppressWarnings("unused") + private final StateSpec> windowClosedSpec; + + @TimerId(STATUS_EMISSION_TIMER) + @SuppressWarnings("unused") + private final TimerSpec statusEmissionTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + @StateId(PROCESSING_STATE) + @SuppressWarnings("unused") + private final StateSpec>> processingStateSpec; + + /** + * Stateful DoFn to do the bulk of processing. + * + * @param eventExaminer + * @param eventCoder + * @param stateCoder + * @param keyCoder + * @param mainOutputTupleTag + * @param statusTupleTag + * @param statusUpdateFrequency + * @param unprocessedEventTupleTag + * @param produceStatusUpdateOnEveryEvent + * @param maxNumberOfResultsToProduce + */ + SequencePerKeyProcessorDoFn( + EventExaminer eventExaminer, + Coder eventCoder, + Coder stateCoder, + Coder keyCoder, + TupleTag> mainOutputTupleTag, + TupleTag> statusTupleTag, + Duration statusUpdateFrequency, + TupleTag>>> unprocessedEventTupleTag, + boolean produceStatusUpdateOnEveryEvent, + long maxNumberOfResultsToProduce) { + super( + eventExaminer, + mainOutputTupleTag, + statusTupleTag, + statusUpdateFrequency, + unprocessedEventTupleTag, + produceStatusUpdateOnEveryEvent, + maxNumberOfResultsToProduce); + this.bufferedEventsSpec = StateSpecs.orderedList(eventCoder); + this.processingStateSpec = StateSpecs.value(ProcessingStateCoder.of(keyCoder)); + this.mutableStateSpec = StateSpecs.value(stateCoder); + this.windowClosedSpec = StateSpecs.value(BooleanCoder.of()); + } + + @Override + boolean checkForFirstOrLastEvent() { + return true; + } + + @Override + boolean checkForSequenceGapInBufferedEvents() { + return true; + } + + @ProcessElement + public void processElement( + @StateId(BUFFERED_EVENTS) OrderedListState bufferedEventsState, + @AlwaysFetched @StateId(PROCESSING_STATE) + ValueState> processingStateState, + @StateId(MUTABLE_STATE) ValueState mutableStateState, + @TimerId(STATUS_EMISSION_TIMER) Timer statusEmissionTimer, + @TimerId(LARGE_BATCH_EMISSION_TIMER) Timer largeBatchEmissionTimer, + @Element KV> eventAndSequence, + MultiOutputReceiver outputReceiver, + BoundedWindow window, + ProcessContext context) { + EventKeyTypeT key = eventAndSequence.getKey(); + long sequence = eventAndSequence.getValue().getKey(); + EventTypeT event = eventAndSequence.getValue().getValue(); + + ProcessingState processingState = processingStateState.read(); + + if (processingState == null) { + // This is the first time we see this key/window pair + processingState = new ProcessingState<>(key); + if (statusUpdateFrequency != null) { + // Set up the timer to produce the status of the processing on a regular basis + statusEmissionTimer.offset(statusUpdateFrequency).setRelative(); + } + } + + if (numberOfResultsBeforeBundleStart == null) { + // Per key processing is synchronized by Beam. There is no need to have it here. + numberOfResultsBeforeBundleStart = processingState.getResultCount(); + } + + processingState.eventReceived(); + + StateTypeT state = + processNewEvent( + sequence, + event, + processingState, + mutableStateState, + bufferedEventsState, + outputReceiver); + + processBufferedEvents( + processingState, state, bufferedEventsState, outputReceiver, largeBatchEmissionTimer); + + saveStates( + processingStateState, + processingState, + mutableStateState, + state, + outputReceiver, + window.maxTimestamp()); + + checkIfProcessingIsCompleted(processingState); + } + + private boolean checkIfProcessingIsCompleted(ProcessingState processingState) { + boolean result = processingState.isProcessingCompleted(); + if (result && LOG.isTraceEnabled()) { + LOG.trace("Processing for key '" + processingState.getKey() + "' is completed."); + } + return result; + } + + /** Process buffered events. */ + private void processBufferedEvents( + ProcessingState processingState, + @Nullable StateTypeT state, + OrderedListState bufferedEventsState, + MultiOutputReceiver outputReceiver, + Timer largeBatchEmissionTimer) { + if (state == null) { + // Only when the current event caused a state mutation and the state is passed to this + // method should we attempt to process buffered events + return; + } + + if (!processingState.readyToProcessBufferedEvents()) { + return; + } + + if (reachedMaxResultCountForBundle(processingState, largeBatchEmissionTimer)) { + // No point in trying to process buffered events + return; + } + + // Technically this block is not needed because these preconditions are checked + // earlier. Included to keep the linter happy. + Long earliestBufferedSequence = processingState.getEarliestBufferedSequence(); + if (earliestBufferedSequence == null) { + return; + } + Long latestBufferedSequence = processingState.getLatestBufferedSequence(); + if (latestBufferedSequence == null) { + return; + } + + processBufferedEventRange( + processingState, + state, + bufferedEventsState, + outputReceiver, + largeBatchEmissionTimer, + ContiguousSequenceRange.EMPTY); + } + + @OnTimer(LARGE_BATCH_EMISSION_TIMER) + public void onBatchEmission( + OnTimerContext context, + @StateId(BUFFERED_EVENTS) OrderedListState bufferedEventsState, + @AlwaysFetched @StateId(PROCESSING_STATE) + ValueState> processingStatusState, + @AlwaysFetched @StateId(MUTABLE_STATE) ValueState currentStateState, + @TimerId(LARGE_BATCH_EMISSION_TIMER) Timer largeBatchEmissionTimer, + MultiOutputReceiver outputReceiver) { + ProcessingState processingState = processingStatusState.read(); + if (processingState == null) { + LOG.warn("Processing state is empty. Ignore it if the pipeline is being cancelled."); + return; + } + StateTypeT state = currentStateState.read(); + if (state == null) { + LOG.warn("Mutable state is empty. Ignore it if the pipeline is being cancelled."); + return; + } + + LOG.debug("Starting to process batch for key '" + processingState.getKey() + "'"); + + this.numberOfResultsBeforeBundleStart = processingState.getResultCount(); + + processBufferedEvents( + processingState, state, bufferedEventsState, outputReceiver, largeBatchEmissionTimer); + + saveStates( + processingStatusState, + processingState, + currentStateState, + state, + outputReceiver, + // TODO: validate that this is correct. + context.window().maxTimestamp()); + + checkIfProcessingIsCompleted(processingState); + } + + @OnTimer(STATUS_EMISSION_TIMER) + @SuppressWarnings("unused") + public void onStatusEmission( + MultiOutputReceiver outputReceiver, + @TimerId(STATUS_EMISSION_TIMER) Timer statusEmissionTimer, + @StateId(WINDOW_CLOSED) ValueState windowClosedState, + @StateId(PROCESSING_STATE) ValueState> processingStateState) { + + processStatusTimerEvent( + outputReceiver, statusEmissionTimer, windowClosedState, processingStateState); + } + + @OnWindowExpiration + public void onWindowExpiration(@StateId(WINDOW_CLOSED) ValueState windowClosedState) { + windowClosedState.write(true); + } +} diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/UnprocessedEvent.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/UnprocessedEvent.java index 2131ef384e22..d7c599277567 100644 --- a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/UnprocessedEvent.java +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/UnprocessedEvent.java @@ -72,7 +72,8 @@ public enum Reason { duplicate, buffered, sequence_id_outside_valid_range, - exception_thrown + exception_thrown, + before_initial_sequence }; public abstract EventT getEvent(); diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/combiner/DefaultSequenceCombiner.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/combiner/DefaultSequenceCombiner.java new file mode 100644 index 000000000000..32e5cbc36e4e --- /dev/null +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/combiner/DefaultSequenceCombiner.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered.combiner; + +import java.util.Iterator; +import java.util.function.BiFunction; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.extensions.ordered.ContiguousSequenceRange; +import org.apache.beam.sdk.extensions.ordered.EventExaminer; +import org.apache.beam.sdk.extensions.ordered.MutableState; +import org.apache.beam.sdk.extensions.ordered.combiner.SequenceRangeAccumulator.SequenceRangeAccumulatorCoder; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TimestampedValue; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Default global sequence combiner. + * + *

    Produces the largest {@link ContiguousSequenceRange} of contiguous longs which starts from the + * initial event identified by {@link EventExaminer#isInitialEvent(long, EventT)}. + * + *

    This combiner currently doesn't use {@link EventExaminer#isLastEvent(long, EventT)}. + * + * @param type of key + * @param type of event + * @param type of state + */ +public class DefaultSequenceCombiner> + extends CombineFn< + TimestampedValue>>, + SequenceRangeAccumulator, + ContiguousSequenceRange> { + + private static final Logger LOG = LoggerFactory.getLogger(DefaultSequenceCombiner.class); + + public static final BiFunction<@NonNull Instant, @Nullable Instant, @Nullable Instant> + OLDEST_TIMESTAMP_SELECTOR = + (instant1, instant2) -> { + if (instant2 == null) { + return instant1; + } + @NonNull Instant nonNullableSecondValue = instant2; + return instant1.isAfter(nonNullableSecondValue) ? instant1 : nonNullableSecondValue; + }; + private final EventExaminer eventExaminer; + + public DefaultSequenceCombiner(EventExaminer eventExaminer) { + this.eventExaminer = eventExaminer; + } + + @Override + public SequenceRangeAccumulator createAccumulator() { + return new SequenceRangeAccumulator(); + } + + @Override + public SequenceRangeAccumulator addInput( + SequenceRangeAccumulator accum, TimestampedValue>> event) { + long sequence = event.getValue().getValue().getKey(); + + accum.add( + sequence, + event.getTimestamp(), + eventExaminer.isInitialEvent(sequence, event.getValue().getValue().getValue())); + + return accum; + } + + @Override + public SequenceRangeAccumulator mergeAccumulators( + Iterable accumulators) { + // There should be at least one accumulator. + Iterator iterator = accumulators.iterator(); + SequenceRangeAccumulator result = iterator.next(); + while (iterator.hasNext()) { + result.merge(iterator.next()); + } + return result; + } + + @Override + public ContiguousSequenceRange extractOutput(SequenceRangeAccumulator accum) { + ContiguousSequenceRange result = accum.largestContinuousRange(); + if (LOG.isTraceEnabled()) { + LOG.trace("Returning completed sequence range: " + result); + } + return result; + } + + @Override + public @UnknownKeyFor @NonNull @Initialized Coder getAccumulatorCoder( + @UnknownKeyFor @NonNull @Initialized CoderRegistry registry, + @UnknownKeyFor @NonNull @Initialized + Coder>>> inputCoder) + throws @UnknownKeyFor @NonNull @Initialized CannotProvideCoderException { + return SequenceRangeAccumulatorCoder.of(); + } +} diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/combiner/SequenceRangeAccumulator.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/combiner/SequenceRangeAccumulator.java new file mode 100644 index 000000000000..89dc912afc90 --- /dev/null +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/combiner/SequenceRangeAccumulator.java @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered.combiner; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.SortedMap; +import java.util.TreeMap; +import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.extensions.ordered.ContiguousSequenceRange; +import org.apache.commons.lang3.tuple.Pair; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.joda.time.Instant; + +/** Default accumulator used to combine sequence ranges. */ +public class SequenceRangeAccumulator { + + private static Instant max(Instant a, Instant b) { + return a.isAfter(b) ? a : b; + } + + /** + * The tree contains a set of non-overlapping contiguous ranges, where the key is the lower + * inclusive start of the range, left value of the pair is the inclusive end of the range and the + * right value of the pair is the maximum timestamp in the range. + * + *

    The maximum timestamp is critical for the correctness of the ordered processing. During the + * merge process the merged range is assigned the maximum timestamp of the two ranges that created + * this new range. + */ + private final TreeMap> data = new TreeMap<>(); + + private @Nullable Long initialSequence = null; + + public void add(long sequence, Instant timestamp, boolean isInitialSequence) { + if (isInitialSequence && this.initialSequence != null && sequence != this.initialSequence) { + throw new IllegalStateException( + "There are different initial sequences detected: " + + initialSequence + + " and " + + sequence); + } + + if (sequence == Long.MAX_VALUE) { + // This is an invalid value and DoFns will not process this element. This will also allow + // to produce a ContiguousSequenceRange with the exclusive end value. + return; + } + + if (isInitialSequence) { + this.initialSequence = sequence; + clearRangesBelowInitialSequence(sequence, timestamp); + } else if (initialSequence != null && sequence <= initialSequence) { + // No need to add anything lower than the initial sequence to the accumulator. + return; + } + + long lowerBound = sequence, upperBound = sequence; + + Entry> lowerRange = data.floorEntry(sequence); + if (lowerRange != null) { + long inclusiveUpperBoundary = lowerRange.getValue().getLeft(); + if (sequence <= inclusiveUpperBoundary) { + // Duplicate. No need to adjust the timestamp. + return; + } + + if (inclusiveUpperBoundary + 1 == sequence) { + // The new element extends the lower range. Remove the range. + timestamp = max(timestamp, lowerRange.getValue().getValue()); + lowerBound = lowerRange.getKey(); + data.remove(lowerRange.getKey()); + } + } + + long nextSequenceNumber = sequence + 1; + Pair upperRange = data.get(nextSequenceNumber); + if (upperRange != null) { + // The new element will extend the upper range. Remove the range. + timestamp = max(timestamp, upperRange.getRight()); + upperBound = upperRange.getLeft(); + data.remove(nextSequenceNumber); + } + + data.put(lowerBound, Pair.of(upperBound, timestamp)); + } + + private void clearRangesBelowInitialSequence(long sequence, Instant timestamp) { + // First, adjust the current range, if any + Entry> lowerRange = data.floorEntry(sequence); + if (lowerRange != null + && lowerRange.getKey() < sequence + && lowerRange.getValue().getLeft() > sequence) { + // The sequence is in the middle of the range. Adjust it. + data.remove(lowerRange.getKey()); + data.put( + sequence, + Pair.of( + lowerRange.getValue().getKey(), max(timestamp, lowerRange.getValue().getValue()))); + } + data.subMap(Long.MIN_VALUE, sequence).clear(); + } + + public ContiguousSequenceRange largestContinuousRange() { + if (initialSequence == null) { + return ContiguousSequenceRange.EMPTY; + } + + Entry> firstEntry = data.firstEntry(); + if (firstEntry == null) { + throw new IllegalStateException("First entry is null when initial sequence is set."); + } + Long start = firstEntry.getKey(); + Long end = firstEntry.getValue().getLeft(); + Instant latestTimestamp = firstEntry.getValue().getRight(); + // Upper bound is inclusive, but the ContiguousSequenceRange's end is exclusive. + // The numeric overflow is prevented by dropping the value of Long.MAX. + return ContiguousSequenceRange.of(start, end + 1, latestTimestamp); + } + + public int numberOfRanges() { + return data.size(); + } + + public void merge(SequenceRangeAccumulator another) { + if (this.initialSequence != null + && another.initialSequence != null + && !this.initialSequence.equals(another.initialSequence)) { + throw new IllegalStateException( + "Two accumulators contain different initial sequences: " + + this.initialSequence + + " and " + + another.initialSequence); + } + + if (another.initialSequence != null) { + long newInitialSequence = another.initialSequence; + this.initialSequence = newInitialSequence; + Entry> firstEntry = another.data.firstEntry(); + if (firstEntry != null) { + Instant timestampOfTheInitialRange = firstEntry.getValue().getRight(); + clearRangesBelowInitialSequence(newInitialSequence, timestampOfTheInitialRange); + } + } + + another + .data + .entrySet() + .forEach( + entry -> { + long lowerBound = entry.getKey(); + long upperBound = entry.getValue().getLeft(); + if (this.initialSequence != null) { + if (upperBound < initialSequence) { + // The whole range is below the initial sequence. Ignore it. + return; + } + if (lowerBound < initialSequence) { + // This will cause pruning of the range up to the initial sequence + lowerBound = this.initialSequence; + } + } + + Entry> lowerRange = this.data.floorEntry(lowerBound); + + if (lowerRange != null) { + if (lowerRange.getValue().getLeft() < lowerBound - 1) { + // Nothing to do. There is a lower non-adjacent range. + } else { + // We found an overlapping range and will replace it with a new one + upperBound = Math.max(upperBound, lowerRange.getValue().getLeft()); + lowerBound = lowerRange.getKey(); + } + } + + Entry> upperRange = this.data.floorEntry(upperBound + 1); + if (upperRange == null + || (lowerRange != null + && Objects.equals(upperRange.getKey(), lowerRange.getKey()))) { + // Nothing to do - either there is no adjacent upper range or it equals the lower + // range + } else { + upperBound = Math.max(upperBound, upperRange.getValue().getLeft()); + } + + Instant latestTimestamp = + removeAllRanges(lowerBound, upperBound, entry.getValue().getRight()); + + this.data.put(lowerBound, Pair.of(upperBound, latestTimestamp)); + }); + } + + private Instant removeAllRanges(long lowerBound, long upperBound, Instant currentTimestamp) { + Instant result = currentTimestamp; + SortedMap> rangesToRemove = data.subMap(lowerBound, upperBound); + for (Pair value : rangesToRemove.values()) { + result = result.isAfter(value.getRight()) ? result : value.getRight(); + } + rangesToRemove.clear(); + return result; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SequenceRangeAccumulator)) { + return false; + } + SequenceRangeAccumulator that = (SequenceRangeAccumulator) o; + return data.equals(that.data) && Objects.equals(initialSequence, that.initialSequence); + } + + @Override + public int hashCode() { + return Objects.hash(data, initialSequence); + } + + @Override + public String toString() { + return "SequenceRangeAccumulator{initialSequence=" + initialSequence + ", data=" + data + '}'; + } + + public static class SequenceRangeAccumulatorCoder extends CustomCoder { + + private static final SequenceRangeAccumulatorCoder INSTANCE = + new SequenceRangeAccumulatorCoder(); + + public static SequenceRangeAccumulatorCoder of() { + return INSTANCE; + } + + private SequenceRangeAccumulatorCoder() {} + + private final NullableCoder initialSequenceCoder = NullableCoder.of(VarLongCoder.of()); + private final VarIntCoder numberOfRangesCoder = VarIntCoder.of(); + private final VarLongCoder dataCoder = VarLongCoder.of(); + + @Override + public void encode( + SequenceRangeAccumulator value, @UnknownKeyFor @NonNull @Initialized OutputStream outStream) + throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull + @Initialized IOException { + numberOfRangesCoder.encode(value.numberOfRanges(), outStream); + initialSequenceCoder.encode(value.initialSequence, outStream); + for (Entry> entry : value.data.entrySet()) { + dataCoder.encode(entry.getKey(), outStream); + dataCoder.encode(entry.getValue().getLeft(), outStream); + dataCoder.encode(entry.getValue().getRight().getMillis(), outStream); + } + } + + @Override + public SequenceRangeAccumulator decode( + @UnknownKeyFor @NonNull @Initialized InputStream inStream) + throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull + @Initialized IOException { + SequenceRangeAccumulator result = new SequenceRangeAccumulator(); + int numberOfRanges = numberOfRangesCoder.decode(inStream); + result.initialSequence = initialSequenceCoder.decode(inStream); + for (int i = 0; i < numberOfRanges; i++) { + long key = dataCoder.decode(inStream); + long upperBound = dataCoder.decode(inStream); + long millis = dataCoder.decode(inStream); + result.data.put(key, Pair.of(upperBound, Instant.ofEpochMilli(millis))); + } + return result; + } + } +} diff --git a/sdks/java/io/kafka/kafka-01103/build.gradle b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/combiner/package-info.java similarity index 63% rename from sdks/java/io/kafka/kafka-01103/build.gradle rename to sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/combiner/package-info.java index 3a74bf04ef22..0d730d55fb9f 100644 --- a/sdks/java/io/kafka/kafka-01103/build.gradle +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/combiner/package-info.java @@ -4,21 +4,20 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ -project.ext { - delimited="0.11.0.3" - undelimited="01103" - sdfCompatible=false -} - -apply from: "../kafka-integration-test.gradle" \ No newline at end of file +/** + * Default implementation of the global sequence combiner used by {@link + * org.apache.beam.sdk.extensions.ordered.OrderedEventProcessor} when processing events using global + * sequences. + */ +package org.apache.beam.sdk.extensions.ordered.combiner; diff --git a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/package-info.java b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/package-info.java index f9d7e3d67bff..4cbbca82a8cf 100644 --- a/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/package-info.java +++ b/sdks/java/extensions/ordered/src/main/java/org/apache/beam/sdk/extensions/ordered/package-info.java @@ -16,7 +16,9 @@ * limitations under the License. */ /** - * Provides a transform for ordered processing. + * Provides a transform for ordered processing. For a detailed reference implementation which uses + * this transform visit {@link https://github.com/GoogleCloudPlatform/dataflow-ordered-processing} * * @see org.apache.beam.sdk.extensions.ordered.OrderedEventProcessor */ diff --git a/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorGlobalSequenceTest.java b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorGlobalSequenceTest.java new file mode 100644 index 000000000000..98bc7591f4d7 --- /dev/null +++ b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorGlobalSequenceTest.java @@ -0,0 +1,534 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.extensions.ordered.StringBufferOrderedProcessingHandler.StringBufferOrderedProcessingWithGlobalSequenceHandler; +import org.apache.beam.sdk.extensions.ordered.UnprocessedEvent.Reason; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; + +public class OrderedEventProcessorGlobalSequenceTest extends OrderedEventProcessorTestBase { + + public static final boolean GLOBAL_SEQUENCE = true; + + static { + Logger logger = Logger.getLogger(GlobalSequencesProcessorDoFn.class.getName()); + logger.setLevel(Level.FINEST); + } + + @org.junit.Test + public void testPerfectOrderingProcessing() throws CannotProvideCoderException { + Event[] events = { + Event.create(0, "id-1", "a"), + Event.create(1, "id-1", "b"), + Event.create(2, "id-1", "c"), + Event.create(3, "id-1", "d"), + Event.create(4, "id-2", "a"), + Event.create(5, "id-2", "b") + }; + + Collection> expectedOutput = new ArrayList<>(); + expectedOutput.add(KV.of("id-1", "a")); + expectedOutput.add(KV.of("id-1", "ab")); + expectedOutput.add(KV.of("id-1", "abc")); + expectedOutput.add(KV.of("id-1", "abcd")); + expectedOutput.add(KV.of("id-2", "a")); + expectedOutput.add(KV.of("id-2", "ab")); + + testGlobalSequenceProcessing( + events, + expectedOutput, + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, + INITIAL_SEQUENCE_OF_0, + LARGE_MAX_RESULTS_PER_OUTPUT, + ContiguousSequenceRange.of(0, 6, new Instant())); + } + + @Test + public void testOutOfSequenceProcessing() throws CannotProvideCoderException { + Event[] events = { + Event.create(2, "id-1", "c"), + Event.create(1, "id-1", "b"), + Event.create(0, "id-1", "a"), + Event.create(3, "id-1", "d"), + Event.create(5, "id-2", "b"), + Event.create(6, "id-2", "c"), + Event.create(8, "id-2", "e"), + Event.create(4, "id-2", "a"), + Event.create(7, "id-2", "d") + }; + + Collection> expectedOutput = new ArrayList<>(); + expectedOutput.add(KV.of("id-1", "a")); + expectedOutput.add(KV.of("id-1", "ab")); + expectedOutput.add(KV.of("id-1", "abc")); + expectedOutput.add(KV.of("id-1", "abcd")); + expectedOutput.add(KV.of("id-2", "a")); + expectedOutput.add(KV.of("id-2", "ab")); + expectedOutput.add(KV.of("id-2", "abc")); + expectedOutput.add(KV.of("id-2", "abcd")); + expectedOutput.add(KV.of("id-2", "abcde")); + + testGlobalSequenceProcessing( + events, + expectedOutput, + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, + INITIAL_SEQUENCE_OF_0, + LARGE_MAX_RESULTS_PER_OUTPUT, + ContiguousSequenceRange.of(0, 9, new Instant())); + } + + @Test + public void testHandlingOfDuplicateSequences() throws CannotProvideCoderException { + Event[] events = { + Event.create(3, "id-1", "d"), + Event.create(2, "id-1", "c"), + + // Duplicates + Event.create(3, "id-1", "d"), + Event.create(3, "id-1", "d"), + Event.create(0, "id-1", "a"), + Event.create(1, "id-1", "b"), + + // Additional duplicates + Event.create(1, "id-1", "b"), + Event.create(3, "id-1", "d"), + }; + + Collection> expectedOutput = new ArrayList<>(); + expectedOutput.add(KV.of("id-1", "a")); + expectedOutput.add(KV.of("id-1", "ab")); + expectedOutput.add(KV.of("id-1", "abc")); + expectedOutput.add(KV.of("id-1", "abcd")); + + Collection>>> duplicates = new ArrayList<>(); + duplicates.add(KV.of("id-1", KV.of(3L, UnprocessedEvent.create("d", Reason.duplicate)))); + duplicates.add(KV.of("id-1", KV.of(3L, UnprocessedEvent.create("d", Reason.duplicate)))); + duplicates.add(KV.of("id-1", KV.of(1L, UnprocessedEvent.create("b", Reason.duplicate)))); + duplicates.add(KV.of("id-1", KV.of(3L, UnprocessedEvent.create("d", Reason.duplicate)))); + + testGlobalSequenceProcessing( + events, + expectedOutput, + duplicates, + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, + INITIAL_SEQUENCE_OF_0, + LARGE_MAX_RESULTS_PER_OUTPUT, + ContiguousSequenceRange.of(0, 4, new Instant())); + } + + @Test + public void testTreatingSequencesBelowInitialAsDuplicates() throws CannotProvideCoderException { + Event[] events = { + Event.create(3, "id-1", "d"), + Event.create(2, "id-1", "c"), + + // Earlier events + Event.create(-1, "id-1", "early-1"), + Event.create(-2, "id-1", "early-2"), + Event.create(0, "id-1", "a"), + Event.create(1, "id-1", "b") + }; + + Collection> expectedOutput = new ArrayList<>(); + expectedOutput.add(KV.of("id-1", "a")); + expectedOutput.add(KV.of("id-1", "ab")); + expectedOutput.add(KV.of("id-1", "abc")); + expectedOutput.add(KV.of("id-1", "abcd")); + + Collection>>> duplicates = new ArrayList<>(); + duplicates.add( + KV.of( + "id-1", + KV.of(-1L, UnprocessedEvent.create("early-1", Reason.before_initial_sequence)))); + duplicates.add( + KV.of( + "id-1", + KV.of(-2L, UnprocessedEvent.create("early-2", Reason.before_initial_sequence)))); + + testGlobalSequenceProcessing( + events, + expectedOutput, + duplicates, + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, + INITIAL_SEQUENCE_OF_0, + LARGE_MAX_RESULTS_PER_OUTPUT, + ContiguousSequenceRange.of(0, 4, new Instant())); + } + + @Test + public void testHandlingOfCheckedExceptions() throws CannotProvideCoderException { + Event[] events = { + Event.create(0, "id-1", "a"), + Event.create(1, "id-1", "b"), + Event.create(2, "id-1", StringBuilderState.BAD_VALUE), + Event.create(3, "id-1", "c"), + }; + + // This is an interesting case - even though event #2 is not processed it doesn't affect + // the global sequence calculations. It is not considered a gap, and all the subsequent + // events will be processed. + Collection> expectedOutput = new ArrayList<>(); + expectedOutput.add(KV.of("id-1", "a")); + expectedOutput.add(KV.of("id-1", "ab")); + expectedOutput.add(KV.of("id-1", "abc")); + + Collection>>> failedEvents = new ArrayList<>(); + failedEvents.add( + KV.of( + "id-1", + KV.of( + 2L, + UnprocessedEvent.create(StringBuilderState.BAD_VALUE, Reason.exception_thrown)))); + + testGlobalSequenceProcessing( + events, + expectedOutput, + failedEvents, + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, + INITIAL_SEQUENCE_OF_0, + LARGE_MAX_RESULTS_PER_OUTPUT, + // Sequence matcher doesn't know if the element is valid or not. + // That's why the elements that are get rejected in the processor still count when + // calculating the global sequence + ContiguousSequenceRange.of(0, 4, new Instant())); + } + + @Test + public void testProcessingWithEveryOtherResultEmission() throws CannotProvideCoderException { + Event[] events = { + Event.create(2, "id-1", "c"), + Event.create(1, "id-1", "b"), + Event.create(0, "id-1", "a"), + Event.create(3, "id-1", "d"), + Event.create(4, "id-2", "a"), + Event.create(5, "id-2", "b"), + }; + + Collection> expectedOutput = new ArrayList<>(); + expectedOutput.add(KV.of("id-1", "a")); + // Skipped KV.of("id-1", "ab"), + expectedOutput.add(KV.of("id-1", "abc")); + // Skipped KV.of("id-1", "abcd"), + expectedOutput.add(KV.of("id-2", "a")); + // Skipped KV.of("id-2", "ab") + testGlobalSequenceProcessing( + events, + expectedOutput, + EMISSION_FREQUENCY_ON_EVERY_OTHER_EVENT, + INITIAL_SEQUENCE_OF_0, + LARGE_MAX_RESULTS_PER_OUTPUT, + ContiguousSequenceRange.of(0, 6, new Instant())); + } + + @Test + public void testLargeBufferedOutputInTimer() throws CannotProvideCoderException { + int maxResultsPerOutput = 100; + + // Array of sequences starting with 2 and the last element - 1. + // Output will be buffered until the last event arrives + long[] sequences = new long[maxResultsPerOutput * 3]; + for (int i = 0; i < sequences.length - 1; i++) { + sequences[i] = i + 2L; + } + sequences[sequences.length - 1] = 1; + + List events = new ArrayList<>(sequences.length); + Collection> expectedOutput = new ArrayList<>(sequences.length); + + StringBuilder output = new StringBuilder(); + String outputPerElement = "."; + String key = "id-1"; + + for (long sequence : sequences) { + events.add(Event.create(sequence, key, outputPerElement)); + output.append(outputPerElement); + expectedOutput.add(KV.of(key, output.toString())); + } + + testGlobalSequenceProcessing( + events.toArray(new Event[events.size()]), + expectedOutput, + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, + 1L /* This dataset assumes 1 as the starting sequence */, + maxResultsPerOutput, + ContiguousSequenceRange.of(1, sequences.length + 1, new Instant())); + } + + @Test + public void testSequenceGapProcessingInBufferedOutput() throws CannotProvideCoderException { + int maxResultsPerOutput = 3; + + long[] sequences = new long[] {2, 3, 7, 8, 9, 10, 1, 4, 5, 6}; + + List events = new ArrayList<>(sequences.length); + List> expectedOutput = new ArrayList<>(sequences.length); + + String key = "id-1"; + + for (long sequence : sequences) { + events.add(Event.create(sequence, key, sequence + "-")); + } + + StringBuilder output = new StringBuilder(); + Arrays.stream(sequences) + .sorted() + .forEach( + sequence -> { + output.append(sequence + "-"); + expectedOutput.add(KV.of(key, output.toString())); + }); + + testGlobalSequenceProcessing( + events.toArray(new Event[events.size()]), + expectedOutput, + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, + 1L /* This dataset assumes 1 as the starting sequence */, + maxResultsPerOutput, + ContiguousSequenceRange.of(1, 11, new Instant())); + } + + @Test + public void testHandlingOfMaxSequenceNumber() throws CannotProvideCoderException { + Event[] events = { + Event.create(1, "id-1", "b"), + Event.create(0, "id-1", "a"), + Event.create(Long.MAX_VALUE, "id-1", "d"), + Event.create(2, "id-1", "c") + }; + + Collection> expectedOutput = new ArrayList<>(); + expectedOutput.add(KV.of("id-1", "a")); + expectedOutput.add(KV.of("id-1", "ab")); + expectedOutput.add(KV.of("id-1", "abc")); + + Collection>>> unprocessedEvents = + new ArrayList<>(); + unprocessedEvents.add( + KV.of( + "id-1", + KV.of( + Long.MAX_VALUE, + UnprocessedEvent.create("d", Reason.sequence_id_outside_valid_range)))); + + testGlobalSequenceProcessing( + events, + expectedOutput, + unprocessedEvents, + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, + INITIAL_SEQUENCE_OF_0, + LARGE_MAX_RESULTS_PER_OUTPUT, + ContiguousSequenceRange.of(0, 3, Instant.now())); + } + + @Test + public void testProcessingOfTheLastInput() throws CannotProvideCoderException { + // TODO: fix the test. Need to see that the resulting status reflects the last input + Event[] events = { + Event.create(0, "id-1", "a"), + Event.create(1, "id-1", "b"), + Event.create(2, "id-1", StringEventExaminer.LAST_INPUT) + }; + + Collection> expectedOutput = new ArrayList<>(); + expectedOutput.add(KV.of("id-1", "a")); + expectedOutput.add(KV.of("id-1", "ab")); + expectedOutput.add(KV.of("id-1", "ab" + StringEventExaminer.LAST_INPUT)); + + testGlobalSequenceProcessing( + events, + expectedOutput, + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, + INITIAL_SEQUENCE_OF_0, + LARGE_MAX_RESULTS_PER_OUTPUT, + ContiguousSequenceRange.of(0, 3, new Instant())); + } + + private void testGlobalSequenceProcessing( + Event[] events, + Collection> expectedOutput, + int emissionFrequency, + long initialSequence, + int maxResultsPerOutput, + ContiguousSequenceRange expectedLastCompleteRange) + throws CannotProvideCoderException { + testGlobalSequenceProcessing( + events, + expectedOutput, + NO_EXPECTED_DLQ_EVENTS, + emissionFrequency, + initialSequence, + maxResultsPerOutput, + expectedLastCompleteRange); + } + + private void testGlobalSequenceProcessing( + Event[] events, + Collection> expectedOutput, + Collection>>> expectedUnprocessedEvents, + int emissionFrequency, + long initialSequence, + int maxResultsPerOutput, + ContiguousSequenceRange expectedLastCompleteRange) + throws CannotProvideCoderException { + // Test a streaming pipeline + doTest( + events, + null /* expectedStatuses */, + expectedOutput, + expectedUnprocessedEvents, + emissionFrequency, + initialSequence, + maxResultsPerOutput, + false /* produceStatusOnEveryEvent */, + STREAMING, + GLOBAL_SEQUENCE, + expectedLastCompleteRange); + + // Test a batch pipeline + if (runTestsOnDataflowRunner()) { + doTest( + events, + null /* expectedStatuses */, + expectedOutput, + expectedUnprocessedEvents, + emissionFrequency, + initialSequence, + maxResultsPerOutput, + false /* produceStatusOnEveryEvent */, + BATCH, + GLOBAL_SEQUENCE, + expectedLastCompleteRange); + } else { + System.err.println( + "Warning - batch tests didn't run. " + + "DirectRunner doesn't work correctly with this transform in batch mode." + + "Run the tests using Dataflow runner to validate."); + } + } + + @Test + public void testWindowedProcessing() throws CannotProvideCoderException { + + Instant base = new Instant(0); + TestStream values = + TestStream.create(streamingPipeline.getCoderRegistry().getCoder(Event.class)) + .advanceWatermarkTo(base) + .addElements( + // Start of first window + TimestampedValue.of( + Event.create(0, "id-1", "a"), base.plus(Duration.standardSeconds(1))), + TimestampedValue.of( + Event.create(1, "id-1", "b"), base.plus(Duration.standardSeconds(2))), + TimestampedValue.of( + Event.create(0, "id-2", "x"), base.plus(Duration.standardSeconds(1))), + TimestampedValue.of( + Event.create(1, "id-2", "y"), base.plus(Duration.standardSeconds(2))), + TimestampedValue.of( + Event.create(2, "id-2", "z"), base.plus(Duration.standardSeconds(2))), + + // Start of second window. Numbering must start with 0 again. + TimestampedValue.of( + Event.create(0, "id-1", "c"), base.plus(Duration.standardSeconds(10))), + TimestampedValue.of( + Event.create(1, "id-1", "d"), base.plus(Duration.standardSeconds(11)))) + .advanceProcessingTime(Duration.standardMinutes(15)) + .advanceWatermarkToInfinity(); + + Pipeline pipeline = streamingPipeline; + + PCollection rawInput = pipeline.apply("Create Streaming Events", values); + PCollection>> input = + rawInput.apply("To KV", ParDo.of(new MapEventsToKV())); + + input = input.apply("Window input", Window.into(FixedWindows.of(Duration.standardSeconds(5)))); + + StringBufferOrderedProcessingWithGlobalSequenceHandler handler = + new StringBufferOrderedProcessingWithGlobalSequenceHandler( + EMISSION_FREQUENCY_ON_EVERY_ELEMENT, INITIAL_SEQUENCE_OF_0); + handler.setMaxOutputElementsPerBundle(LARGE_MAX_RESULTS_PER_OUTPUT); + handler.setStatusUpdateFrequency(null); + handler.setProduceStatusUpdateOnEveryEvent(false); + + OrderedEventProcessor orderedEventProcessor = + OrderedEventProcessor.create(handler); + + OrderedEventProcessorResult processingResult = + input.apply("Process Events", orderedEventProcessor); + + IntervalWindow window1 = new IntervalWindow(base, base.plus(Duration.standardSeconds(5))); + PAssert.that("Output matches in window 1", processingResult.output()) + .inWindow(window1) + .containsInAnyOrder( + KV.of("id-1", "a"), + KV.of("id-1", "ab"), + KV.of("id-2", "x"), + KV.of("id-2", "xy"), + KV.of("id-2", "xyz")); + + IntervalWindow window2 = + new IntervalWindow( + base.plus(Duration.standardSeconds(10)), base.plus(Duration.standardSeconds(15))); + PAssert.that("Output matches in window 2", processingResult.output()) + .inWindow(window2) + .containsInAnyOrder(KV.of("id-1", "c"), KV.of("id-1", "cd")); + + // TODO: can we make the status assertions work? + // PAssert.that("Statuses match in window 1", processingResult.processingStatuses()) + // .inWindow(window1) + // .containsInAnyOrder( + //// KV.of("id-1", OrderedProcessingStatus.create(0L, 0, null, null, 1, 1, 0, + // false)), + // KV.of("id-1", OrderedProcessingStatus.create(1L, 0, null, null, 2, 2, 0, false)), + //// KV.of("id-2", OrderedProcessingStatus.create(0L, 0, null, null, 1, 1, 0, + // false)), + //// KV.of("id-2", OrderedProcessingStatus.create(1L, 0, null, null, 2, 2, 0, + // false)), + // KV.of("id-2", OrderedProcessingStatus.create(2L, 0, null, null, 3, 3, 0, false)) + // ); + + // PAssert.that("Statuses match in window 2", processingResult.processingStatuses()) + // .inWindow(window2) + // .containsInAnyOrder( + // KV.of("id-1", OrderedProcessingStatus.create(0L, 0, null, null, 1, 1, 0, false)), + // KV.of("id-1", OrderedProcessingStatus.create(1L, 0, null, null, 2, 2, 0, false))); + + PAssert.that("Unprocessed events match", processingResult.unprocessedEvents()) + .containsInAnyOrder(NO_EXPECTED_DLQ_EVENTS); + + pipeline.run(); + } +} diff --git a/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorTest.java b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorPerKeySequenceTest.java similarity index 71% rename from sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorTest.java rename to sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorPerKeySequenceTest.java index 6a24021ad667..6909a3bb992c 100644 --- a/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorTest.java +++ b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorPerKeySequenceTest.java @@ -20,82 +20,24 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.List; -import java.util.Set; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.extensions.ordered.UnprocessedEvent.Reason; import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.SerializableMatcher; -import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.Reshuffle; -import org.apache.beam.sdk.transforms.windowing.AfterWatermark; import org.apache.beam.sdk.transforms.windowing.FixedWindows; -import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; -import org.apache.beam.sdk.transforms.windowing.Repeatedly; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TimestampedValue; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; -import org.hamcrest.BaseMatcher; -import org.hamcrest.Description; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.Rule; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -/** - * Ordered Processing tests use the same testing scenario. Events are sent in or out of sequence. - * Each event is a string for a particular key. The output is a concatenation of all strings. - */ -@RunWith(JUnit4.class) -public class OrderedEventProcessorTest { - - public static final boolean LAST_EVENT_RECEIVED = true; - public static final int EMISSION_FREQUENCY_ON_EVERY_ELEMENT = 1; - public static final int INITIAL_SEQUENCE_OF_0 = 0; - public static final boolean DONT_PRODUCE_STATUS_ON_EVERY_EVENT = false; - public static final int LARGE_MAX_RESULTS_PER_OUTPUT = 1000; - public static final int EMISSION_FREQUENCY_ON_EVERY_OTHER_EVENT = 2; - public static final boolean PRODUCE_STATUS_ON_EVERY_EVENT = true; - public static final boolean STREAMING = true; - public static final boolean BATCH = false; - public static final Set>>> NO_EXPECTED_DLQ_EVENTS = - Collections.emptySet(); - @Rule public final transient TestPipeline streamingPipeline = TestPipeline.create(); - @Rule public final transient TestPipeline batchPipeline = TestPipeline.create(); - - static class MapEventsToKV extends DoFn>> { - - @ProcessElement - public void convert( - @Element Event event, OutputReceiver>> outputReceiver) { - outputReceiver.output(KV.of(event.getKey(), KV.of(event.getSequence(), event.getValue()))); - } - } - - static class MapStringBufferStateToString - extends DoFn, KV> { - - @ProcessElement - public void map( - @Element KV element, - OutputReceiver> outputReceiver) { - outputReceiver.output(KV.of(element.getKey(), element.getValue().toString())); - } - } +public class OrderedEventProcessorPerKeySequenceTest extends OrderedEventProcessorTestBase { @Test public void testPerfectOrderingProcessing() throws CannotProvideCoderException { @@ -142,7 +84,7 @@ public void testPerfectOrderingProcessing() throws CannotProvideCoderException { expectedOutput.add(KV.of("id-2", "a")); expectedOutput.add(KV.of("id-2", "ab")); - testProcessing( + testPerKeySequenceProcessing( events, expectedStatuses, expectedOutput, @@ -203,7 +145,7 @@ public void testOutOfSequenceProcessing() throws CannotProvideCoderException { expectedOutput.add(KV.of("id-2", "abcd")); expectedOutput.add(KV.of("id-2", "abcde")); - testProcessing( + testPerKeySequenceProcessing( events, expectedStatuses, expectedOutput, @@ -235,7 +177,7 @@ public void testUnfinishedProcessing() throws CannotProvideCoderException { expectedOutput.add(KV.of("id-2", "a")); expectedOutput.add(KV.of("id-2", "ab")); - testProcessing(events, expectedStatuses, expectedOutput, 1, 0, 1000, false); + testPerKeySequenceProcessing(events, expectedStatuses, expectedOutput, 1, 0, 1000, false); } @Test @@ -275,7 +217,7 @@ public void testHandlingOfDuplicateSequences() throws CannotProvideCoderExceptio duplicates.add(KV.of("id-1", KV.of(1L, UnprocessedEvent.create("b", Reason.duplicate)))); duplicates.add(KV.of("id-1", KV.of(3L, UnprocessedEvent.create("d", Reason.duplicate)))); - testProcessing( + testPerKeySequenceProcessing( events, expectedStatuses, expectedOutput, @@ -311,7 +253,7 @@ public void testHandlingOfCheckedExceptions() throws CannotProvideCoderException 2L, UnprocessedEvent.create(StringBuilderState.BAD_VALUE, Reason.exception_thrown)))); - testProcessing( + testPerKeySequenceProcessing( events, expectedStatuses, expectedOutput, @@ -346,7 +288,7 @@ public void testProcessingWithEveryOtherResultEmission() throws CannotProvideCod // Skipped KV.of("id-1", "abcd"), expectedOutput.add(KV.of("id-2", "a")); // Skipped KV.of("id-2", "ab") - testProcessing( + testPerKeySequenceProcessing( events, expectedStatuses, expectedOutput, @@ -428,7 +370,7 @@ public void testLargeBufferedOutputInTimer() throws CannotProvideCoderException 0, false))); - testProcessing( + testPerKeySequenceProcessing( events.toArray(new Event[events.size()]), expectedStatuses, expectedOutput, @@ -523,7 +465,7 @@ public void testSequenceGapProcessingInBufferedOutput() throws CannotProvideCode OrderedProcessingStatus.create( 10L, 0, null, null, numberOfReceivedEvents, 10L, 0, false))); - testProcessing( + testPerKeySequenceProcessing( events.toArray(new Event[events.size()]), expectedStatuses, expectedOutput, @@ -558,7 +500,7 @@ public void testHandlingOfMaxSequenceNumber() throws CannotProvideCoderException Long.MAX_VALUE, UnprocessedEvent.create("c", Reason.sequence_id_outside_valid_range)))); - testProcessing( + testPerKeySequenceProcessing( events, expectedStatuses, expectedOutput, @@ -589,7 +531,7 @@ public void testProcessingOfTheLastInput() throws CannotProvideCoderException { expectedOutput.add(KV.of("id-1", "ab")); expectedOutput.add(KV.of("id-1", "ab" + StringEventExaminer.LAST_INPUT)); - testProcessing( + testPerKeySequenceProcessing( events, expectedStatuses, expectedOutput, @@ -599,6 +541,65 @@ public void testProcessingOfTheLastInput() throws CannotProvideCoderException { DONT_PRODUCE_STATUS_ON_EVERY_EVENT); } + protected void testPerKeySequenceProcessing( + Event[] events, + Collection> expectedStatuses, + Collection> expectedOutput, + int emissionFrequency, + long initialSequence, + int maxResultsPerOutput, + boolean produceStatusOnEveryEvent) + throws CannotProvideCoderException { + testPerKeySequenceProcessing( + events, + expectedStatuses, + expectedOutput, + NO_EXPECTED_DLQ_EVENTS, + emissionFrequency, + initialSequence, + maxResultsPerOutput, + produceStatusOnEveryEvent); + } + + protected void testPerKeySequenceProcessing( + Event[] events, + Collection> expectedStatuses, + Collection> expectedOutput, + Collection>>> expectedUnprocessedEvents, + int emissionFrequency, + long initialSequence, + int maxResultsPerOutput, + boolean produceStatusOnEveryEvent) + throws CannotProvideCoderException { + // Test a streaming pipeline + doTest( + events, + expectedStatuses, + expectedOutput, + expectedUnprocessedEvents, + emissionFrequency, + initialSequence, + maxResultsPerOutput, + produceStatusOnEveryEvent, + STREAMING, + false, + ContiguousSequenceRange.EMPTY); + + // Test a batch pipeline + doTest( + events, + expectedStatuses, + expectedOutput, + expectedUnprocessedEvents, + emissionFrequency, + initialSequence, + maxResultsPerOutput, + produceStatusOnEveryEvent, + BATCH, + false, + ContiguousSequenceRange.EMPTY); + } + @Test public void testWindowedProcessing() throws CannotProvideCoderException { @@ -684,223 +685,4 @@ public void testWindowedProcessing() throws CannotProvideCoderException { pipeline.run(); } - - private void testProcessing( - Event[] events, - Collection> expectedStatuses, - Collection> expectedOutput, - int emissionFrequency, - long initialSequence, - int maxResultsPerOutput, - boolean produceStatusOnEveryEvent) - throws CannotProvideCoderException { - testProcessing( - events, - expectedStatuses, - expectedOutput, - NO_EXPECTED_DLQ_EVENTS, - emissionFrequency, - initialSequence, - maxResultsPerOutput, - produceStatusOnEveryEvent); - } - - private void testProcessing( - Event[] events, - Collection> expectedStatuses, - Collection> expectedOutput, - Collection>>> expectedUnprocessedEvents, - int emissionFrequency, - long initialSequence, - int maxResultsPerOutput, - boolean produceStatusOnEveryEvent) - throws CannotProvideCoderException { - doTest( - events, - expectedStatuses, - expectedOutput, - expectedUnprocessedEvents, - emissionFrequency, - initialSequence, - maxResultsPerOutput, - produceStatusOnEveryEvent, - STREAMING); - doTest( - events, - expectedStatuses, - expectedOutput, - expectedUnprocessedEvents, - emissionFrequency, - initialSequence, - maxResultsPerOutput, - produceStatusOnEveryEvent, - BATCH); - } - - /** - * The majority of the tests use this method. Testing is done in the global window. - * - * @param events - * @param expectedStatuses - * @param expectedOutput - * @param expectedUnprocessedEvents - * @param emissionFrequency - * @param initialSequence - * @param maxResultsPerOutput - * @param produceStatusOnEveryEvent - * @param streaming - * @throws @UnknownKeyFor @NonNull @Initialized CannotProvideCoderException - */ - private void doTest( - Event[] events, - Collection> expectedStatuses, - Collection> expectedOutput, - Collection>>> expectedUnprocessedEvents, - int emissionFrequency, - long initialSequence, - int maxResultsPerOutput, - boolean produceStatusOnEveryEvent, - boolean streaming) - throws @UnknownKeyFor @NonNull @Initialized CannotProvideCoderException { - - Pipeline pipeline = streaming ? streamingPipeline : batchPipeline; - - PCollection rawInput = - streaming - ? createStreamingPCollection(pipeline, events) - : createBatchPCollection(pipeline, events); - PCollection>> input = - rawInput.apply("To KV", ParDo.of(new MapEventsToKV())); - - StringBufferOrderedProcessingHandler handler = - new StringBufferOrderedProcessingHandler(emissionFrequency, initialSequence); - handler.setMaxOutputElementsPerBundle(maxResultsPerOutput); - if (produceStatusOnEveryEvent) { - handler.setProduceStatusUpdateOnEveryEvent(true); - // This disables status updates emitted on timers. - handler.setStatusUpdateFrequency(null); - } else { - handler.setStatusUpdateFrequency( - streaming ? Duration.standardMinutes(5) : Duration.standardSeconds(1)); - } - OrderedEventProcessor orderedEventProcessor = - OrderedEventProcessor.create(handler); - - OrderedEventProcessorResult processingResult = - input.apply("Process Events", orderedEventProcessor); - - PAssert.that("Output matches", processingResult.output()).containsInAnyOrder(expectedOutput); - - if (streaming) { - // Only in streaming the events will arrive in a pre-determined order and the statuses - // will be deterministic. In batch pipelines events can be processed in any order, - // so we skip status verification and rely on the output and unprocessed event matches. - PAssert.that("Statuses match", processingResult.processingStatuses()) - .containsInAnyOrder(expectedStatuses); - } - - // This is a temporary workaround until PAssert changes. - boolean unprocessedEventsHaveExceptionStackTrace = false; - for (KV>> event : expectedUnprocessedEvents) { - if (event.getValue().getValue().getReason() == Reason.exception_thrown) { - unprocessedEventsHaveExceptionStackTrace = true; - break; - } - } - - if (unprocessedEventsHaveExceptionStackTrace) { - PAssert.thatSingleton( - "Unprocessed event count", - processingResult - .unprocessedEvents() - .apply( - "Window", - Window.>>>into( - new GlobalWindows()) - .triggering(Repeatedly.forever(AfterWatermark.pastEndOfWindow())) - .discardingFiredPanes()) - .apply("Count", Count.globally())) - .isEqualTo((long) expectedUnprocessedEvents.size()); - } else { - PAssert.that("Unprocessed events match", processingResult.unprocessedEvents()) - .containsInAnyOrder(expectedUnprocessedEvents); - } - pipeline.run(); - } - - private @UnknownKeyFor @NonNull @Initialized PCollection createBatchPCollection( - Pipeline pipeline, Event[] events) { - return pipeline - .apply("Create Batch Events", Create.of(Arrays.asList(events))) - .apply("Reshuffle", Reshuffle.viaRandomKey()); - } - - private @UnknownKeyFor @NonNull @Initialized PCollection createStreamingPCollection( - Pipeline pipeline, Event[] events) - throws @UnknownKeyFor @NonNull @Initialized CannotProvideCoderException { - Instant now = Instant.now().minus(Duration.standardMinutes(20)); - TestStream.Builder messageFlow = - TestStream.create(pipeline.getCoderRegistry().getCoder(Event.class)) - .advanceWatermarkTo(now); - - int delayInMilliseconds = 0; - for (Event e : events) { - messageFlow = - messageFlow - .advanceWatermarkTo(now.plus(Duration.millis(++delayInMilliseconds))) - .addElements(e); - } - - // Needed to force the processing time based timers. - messageFlow = messageFlow.advanceProcessingTime(Duration.standardMinutes(15)); - return pipeline.apply("Create Streaming Events", messageFlow.advanceWatermarkToInfinity()); - } - - /** - * Unprocessed event's explanation contains stacktraces which makes tests very brittle because it - * requires hardcoding the line numbers in the code. We use this matcher to only compare on the - * first line of the explanation. - */ - static class UnprocessedEventMatcher - extends BaseMatcher>>> - implements SerializableMatcher>>> { - - private KV>> element; - - public UnprocessedEventMatcher(KV>> element) { - this.element = element; - } - - @Override - public boolean matches(Object actual) { - KV>> toMatch = - (KV>>) actual; - - UnprocessedEvent originalEvent = element.getValue().getValue(); - UnprocessedEvent eventToMatch = toMatch.getValue().getValue(); - - return element.getKey().equals(toMatch.getKey()) - && element.getValue().getKey().equals(toMatch.getValue().getKey()) - && originalEvent.getEvent().equals(eventToMatch.getEvent()) - && originalEvent.getReason() == eventToMatch.getReason() - && normalizeExplanation(originalEvent.getExplanation()) - .equals(normalizeExplanation(eventToMatch.getExplanation())); - } - - @Override - public void describeTo(Description description) { - description.appendText("Just some text..."); - } - - static String normalizeExplanation(String value) { - if (value == null) { - return ""; - } - String firstLine = value.split("\n", 1)[0]; - if (firstLine.contains("Exception")) { - return firstLine; - } - return value; - } - } } diff --git a/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorTestBase.java b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorTestBase.java new file mode 100644 index 000000000000..fd651b919df1 --- /dev/null +++ b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/OrderedEventProcessorTestBase.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered; + +import static org.hamcrest.MatcherAssert.assertThat; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Set; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.TestDataflowPipelineOptions; +import org.apache.beam.runners.dataflow.TestDataflowRunner; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.extensions.ordered.StringBufferOrderedProcessingHandler.StringBufferOrderedProcessingWithGlobalSequenceHandler; +import org.apache.beam.sdk.extensions.ordered.UnprocessedEvent.Reason; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.SerializableMatcher; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Reshuffle; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.windowing.AfterWatermark; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; +import org.apache.beam.sdk.values.PCollectionView; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; + +/** + * Ordered Processing tests use the same testing scenario. Events are sent in or out of sequence. + * Each event is a string for a particular key. The output is a concatenation of all strings. + */ +public class OrderedEventProcessorTestBase { + + public static final boolean LAST_EVENT_RECEIVED = true; + public static final int EMISSION_FREQUENCY_ON_EVERY_ELEMENT = 1; + public static final int INITIAL_SEQUENCE_OF_0 = 0; + public static final boolean DONT_PRODUCE_STATUS_ON_EVERY_EVENT = false; + public static final int LARGE_MAX_RESULTS_PER_OUTPUT = 1000; + public static final int EMISSION_FREQUENCY_ON_EVERY_OTHER_EVENT = 2; + public static final boolean PRODUCE_STATUS_ON_EVERY_EVENT = true; + public static final boolean STREAMING = true; + public static final boolean BATCH = false; + public static final Set>>> NO_EXPECTED_DLQ_EVENTS = + Collections.emptySet(); + @Rule public final transient TestPipeline streamingPipeline = TestPipeline.create(); + @Rule public final transient TestPipeline batchPipeline = TestPipeline.create(); + + protected boolean runTestsOnDataflowRunner() { + return Boolean.getBoolean("run-tests-on-dataflow"); + } + + protected String getSystemProperty(String name) { + String property = System.getProperty(name); + if (property == null) { + throw new IllegalStateException("Unable to find system property '" + name + "'"); + } + return property; + } + + static class MapEventsToKV extends DoFn>> { + + @ProcessElement + public void convert( + @Element Event event, OutputReceiver>> outputReceiver) { + outputReceiver.output(KV.of(event.getKey(), KV.of(event.getSequence(), event.getValue()))); + } + } + + static class MapStringBufferStateToString + extends DoFn, KV> { + + @ProcessElement + public void map( + @Element KV element, + OutputReceiver> outputReceiver) { + outputReceiver.output(KV.of(element.getKey(), element.getValue().toString())); + } + } + + /** + * The majority of the tests use this method. Testing is done in the global window. + * + * @throws @UnknownKeyFor @NonNull @Initialized CannotProvideCoderException + */ + protected void doTest( + Event[] events, + @Nullable Collection> expectedStatuses, + Collection> expectedOutput, + Collection>>> expectedUnprocessedEvents, + int emissionFrequency, + long initialSequence, + int maxResultsPerOutput, + boolean produceStatusOnEveryEvent, + boolean streaming, + boolean isGlobalSequence, + @Nullable ContiguousSequenceRange expectedLastCompletedSequence) + throws @UnknownKeyFor @NonNull @Initialized CannotProvideCoderException { + + Pipeline pipeline = streaming ? streamingPipeline : batchPipeline; + if (runTestsOnDataflowRunner()) { + pipeline.getOptions().setRunner(TestDataflowRunner.class); + TestDataflowPipelineOptions options = + pipeline.getOptions().as(TestDataflowPipelineOptions.class); + options.setExperiments(Arrays.asList("disable_runner_v2")); + options.setTempRoot("gs://" + getSystemProperty("temp_dataflow_bucket")); + } + PCollection rawInput = + streaming + ? createStreamingPCollection(pipeline, events) + : createBatchPCollection(pipeline, events); + PCollection>> input = + rawInput.apply("To KV", ParDo.of(new MapEventsToKV())); + + OrderedProcessingHandler handler = + isGlobalSequence + ? new StringBufferOrderedProcessingWithGlobalSequenceHandler( + emissionFrequency, initialSequence) + : new StringBufferOrderedProcessingHandler(emissionFrequency, initialSequence); + handler.setMaxOutputElementsPerBundle(maxResultsPerOutput); + if (produceStatusOnEveryEvent) { + handler.setProduceStatusUpdateOnEveryEvent(true); + // This disables status updates emitted on timers. + handler.setStatusUpdateFrequency(null); + } else { + handler.setStatusUpdateFrequency( + streaming ? Duration.standardMinutes(5) : Duration.standardSeconds(1)); + } + + OrderedEventProcessor orderedEventProcessor = + OrderedEventProcessor.create(handler); + + OrderedEventProcessorResult processingResult = + input.apply("Process Events", orderedEventProcessor); + + PAssert.that("Output matches", processingResult.output()).containsInAnyOrder(expectedOutput); + + if (streaming && expectedStatuses != null) { + // Only in a streaming pipeline the events will arrive in a pre-determined order and the + // statuses + // will be deterministic. In batch pipelines events can be processed in any order, + // so we skip status verification and rely on the output and unprocessed event matches. + PAssert.that("Statuses match", processingResult.processingStatuses()) + .containsInAnyOrder(expectedStatuses); + } + + // This is a temporary workaround until PAssert changes. + boolean unprocessedEventsHaveExceptionStackTrace = false; + for (KV>> event : expectedUnprocessedEvents) { + if (event.getValue().getValue().getReason() == Reason.exception_thrown) { + unprocessedEventsHaveExceptionStackTrace = true; + break; + } + } + + if (unprocessedEventsHaveExceptionStackTrace) { + PAssert.thatSingleton( + "Unprocessed event count", + processingResult + .unprocessedEvents() + .apply( + "Window", + Window.>>>into( + new GlobalWindows()) + .triggering(Repeatedly.forever(AfterWatermark.pastEndOfWindow())) + .discardingFiredPanes()) + .apply("Count", Count.globally())) + .isEqualTo((long) expectedUnprocessedEvents.size()); + } else { + PAssert.that("Unprocessed events match", processingResult.unprocessedEvents()) + .containsInAnyOrder(expectedUnprocessedEvents); + } + + if (expectedLastCompletedSequence != null && processingResult.latestContiguousRange() != null) { + PCollection globalSequences = + rawInput.apply( + "Publish Global Sequences", + new GlobalSequenceRangePublisher( + processingResult.latestContiguousRange(), + handler.getKeyCoder(pipeline, input.getCoder()), + handler.getEventCoder(pipeline, input.getCoder()))); + PAssert.that("CompletedSequenceRange verification", globalSequences) + .satisfies(new LastExpectedGlobalSequenceRangeMatcher(expectedLastCompletedSequence)); + } + pipeline.run(); + } + + static class LastExpectedGlobalSequenceRangeMatcher + implements SerializableFunction, Void> { + + private final long expectedStart; + private final long expectedEnd; + + LastExpectedGlobalSequenceRangeMatcher(ContiguousSequenceRange expected) { + this.expectedStart = expected.getStart(); + this.expectedEnd = expected.getEnd(); + } + + @Override + public Void apply(Iterable input) { + StringBuilder listOfRanges = new StringBuilder("["); + Iterator iterator = input.iterator(); + ContiguousSequenceRange lastRange = null; + while (iterator.hasNext()) { + lastRange = iterator.next(); + + if (listOfRanges.length() > 1) { + listOfRanges.append(", "); + } + listOfRanges.append(lastRange); + } + listOfRanges.append(']'); + boolean foundExpectedRange = + lastRange != null + && lastRange.getStart() == expectedStart + && lastRange.getEnd() == expectedEnd; + + assertThat( + "Expected range not found: [" + + expectedStart + + '-' + + expectedEnd + + "], received ranges: " + + listOfRanges, + foundExpectedRange); + return null; + } + } + + private @UnknownKeyFor @NonNull @Initialized PCollection createBatchPCollection( + Pipeline pipeline, Event[] events) { + return pipeline + .apply("Create Batch Events", Create.of(Arrays.asList(events))) + .apply("Reshuffle", Reshuffle.viaRandomKey()); + } + + private @UnknownKeyFor @NonNull @Initialized PCollection createStreamingPCollection( + Pipeline pipeline, Event[] events) + throws @UnknownKeyFor @NonNull @Initialized CannotProvideCoderException { + Instant now = Instant.now().minus(Duration.standardMinutes(20)); + TestStream.Builder messageFlow = + TestStream.create(pipeline.getCoderRegistry().getCoder(Event.class)) + .advanceWatermarkTo(now); + + int delayInMilliseconds = 0; + for (Event e : events) { + messageFlow = + messageFlow + .advanceWatermarkTo(now.plus(Duration.millis(++delayInMilliseconds))) + .addElements(e); + } + + // Needed to force the processing time based timers. + messageFlow = messageFlow.advanceProcessingTime(Duration.standardMinutes(15)); + return pipeline.apply("Create Streaming Events", messageFlow.advanceWatermarkToInfinity()); + } + + /** + * Unprocessed event's explanation contains stacktraces which makes tests very brittle because it + * requires hardcoding the line numbers in the code. We use this matcher to only compare on the + * first line of the explanation. + */ + static class UnprocessedEventMatcher + extends BaseMatcher>>> + implements SerializableMatcher>>> { + + private KV>> element; + + public UnprocessedEventMatcher(KV>> element) { + this.element = element; + } + + @Override + public boolean matches(Object actual) { + KV>> toMatch = + (KV>>) actual; + + UnprocessedEvent originalEvent = element.getValue().getValue(); + UnprocessedEvent eventToMatch = toMatch.getValue().getValue(); + + return element.getKey().equals(toMatch.getKey()) + && element.getValue().getKey().equals(toMatch.getValue().getKey()) + && originalEvent.getEvent().equals(eventToMatch.getEvent()) + && originalEvent.getReason() == eventToMatch.getReason() + && normalizeExplanation(originalEvent.getExplanation()) + .equals(normalizeExplanation(eventToMatch.getExplanation())); + } + + @Override + public void describeTo(Description description) { + description.appendText("Just some text..."); + } + + static String normalizeExplanation(String value) { + if (value == null) { + return ""; + } + String firstLine = value.split("\n", 1)[0]; + if (firstLine.contains("Exception")) { + return firstLine; + } + return value; + } + } + + static class GlobalSequenceRangePublisher + extends PTransform, PCollection> { + + private final PCollectionView lastCompletedSequenceRangeView; + private final Coder keyCoder; + private final Coder eventCoder; + + public GlobalSequenceRangePublisher( + PCollectionView latestCompletedSequenceRange, + Coder keyCoder, + Coder eventCoder) { + this.lastCompletedSequenceRangeView = latestCompletedSequenceRange; + this.keyCoder = keyCoder; + this.eventCoder = eventCoder; + } + + @Override + public PCollection expand(PCollection input) { + PCollection>> events = + input + // In production pipelines the global sequence will typically be obtained + // by using GenerateSequence. But GenerateSequence doesn't work well with TestStream, + // That's why we use the input events here. + // .apply("Create Ticker", + // GenerateSequence.from(0).to(2).withRate(1, + // Duration.standardSeconds(5))) + .apply("To KV", ParDo.of(new MapEventsToKV())); + if (input.isBounded() == IsBounded.BOUNDED) { + return events.apply( + "Emit SideInput", + ParDo.of(new SideInputEmitter()) + .withSideInput("lastCompletedSequence", lastCompletedSequenceRangeView)); + } else { + PCollection>> tickers = + events.apply( + "Create Tickers", + new PerKeyTickerGenerator<>(keyCoder, eventCoder, Duration.standardSeconds(1))); + return tickers.apply( + "Emit SideInput", + ParDo.of(new SideInputEmitter()) + .withSideInput("lastCompletedSequence", lastCompletedSequenceRangeView)); + } + } + + static class SideInputEmitter + extends DoFn>, ContiguousSequenceRange> { + + @ProcessElement + public void produceCompletedRange( + @SideInput("lastCompletedSequence") ContiguousSequenceRange sideInput, + OutputReceiver outputReceiver) { + outputReceiver.output(sideInput); + } + } + } +} diff --git a/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/StringBufferOrderedProcessingHandler.java b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/StringBufferOrderedProcessingHandler.java index 72f3a3cf21b6..1da46c3262e4 100644 --- a/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/StringBufferOrderedProcessingHandler.java +++ b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/StringBufferOrderedProcessingHandler.java @@ -27,6 +27,24 @@ public class StringBufferOrderedProcessingHandler extends OrderedProcessingHandler { + public static class StringBufferOrderedProcessingWithGlobalSequenceHandler + extends OrderedProcessingGlobalSequenceHandler { + + private final EventExaminer eventExaminer; + + public StringBufferOrderedProcessingWithGlobalSequenceHandler( + int emissionFrequency, long initialSequence) { + super(String.class, String.class, StringBuilderState.class, String.class); + this.eventExaminer = new StringEventExaminer(initialSequence, emissionFrequency); + } + + @Override + @NonNull + public EventExaminer getEventExaminer() { + return eventExaminer; + } + } + private final EventExaminer eventExaminer; public StringBufferOrderedProcessingHandler(int emissionFrequency, long initialSequence) { diff --git a/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/combiner/SequenceRangeAccumulatorCoderTest.java b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/combiner/SequenceRangeAccumulatorCoderTest.java new file mode 100644 index 000000000000..0e5b0b7c819a --- /dev/null +++ b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/combiner/SequenceRangeAccumulatorCoderTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered.combiner; + +import static org.junit.Assert.assertEquals; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import org.apache.beam.sdk.extensions.ordered.combiner.SequenceRangeAccumulator.SequenceRangeAccumulatorCoder; +import org.joda.time.Instant; +import org.junit.Test; + +public class SequenceRangeAccumulatorCoderTest { + + private SequenceRangeAccumulatorCoder coder = SequenceRangeAccumulatorCoder.of(); + + @Test + public void testEncodingEmptyAccumulator() throws IOException { + SequenceRangeAccumulator empty = new SequenceRangeAccumulator(); + + doTestEncodingAndDecoding(empty); + } + + @Test + public void testEncodingAccumulatorWithoutInitialSequence() throws IOException { + SequenceRangeAccumulator accumulator = new SequenceRangeAccumulator(); + accumulator.add(1, Instant.now(), false); + accumulator.add(2, Instant.now(), false); + accumulator.add(3, Instant.now(), false); + accumulator.add(5, Instant.now(), false); + accumulator.add(6, Instant.now(), false); + + doTestEncodingAndDecoding(accumulator); + } + + @Test + public void testEncodingAccumulatorWithInitialSequence() throws IOException { + SequenceRangeAccumulator accumulator = new SequenceRangeAccumulator(); + accumulator.add(1, Instant.now(), true); + accumulator.add(2, Instant.now(), false); + accumulator.add(3, Instant.now(), false); + accumulator.add(5, Instant.now(), false); + accumulator.add(6, Instant.now(), false); + + doTestEncodingAndDecoding(accumulator); + } + + private void doTestEncodingAndDecoding(SequenceRangeAccumulator value) throws IOException { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + coder.encode(value, output); + + SequenceRangeAccumulator decoded = coder.decode(new ByteArrayInputStream(output.toByteArray())); + assertEquals("Accumulator", value, decoded); + } +} diff --git a/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/combiner/SequenceRangeAccumulatorTest.java b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/combiner/SequenceRangeAccumulatorTest.java new file mode 100644 index 000000000000..4082ce6de758 --- /dev/null +++ b/sdks/java/extensions/ordered/src/test/java/org/apache/beam/sdk/extensions/ordered/combiner/SequenceRangeAccumulatorTest.java @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.ordered.combiner; + +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.beam.sdk.extensions.ordered.ContiguousSequenceRange; +import org.joda.time.Instant; +import org.junit.Assert; +import org.junit.Test; + +public class SequenceRangeAccumulatorTest { + + // Atomic just in case tests are run in parallel + private static final AtomicLong currentTicker = new AtomicLong(); + + static Instant nextTimestamp() { + return Instant.ofEpochMilli(currentTicker.getAndIncrement()); + } + + static Instant eventTimestamp(Event[] events, long eventSequence) { + for (Event e : events) { + if (e.sequence == eventSequence) { + return e.timestamp; + } + } + throw new IllegalStateException("Unable to find event with sequence " + eventSequence); + } + + static class Event { + + long sequence; + Instant timestamp; + boolean initialEvent; + + Event(long sequence, Instant ts) { + this(sequence, ts, false); + } + + Event(long sequence, Instant ts, boolean initialEvent) { + this.sequence = sequence; + this.timestamp = ts; + this.initialEvent = initialEvent; + } + } + + @Test + public void testSimpleAccumulation() { + Event[] events = + new Event[] { + new Event(1, nextTimestamp(), true), + new Event(2, nextTimestamp()), + new Event(3, nextTimestamp()) + }; + + doTestAccumulation(events, ContiguousSequenceRange.of(1, 4, eventTimestamp(events, 3)), 1); + } + + @Test + public void testReverseArrivalHandling() { + Event[] events = + new Event[] { + new Event(3, nextTimestamp()), + new Event(2, nextTimestamp()), + new Event(1, nextTimestamp(), true) + }; + + Instant timestampOfEventNumber1 = eventTimestamp(events, 1); + doTestAccumulation(events, ContiguousSequenceRange.of(1, 4, timestampOfEventNumber1), 1); + } + + @Test + public void testPartialRangeAccumulation() { + Event[] events = + new Event[] { + new Event(1, nextTimestamp(), true), + new Event(2, nextTimestamp()), + new Event(3, nextTimestamp()), + new Event(5, nextTimestamp()), + new Event(7, nextTimestamp()), + }; + + doTestAccumulation(events, ContiguousSequenceRange.of(1, 4, eventTimestamp(events, 3)), 3); + } + + @Test + public void testMergingRangeAccumulation() { + Event[] events = + new Event[] { + new Event(1, nextTimestamp(), true), + new Event(2, nextTimestamp()), + new Event(3, nextTimestamp()), + new Event(5, nextTimestamp()), + new Event(7, nextTimestamp()), + new Event(6, nextTimestamp()), + }; + + doTestAccumulation(events, ContiguousSequenceRange.of(1, 4, eventTimestamp(events, 3)), 2); + } + + @Test + public void testNoStartEvent() { + Event[] events = + new Event[] { + new Event(2, nextTimestamp()), + new Event(3, nextTimestamp()), + new Event(1, nextTimestamp()), + new Event(5, nextTimestamp()), + }; + + doTestAccumulation(events, ContiguousSequenceRange.EMPTY, 2); + } + + @Test + public void testNoEventsAccumulation() { + Event[] events = new Event[] {}; + + doTestAccumulation(events, ContiguousSequenceRange.EMPTY, 0); + } + + @Test + public void testRemovingRangesBelowInitialSequenceDuringAccumulation() { + Event[] events = + new Event[] { + // First range + new Event(2, nextTimestamp()), + new Event(3, nextTimestamp()), + new Event(1, nextTimestamp()), + + // Second range + new Event(5, nextTimestamp()), + new Event(6, nextTimestamp()), + + // This event should prune everything below + new Event(7, nextTimestamp(), true), + }; + + doTestAccumulation(events, ContiguousSequenceRange.of(7, 8, eventTimestamp(events, 7)), 1); + } + + @Test + public void testRemovingElementsBelowInitialSequenceDuringAccumulation() { + + Event[] events = + new Event[] { + // First range + new Event(2, nextTimestamp()), + new Event(3, nextTimestamp()), + new Event(1, nextTimestamp()), + + // Second range + new Event(5, nextTimestamp()), + new Event(6, nextTimestamp()), + new Event(7, nextTimestamp()), + new Event(8, nextTimestamp()), + + // This event should reduce the range. + new Event(7, nextTimestamp(), true), + }; + + Instant timestampOfTheLastEvent = events[events.length - 1].timestamp; + doTestAccumulation(events, ContiguousSequenceRange.of(7, 9, timestampOfTheLastEvent), 1); + } + + private static void doTestAccumulation( + Event[] events, ContiguousSequenceRange expectedResult, int expectedNumberOfRanges) { + SequenceRangeAccumulator accumulator = new SequenceRangeAccumulator(); + Arrays.stream(events).forEach(e -> accumulator.add(e.sequence, e.timestamp, e.initialEvent)); + + Assert.assertEquals( + "Accumulated results", expectedResult, accumulator.largestContinuousRange()); + + Assert.assertEquals("Number of ranges", expectedNumberOfRanges, accumulator.numberOfRanges()); + } + + @Test + public void testEmptyMerge() { + Event[] set1 = new Event[] {}; + Event[] set2 = new Event[] {}; + + ContiguousSequenceRange expectedResult = ContiguousSequenceRange.EMPTY; + int expectedNumberOfRanges = 0; + + doTestMerging(set1, set2, expectedResult, expectedNumberOfRanges); + } + + @Test + public void testMergingNonEmptyWithEmpty() { + Event[] set1 = + new Event[] { + new Event(3, nextTimestamp()), + new Event(2, nextTimestamp()), + new Event(1, nextTimestamp(), true) + }; + Event[] set2 = new Event[] {}; + + ContiguousSequenceRange expectedResult = + ContiguousSequenceRange.of(1, 4, eventTimestamp(set1, 1L)); + int expectedNumberOfRanges = 1; + + doTestMerging(set1, set2, expectedResult, expectedNumberOfRanges); + } + + @Test + public void testMergingWithLowerNonAdjacentRange() { + Event[] set1 = + new Event[] { + new Event(1, nextTimestamp(), true), new Event(2, nextTimestamp()), + }; + Event[] set2 = + new Event[] { + new Event(4, nextTimestamp()), + new Event(5, nextTimestamp()), + new Event(6, nextTimestamp()) + }; + + ContiguousSequenceRange expectedResult = + ContiguousSequenceRange.of(1, 3, eventTimestamp(set1, 2L)); + int expectedNumberOfRanges = 2; + + doTestMerging(set1, set2, expectedResult, expectedNumberOfRanges); + } + + @Test + public void testMergingWithoutAnyInitialEvents() { + Event[] set1 = + new Event[] { + new Event(1, nextTimestamp()), new Event(2, nextTimestamp()), + }; + Event[] set2 = + new Event[] { + new Event(4, nextTimestamp()), + new Event(5, nextTimestamp()), + new Event(6, nextTimestamp()) + }; + + ContiguousSequenceRange expectedResult = ContiguousSequenceRange.EMPTY; + int expectedNumberOfRanges = 2; + + doTestMerging(set1, set2, expectedResult, expectedNumberOfRanges); + } + + @Test + public void testMergingAdjacentRanges() { + Event[] set1 = + new Event[] { + new Event(1, nextTimestamp(), true), new Event(2, nextTimestamp()), + }; + Event[] set2 = + new Event[] { + new Event(3, nextTimestamp()), + new Event(4, nextTimestamp()), + new Event(5, nextTimestamp()), + new Event(6, nextTimestamp()) + }; + + ContiguousSequenceRange expectedResult = + ContiguousSequenceRange.of(1, 7, eventTimestamp(set2, 6L)); + int expectedNumberOfRanges = 1; + + doTestMerging(set1, set2, expectedResult, expectedNumberOfRanges); + } + + @Test + public void testPruningSequencesBelowInitial() { + Event[] set1 = + new Event[] { + new Event(1, nextTimestamp()), new Event(2, nextTimestamp()), + }; + Event[] set2 = + new Event[] { + new Event(3, nextTimestamp(), true), + new Event(4, nextTimestamp()), + new Event(5, nextTimestamp()), + new Event(6, nextTimestamp()) + }; + + ContiguousSequenceRange expectedResult = + ContiguousSequenceRange.of(3, 7, eventTimestamp(set2, 6L)); + int expectedNumberOfRanges = 1; + + doTestMerging(set1, set2, expectedResult, expectedNumberOfRanges); + } + + @Test + public void testDuplicateHandling() { + Event[] set1 = + new Event[] { + new Event(1, nextTimestamp(), true), + new Event(2, nextTimestamp()), + new Event(3, nextTimestamp()), + new Event(5, nextTimestamp()), + }; + Event[] set2 = + new Event[] { + new Event(3, nextTimestamp()), + new Event(4, nextTimestamp()), + new Event(5, nextTimestamp()), + new Event(6, nextTimestamp()) + }; + + ContiguousSequenceRange expectedResult = + ContiguousSequenceRange.of(1, 7, eventTimestamp(set2, 6L)); + int expectedNumberOfRanges = 1; + + doTestMerging(set1, set2, expectedResult, expectedNumberOfRanges); + } + + @Test + public void testExceptionThrownIfThereAreDifferentInitialSequences() { + Event[] set1 = + new Event[] { + new Event(1, nextTimestamp(), true), new Event(2, nextTimestamp()), + }; + Event[] set2 = + new Event[] { + new Event(3, nextTimestamp(), true), + new Event(4, nextTimestamp()), + new Event(5, nextTimestamp()), + new Event(6, nextTimestamp()) + }; + + try { + doTestMerging(set1, set2, ContiguousSequenceRange.EMPTY, 0); + Assert.fail("Expected to throw an exception"); + } catch (IllegalStateException e) { + Assert.assertEquals( + "Exception message", + "Two accumulators contain different initial sequences: 1 and 3", + e.getMessage()); + } + } + + @Test + public void testSelectingHighestTimestampWhenMerging() { + Event[] set1 = + new Event[] { + new Event(1, nextTimestamp(), true), + new Event(2, Instant.ofEpochMilli(currentTicker.get() + 10000)), + }; + Event[] set2 = + new Event[] { + new Event(3, nextTimestamp()), + new Event(4, nextTimestamp()), + new Event(5, nextTimestamp()), + new Event(6, nextTimestamp()) + }; + + ContiguousSequenceRange expectedResult = + ContiguousSequenceRange.of(1, 7, eventTimestamp(set1, 2L)); + int expectedNumberOfRanges = 1; + doTestMerging(set1, set2, expectedResult, expectedNumberOfRanges); + } + + private static void doTestMerging( + Event[] set1, + Event[] set2, + ContiguousSequenceRange expectedResult, + int expectedNumberOfRanges) { + // Try to merge both set2 to set1 and set1 to set2 - both must return the same results + mergeAndTest(set1, set2, expectedResult, expectedNumberOfRanges, "set1"); + mergeAndTest(set2, set1, expectedResult, expectedNumberOfRanges, "set2"); + } + + private static void mergeAndTest( + Event[] set1, + Event[] set2, + ContiguousSequenceRange expectedResult, + int expectedNumberOfRanges, + String firstSetName) { + final SequenceRangeAccumulator a1 = new SequenceRangeAccumulator(); + Arrays.stream(set1).forEach(e -> a1.add(e.sequence, e.timestamp, e.initialEvent)); + + final SequenceRangeAccumulator a2 = new SequenceRangeAccumulator(); + Arrays.stream(set2).forEach(e -> a2.add(e.sequence, e.timestamp, e.initialEvent)); + + a1.merge(a2); + + Assert.assertEquals( + "Accumulated results - " + firstSetName, expectedResult, a1.largestContinuousRange()); + + Assert.assertEquals( + "Number of ranges - " + firstSetName, expectedNumberOfRanges, a1.numberOfRanges()); + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java index d159e9de44a8..9fe6162ec936 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -104,16 +104,14 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) class ProtoByteBuddyUtils { private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); private static final TypeDescriptor BYTE_STRING_TYPE_DESCRIPTOR = @@ -270,7 +268,7 @@ static class ProtoConvertType extends ConvertType { .build(); @Override - public Type convert(TypeDescriptor typeDescriptor) { + public Type convert(TypeDescriptor typeDescriptor) { if (typeDescriptor.equals(BYTE_STRING_TYPE_DESCRIPTOR) || typeDescriptor.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { return byte[].class; @@ -297,7 +295,7 @@ protected ProtoTypeConversionsFactory getFactory() { } @Override - public StackManipulation convert(TypeDescriptor type) { + public StackManipulation convert(TypeDescriptor type) { if (type.equals(BYTE_STRING_TYPE_DESCRIPTOR) || type.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { return new Compound( @@ -372,7 +370,7 @@ protected ProtoTypeConversionsFactory getFactory() { } @Override - public StackManipulation convert(TypeDescriptor type) { + public StackManipulation convert(TypeDescriptor type) { if (type.isSubtypeOf(TypeDescriptor.of(ByteString.class))) { return new Compound( readValue, @@ -459,7 +457,7 @@ public TypeConversion createSetterConversions(StackManipulati // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = + private static final Map>> CACHED_GETTERS = Maps.newConcurrentMap(); /** @@ -467,35 +465,36 @@ public TypeConversion createSetterConversions(StackManipulati * *

    The returned list is ordered by the order of fields in the schema. */ - public static List getGetters( - Class clazz, + public static List> getGetters( + Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { Multimap methods = ReflectUtils.getMethodsMap(clazz); - return CACHED_GETTERS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), - c -> { - List types = - fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); - return types.stream() - .map( - t -> - createGetter( - t, - typeConversionsFactory, - clazz, - methods, - schema.getField(t.getName()), - fieldValueTypeSupplier)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + ClassWithSchema.create(clazz, schema), + c -> { + List types = + fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); + return types.stream() + .map( + t -> + createGetter( + t, + typeConversionsFactory, + clazz, + methods, + schema.getField(t.getName()), + fieldValueTypeSupplier)) + .collect(Collectors.toList()); + }); } - static FieldValueGetter createOneOfGetter( + static FieldValueGetter<@NonNull ProtoT, OneOfType.Value> createOneOfGetter( FieldValueTypeInformation typeInformation, - TreeMap> getterMethodMap, - Class protoClass, + TreeMap> getterMethodMap, + Class protoClass, OneOfType oneOfType, Method getCaseMethod) { Set indices = getterMethodMap.keySet(); @@ -505,7 +504,7 @@ static FieldValueGetter createOneOfGetter( int[] keys = getterMethodMap.keySet().stream().mapToInt(Integer::intValue).toArray(); - DynamicType.Builder builder = + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface(BYTE_BUDDY, protoClass, OneOfType.Value.class); builder = builder @@ -514,7 +513,8 @@ static FieldValueGetter createOneOfGetter( .method(ElementMatchers.named("get")) .intercept(new OneOfGetterInstruction(contiguous, keys, getCaseMethod)); - List getters = Lists.newArrayList(getterMethodMap.values()); + List> getters = + Lists.newArrayList(getterMethodMap.values()); builder = builder // Store a field with the list of individual getters. The get() instruction will pick @@ -556,12 +556,12 @@ static FieldValueGetter createOneOfGetter( FieldValueSetter createOneOfSetter( String name, TreeMap> setterMethodMap, - Class protoBuilderClass) { + Class protoBuilderClass) { Set indices = setterMethodMap.keySet(); boolean contiguous = isContiguous(indices); int[] keys = setterMethodMap.keySet().stream().mapToInt(Integer::intValue).toArray(); - DynamicType.Builder builder = + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, protoBuilderClass, OneOfType.Value.class); builder = @@ -585,7 +585,8 @@ FieldValueSetter createOneOfSetter( .withParameters(List.class) .intercept(new OneOfSetterConstructor()); - List setters = Lists.newArrayList(setterMethodMap.values()); + List> setters = + Lists.newArrayList(setterMethodMap.values()); try { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) @@ -947,10 +948,10 @@ public ByteCodeAppender appender(final Target implementationTarget) { } } - private static FieldValueGetter createGetter( + private static FieldValueGetter<@NonNull ProtoT, ?> createGetter( FieldValueTypeInformation fieldValueTypeInformation, TypeConversionsFactory typeConversionsFactory, - Class clazz, + Class clazz, Multimap methods, Field field, FieldValueTypeSupplier fieldValueTypeSupplier) { @@ -964,21 +965,23 @@ private static FieldValueGetter createGetter( field.getName() + "_case", FieldType.logicalType(oneOfType.getCaseEnumType())); // Create a map of case enum value to getter. This must be sorted, so store in a TreeMap. - TreeMap> oneOfGetters = Maps.newTreeMap(); + TreeMap> oneOfGetters = + Maps.newTreeMap(); Map oneOfFieldTypes = fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), oneOfType.getOneOfSchema()).stream() .collect(Collectors.toMap(FieldValueTypeInformation::getName, f -> f)); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { int protoFieldIndex = getFieldNumber(oneOfField); - FieldValueGetter oneOfFieldGetter = + FieldValueGetter<@NonNull ProtoT, ?> oneOfFieldGetter = createGetter( - oneOfFieldTypes.get(oneOfField.getName()), + Verify.verifyNotNull(oneOfFieldTypes.get(oneOfField.getName())), typeConversionsFactory, clazz, methods, oneOfField, fieldValueTypeSupplier); - oneOfGetters.put(protoFieldIndex, oneOfFieldGetter); + oneOfGetters.put( + protoFieldIndex, (FieldValueGetter<@NonNull ProtoT, OneOfType.Value>) oneOfFieldGetter); } return createOneOfGetter( fieldValueTypeInformation, oneOfGetters, clazz, oneOfType, caseMethod); @@ -987,10 +990,11 @@ private static FieldValueGetter createGetter( } } - private static Class getProtoGeneratedBuilder(Class clazz) { + private static @Nullable Class getProtoGeneratedBuilder( + Class clazz) { String builderClassName = clazz.getName() + "$Builder"; try { - return Class.forName(builderClassName); + return (Class) Class.forName(builderClassName); } catch (ClassNotFoundException e) { return null; } @@ -1018,51 +1022,59 @@ static Method getProtoGetter(Multimap methods, String name, Fiel public static @Nullable SchemaUserTypeCreator getBuilderCreator( - Class protoClass, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { - Class builderClass = getProtoGeneratedBuilder(protoClass); + TypeDescriptor protoTypeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier) { + Class builderClass = getProtoGeneratedBuilder(protoTypeDescriptor.getRawType()); if (builderClass == null) { return null; } Multimap methods = ReflectUtils.getMethodsMap(builderClass); List> setters = schema.getFields().stream() - .map(f -> getProtoFieldValueSetter(f, methods, builderClass)) + .map(f -> getProtoFieldValueSetter(protoTypeDescriptor, f, methods, builderClass)) .collect(Collectors.toList()); - return createBuilderCreator(protoClass, builderClass, setters, schema); + return createBuilderCreator(protoTypeDescriptor.getRawType(), builderClass, setters, schema); } private static FieldValueSetter getProtoFieldValueSetter( - Field field, Multimap methods, Class builderClass) { + TypeDescriptor typeDescriptor, + Field field, + Multimap methods, + Class builderClass) { if (field.getType().isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = field.getType().getLogicalType(OneOfType.class); TreeMap> oneOfSetters = Maps.newTreeMap(); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { - FieldValueSetter setter = getProtoFieldValueSetter(oneOfField, methods, builderClass); + FieldValueSetter setter = + getProtoFieldValueSetter(typeDescriptor, oneOfField, methods, builderClass); oneOfSetters.put(getFieldNumber(oneOfField), setter); } return createOneOfSetter(field.getName(), oneOfSetters, builderClass); } else { Method method = getProtoSetter(methods, field.getName(), field.getType()); return JavaBeanUtils.createSetter( - FieldValueTypeInformation.forSetter(method, protoSetterPrefix(field.getType())), + FieldValueTypeInformation.forSetter( + typeDescriptor, method, protoSetterPrefix(field.getType())), new ProtoTypeConversionsFactory()); } } static SchemaUserTypeCreator createBuilderCreator( Class protoClass, - Class builderClass, + Class builderClass, List> setters, Schema schema) { try { - DynamicType.Builder builder = - BYTE_BUDDY - .with(new InjectPackageStrategy(builderClass)) - .subclass(Supplier.class) - .method(ElementMatchers.named("get")) - .intercept(new BuilderSupplier(protoClass)); - Supplier supplier = + DynamicType.Builder> builder = + (DynamicType.Builder) + BYTE_BUDDY + .with(new InjectPackageStrategy(builderClass)) + .subclass(Supplier.class) + .method(ElementMatchers.named("get")) + .intercept(new BuilderSupplier(protoClass)); + Supplier supplier = builder .visit( new AsmVisitorWrapper.ForDeclaredMethods() diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java index faf3ad407af5..b0bb9071524b 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java @@ -43,12 +43,9 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) public class ProtoMessageSchema extends GetterBasedSchemaProviderV2 { private static final class ProtoClassFieldValueTypeSupplier implements FieldValueTypeSupplier { @@ -72,7 +69,8 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = getProtoGetter(methods, oneOfField.getName(), oneOfField.getType()); oneOfTypes.put( oneOfField.getName(), - FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) + .withName(field.getName())); } // Add an entry that encapsulates information about all possible getters. types.add( @@ -82,7 +80,9 @@ public List get(TypeDescriptor typeDescriptor, Sch } else { // This is a simple field. Add the getter. Method method = getProtoGetter(methods, field.getName(), field.getType()); - types.add(FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + types.add( + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) + .withName(field.getName())); } } return types; @@ -96,8 +96,8 @@ public List get(TypeDescriptor typeDescriptor, Sch } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return ProtoByteBuddyUtils.getGetters( targetTypeDescriptor.getRawType(), schema, @@ -117,7 +117,7 @@ public SchemaUserTypeCreator schemaTypeCreator( TypeDescriptor targetTypeDescriptor, Schema schema) { SchemaUserTypeCreator creator = ProtoByteBuddyUtils.getBuilderCreator( - targetTypeDescriptor.getRawType(), schema, new ProtoClassFieldValueTypeSupplier()); + targetTypeDescriptor, schema, new ProtoClassFieldValueTypeSupplier()); if (creator == null) { throw new RuntimeException("Cannot create creator for " + targetTypeDescriptor); } @@ -152,7 +152,8 @@ public static SimpleFunction getRowToProtoBytesFn(Class claz private void checkForDynamicType(TypeDescriptor typeDescriptor) { if (typeDescriptor.getRawType().equals(DynamicMessage.class)) { throw new RuntimeException( - "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use ProtoDynamicMessageSchema instead."); + "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use" + + " ProtoDynamicMessageSchema instead."); } } diff --git a/sdks/java/extensions/sql/expansion-service/build.gradle b/sdks/java/extensions/sql/expansion-service/build.gradle index b6963cf7547b..b8d78e4e1bb9 100644 --- a/sdks/java/extensions/sql/expansion-service/build.gradle +++ b/sdks/java/extensions/sql/expansion-service/build.gradle @@ -46,3 +46,7 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.main.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } + +shadowJar { + outputs.upToDateWhen { false } +} \ No newline at end of file diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java index 372e77c54c67..d0f6427a262e 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java @@ -36,7 +36,6 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.MockConsumer; @@ -138,10 +137,6 @@ public synchronized void assign(final Collection assigned) { .collect(Collectors.toList()); super.assign(realPartitions); assignedPartitions.set(ImmutableList.copyOf(realPartitions)); - for (TopicPartition tp : realPartitions) { - updateBeginningOffsets(ImmutableMap.of(tp, 0L)); - updateEndOffsets(ImmutableMap.of(tp, (long) kafkaRecords.get(tp).size())); - } } // Override offsetsForTimes() in order to look up the offsets by timestamp. @Override @@ -163,9 +158,12 @@ public synchronized Map offsetsForTimes( } }; - for (String topic : getTopics()) { - consumer.updatePartitions(topic, partitionInfoMap.get(topic)); - } + partitionInfoMap.forEach(consumer::updatePartitions); + consumer.updateBeginningOffsets( + kafkaRecords.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> 0L))); + consumer.updateEndOffsets( + kafkaRecords.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size()))); Runnable recordEnqueueTask = new Runnable() { diff --git a/sdks/java/extensions/sql/zetasql/build.gradle b/sdks/java/extensions/sql/zetasql/build.gradle index 8d6e2aac0bf4..29a3f95402b0 100644 --- a/sdks/java/extensions/sql/zetasql/build.gradle +++ b/sdks/java/extensions/sql/zetasql/build.gradle @@ -27,7 +27,7 @@ applyJavaNature( description = "Apache Beam :: SDKs :: Java :: Extensions :: SQL :: ZetaSQL" ext.summary = "ZetaSQL to Calcite translator" -def zetasql_version = "2022.04.1" +def zetasql_version = "2024.11.1" dependencies { // TODO(https://github.com/apache/beam/issues/21156): Determine how to build without this dependency diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java index 2dfc7fe372f5..412cd46001f8 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java @@ -29,6 +29,7 @@ import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn; +import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumnBase; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr; import java.util.ArrayList; import java.util.Arrays; @@ -94,7 +95,7 @@ public RelNode convert(ResolvedAggregateScan zetaNode, List inputs) { aggregateCalls = new ArrayList<>(); // For aggregate calls, their input ref follow after GROUP BY input ref. int columnRefoff = groupFieldsListSize; - for (ResolvedComputedColumn computedColumn : zetaNode.getAggregateList()) { + for (ResolvedComputedColumnBase computedColumn : zetaNode.getAggregateList()) { AggregateCall aggCall = convertAggCall(computedColumn, columnRefoff, groupSet.size(), input); aggregateCalls.add(aggCall); @@ -144,7 +145,7 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject( // LogicalProject should also include columns used by aggregate functions. These columns should // follow after GROUP BY columns. // TODO: remove duplicate columns in projects. - for (ResolvedComputedColumn resolvedComputedColumn : node.getAggregateList()) { + for (ResolvedComputedColumnBase resolvedComputedColumn : node.getAggregateList()) { // Should create Calcite's RexInputRef from ResolvedColumn from ResolvedComputedColumn. // TODO: handle aggregate function with more than one argument and handle OVER // TODO: is there is general way for column reference tracking and deduplication for @@ -180,7 +181,7 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject( } private AggregateCall convertAggCall( - ResolvedComputedColumn computedColumn, int columnRefOff, int groupCount, RelNode input) { + ResolvedComputedColumnBase computedColumn, int columnRefOff, int groupCount, RelNode input) { ResolvedAggregateFunctionCall aggregateFunctionCall = (ResolvedAggregateFunctionCall) computedColumn.getExpr(); diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java index 14db554d6f0b..0f32451504b3 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java @@ -37,6 +37,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.zetasql.TVFRelation; +import com.google.zetasql.TVFRelation.Column; import com.google.zetasql.TableValuedFunction; import com.google.zetasql.TableValuedFunction.FixedOutputSchemaTVF; import com.google.zetasql.Type; @@ -65,6 +66,7 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.reflect.FieldUtils; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters; import org.apache.beam.sdk.extensions.sql.impl.ZetaSqlUserDefinedSQLNativeTableValuedFunction; @@ -495,9 +497,16 @@ public RexCall convertTableValuedFunction( new ZetaSqlUserDefinedSQLNativeTableValuedFunction( new SqlIdentifier(tvf.getName(), SqlParserPos.ZERO), opBinding -> { + TVFRelation rel = fixedOutputSchemaTVF.getOutputSchema(); + // TODO(yathu) revert this workaround when ZetaSQL adds back this API. + List cols; + try { + cols = (List) FieldUtils.readField(rel, "columns", true); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } List relDataTypeFields = - convertTVFRelationColumnsToRelDataTypeFields( - fixedOutputSchemaTVF.getOutputSchema().getColumns()); + convertTVFRelationColumnsToRelDataTypeFields(cols); return new RelRecordType(relDataTypeFields); }, null, diff --git a/sdks/java/harness/build.gradle b/sdks/java/harness/build.gradle index 2de578cb32cf..ed3034c08612 100644 --- a/sdks/java/harness/build.gradle +++ b/sdks/java/harness/build.gradle @@ -30,6 +30,7 @@ dependencies { provided project(path: ":model:pipeline", configuration: "shadow") provided project(path: ":sdks:java:core", configuration: "shadow") provided project(path: ":sdks:java:transform-service:launcher") + provided library.java.google_api_services_dataflow provided library.java.avro provided library.java.jackson_databind provided library.java.joda_time @@ -79,4 +80,5 @@ dependencies { shadowTest project(path: ":sdks:java:core", configuration: "shadowTest") shadowTestRuntimeClasspath library.java.slf4j_jdk14 permitUnusedDeclared library.java.avro + permitUnusedDeclared library.java.google_api_services_dataflow } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 0d520dcf7f5c..300796ac6f12 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -64,7 +64,6 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; -import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.Coder; @@ -84,6 +83,7 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.SdkHarnessOptions; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.common.ReflectHelpers; @@ -93,6 +93,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; @@ -108,20 +109,19 @@ import org.slf4j.LoggerFactory; /** - * Processes {@link BeamFnApi.ProcessBundleRequest}s and {@link - * BeamFnApi.ProcessBundleSplitRequest}s. + * Processes {@link ProcessBundleRequest}s and {@link BeamFnApi.ProcessBundleSplitRequest}s. * *

    {@link BeamFnApi.ProcessBundleSplitRequest}s use a {@link BundleProcessorCache cache} to * find/create a {@link BundleProcessor}. The creation of a {@link BundleProcessor} uses the - * associated {@link BeamFnApi.ProcessBundleDescriptor} definition; creating runners for each {@link + * associated {@link ProcessBundleDescriptor} definition; creating runners for each {@link * RunnerApi.FunctionSpec}; wiring them together based upon the {@code input} and {@code output} map * definitions. The {@link BundleProcessor} executes the DAG based graph by starting all runners in * reverse topological order, and finishing all runners in forward topological order. * *

    {@link BeamFnApi.ProcessBundleSplitRequest}s finds an {@code active} {@link BundleProcessor} - * associated with a currently processing {@link BeamFnApi.ProcessBundleRequest} and uses it to - * perform a split request. See breaking the - * fusion barrier for further details. + * associated with a currently processing {@link ProcessBundleRequest} and uses it to perform a + * split request. See breaking the fusion + * barrier for further details. */ @SuppressWarnings({ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) @@ -153,7 +153,7 @@ public class ProcessBundleHandler { } private final PipelineOptions options; - private final Function fnApiRegistry; + private final Function fnApiRegistry; private final BeamFnDataClient beamFnDataClient; private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache; private final FinalizeBundleHandler finalizeBundleHandler; @@ -170,7 +170,7 @@ public class ProcessBundleHandler { public ProcessBundleHandler( PipelineOptions options, Set runnerCapabilities, - Function fnApiRegistry, + Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, FinalizeBundleHandler finalizeBundleHandler, @@ -189,7 +189,8 @@ public ProcessBundleHandler( executionStateSampler, REGISTERED_RUNNER_FACTORIES, processWideCache, - new BundleProcessorCache(), + new BundleProcessorCache( + options.as(SdkHarnessOptions.class).getBundleProcessorCacheTimeout()), dataSampler); } @@ -197,7 +198,7 @@ public ProcessBundleHandler( ProcessBundleHandler( PipelineOptions options, Set runnerCapabilities, - Function fnApiRegistry, + Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, FinalizeBundleHandler finalizeBundleHandler, @@ -216,7 +217,7 @@ public ProcessBundleHandler( this.runnerCapabilities = runnerCapabilities; this.runnerAcceptsShortIds = runnerCapabilities.contains( - BeamUrns.getUrn(RunnerApi.StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS)); + BeamUrns.getUrn(StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS)); this.executionStateSampler = executionStateSampler; this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap; this.defaultPTransformRunnerFactory = @@ -232,7 +233,7 @@ private void createRunnerAndConsumersForPTransformRecursively( String pTransformId, PTransform pTransform, Supplier processBundleInstructionId, - Supplier> cacheTokens, + Supplier> cacheTokens, Supplier> bundleCache, ProcessBundleDescriptor processBundleDescriptor, SetMultimap pCollectionIdsToConsumingPTransforms, @@ -242,7 +243,7 @@ private void createRunnerAndConsumersForPTransformRecursively( PTransformFunctionRegistry finishFunctionRegistry, Consumer addResetFunction, Consumer addTearDownFunction, - BiConsumer> addDataEndpoint, + BiConsumer> addDataEndpoint, Consumer> addTimerEndpoint, Consumer addBundleProgressReporter, BundleSplitListener splitListener, @@ -499,28 +500,29 @@ public BundleFinalizer getBundleFinalizer() { * Processes a bundle, running the start(), process(), and finish() functions. This function is * required to be reentrant. */ - public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request) + public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest request) throws Exception { - BeamFnApi.ProcessBundleResponse.Builder response = BeamFnApi.ProcessBundleResponse.newBuilder(); - - BundleProcessor bundleProcessor = - bundleProcessorCache.get( - request, - () -> { - try { - return createBundleProcessor( - request.getProcessBundle().getProcessBundleDescriptorId(), - request.getProcessBundle()); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); + @Nullable BundleProcessor bundleProcessor = null; try { + bundleProcessor = + Preconditions.checkNotNull( + bundleProcessorCache.get( + request, + () -> { + try { + return createBundleProcessor( + request.getProcessBundle().getProcessBundleDescriptorId(), + request.getProcessBundle()); + } catch (IOException e) { + throw new RuntimeException(e); + } + })); + PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry(); PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry(); ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker(); - + ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder(); try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) { stateTracker.start(request.getInstructionId()); try { @@ -596,8 +598,17 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction request.getProcessBundle().getProcessBundleDescriptorId(), bundleProcessor); return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response); } catch (Exception e) { - // Make sure we clean-up from the active set of bundle processors. - bundleProcessorCache.discard(bundleProcessor); + LOG.debug( + "Error processing bundle {} with bundleProcessor for {} after exception: {}", + request.getInstructionId(), + request.getProcessBundle().getProcessBundleDescriptorId(), + e.getMessage()); + if (bundleProcessor != null) { + // Make sure we clean up from the active set of bundle processors. + bundleProcessorCache.discard(bundleProcessor); + } + // Ensure that if more data arrives for the instruction it is discarded. + beamFnDataClient.poisonInstructionId(request.getInstructionId()); throw e; } } @@ -639,7 +650,7 @@ private void embedOutboundElementsIfApplicable( } } - public BeamFnApi.InstructionResponse.Builder progress(BeamFnApi.InstructionRequest request) + public BeamFnApi.InstructionResponse.Builder progress(InstructionRequest request) throws Exception { BundleProcessor bundleProcessor = bundleProcessorCache.find(request.getProcessBundleProgress().getInstructionId()); @@ -723,7 +734,7 @@ private Map finalMonitoringData(BundleProcessor bundleProces } /** Splits an active bundle. */ - public BeamFnApi.InstructionResponse.Builder trySplit(BeamFnApi.InstructionRequest request) { + public BeamFnApi.InstructionResponse.Builder trySplit(InstructionRequest request) { BundleProcessor bundleProcessor = bundleProcessorCache.find(request.getProcessBundleSplit().getInstructionId()); BeamFnApi.ProcessBundleSplitResponse.Builder response = @@ -768,8 +779,8 @@ public void discard() { } private BundleProcessor createBundleProcessor( - String bundleId, BeamFnApi.ProcessBundleRequest processBundleRequest) throws IOException { - BeamFnApi.ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId); + String bundleId, ProcessBundleRequest processBundleRequest) throws IOException { + ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId); SetMultimap pCollectionIdsToConsumingPTransforms = HashMultimap.create(); BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar = @@ -795,8 +806,7 @@ private BundleProcessor createBundleProcessor( List tearDownFunctions = new ArrayList<>(); // Build a multimap of PCollection ids to PTransform ids which consume said PCollections - for (Map.Entry entry : - bundleDescriptor.getTransformsMap().entrySet()) { + for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { for (String pCollectionId : entry.getValue().getInputsMap().values()) { pCollectionIdsToConsumingPTransforms.put(pCollectionId, entry.getKey()); } @@ -844,8 +854,7 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) { runnerCapabilities); // Create a BeamFnStateClient - for (Map.Entry entry : - bundleDescriptor.getTransformsMap().entrySet()) { + for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { // Skip anything which isn't a root. // Also force data output transforms to be unconditionally instantiated (see BEAM-10450). @@ -920,25 +929,25 @@ public int hashCode() { return super.hashCode(); } - BundleProcessorCache() { - this.cachedBundleProcessors = + BundleProcessorCache(Duration timeout) { + CacheBuilder> builder = CacheBuilder.newBuilder() - .expireAfterAccess(Duration.ofMinutes(1L)) .removalListener( - removalNotification -> { - ((ConcurrentLinkedQueue) removalNotification.getValue()) - .forEach( - bundleProcessor -> { - bundleProcessor.shutdown(); - }); - }) - .build( - new CacheLoader>() { - @Override - public ConcurrentLinkedQueue load(String s) throws Exception { - return new ConcurrentLinkedQueue<>(); - } - }); + removalNotification -> + removalNotification + .getValue() + .forEach(bundleProcessor -> bundleProcessor.shutdown())); + if (timeout.compareTo(Duration.ZERO) > 0) { + builder = builder.expireAfterAccess(timeout); + } + this.cachedBundleProcessors = + builder.build( + new CacheLoader>() { + @Override + public ConcurrentLinkedQueue load(String s) throws Exception { + return new ConcurrentLinkedQueue<>(); + } + }); // We specifically use a weak hash map so that references will automatically go out of scope // and not need to be freed explicitly from the cache. this.activeBundleProcessors = Collections.synchronizedMap(new WeakHashMap<>()); @@ -1086,7 +1095,7 @@ public static BundleProcessor create( abstract HandleStateCallsForBundle getBeamFnStateClient(); - abstract List getInboundEndpointApiServiceDescriptors(); + abstract List getInboundEndpointApiServiceDescriptors(); abstract List> getInboundDataEndpoints(); @@ -1113,7 +1122,7 @@ synchronized List getCacheTokens() { synchronized Cache getBundleCache() { if (this.bundleCache == null) { this.bundleCache = - new Caches.ClearableCache<>( + new ClearableCache<>( Caches.subCache(getProcessWideCache(), "Bundle", this.instructionId)); } return this.bundleCache; @@ -1168,6 +1177,18 @@ void discard() { if (this.bundleCache != null) { this.bundleCache.clear(); } + // setupFunctions are invoked in createBundleProcessor. Invoke teardownFunction here as the + // BundleProcessor is already removed from cache and won't be re-used. + for (ThrowingRunnable teardownFunction : Lists.reverse(this.getTearDownFunctions())) { + try { + teardownFunction.run(); + } catch (Throwable e) { + LOG.warn( + "Exceptions are thrown from DoFn.teardown method when trying to discard " + + "ProcessBundleHandler", + e); + } + } getMetricsEnvironmentStateForBundle().discard(); for (BeamFnDataOutboundAggregator aggregator : getOutboundAggregators().values()) { aggregator.discard(); @@ -1175,6 +1196,7 @@ void discard() { } } + // this is called in cachedBundleProcessors removal listener void shutdown() { for (ThrowingRunnable tearDownFunction : getTearDownFunctions()) { LOG.debug("Tearing down function {}", tearDownFunction); @@ -1247,7 +1269,7 @@ public void close() throws Exception { } @Override - public CompletableFuture handle(BeamFnApi.StateRequest.Builder requestBuilder) { + public CompletableFuture handle(StateRequest.Builder requestBuilder) { throw new IllegalStateException( String.format( "State API calls are unsupported because the " diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java index 75f3a24301c9..94d59d0fcb62 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java @@ -55,10 +55,19 @@ void registerReceiver( * successfully. * *

    It is expected that if a bundle fails during processing then the failure will become visible - * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation. + * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation or via + * a call to {@link #poisonInstructionId}. */ void unregisterReceiver(String instructionId, List apiServiceDescriptors); + /** + * Poisons the instruction id, indicating that future data arriving for it should be discarded. + * Unregisters the receiver if was registered. + * + * @param instructionId + */ + void poisonInstructionId(String instructionId); + /** * Creates a {@link BeamFnDataOutboundAggregator} for buffering and sending outbound data and * timers over the data plane. It is important that {@link diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java index 981b115c58e7..cd1ac26e364d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java @@ -82,6 +82,14 @@ public void unregisterReceiver( } } + @Override + public void poisonInstructionId(String instructionId) { + LOG.debug("Poisoning instruction {}", instructionId); + for (BeamFnDataGrpcMultiplexer client : multiplexerCache.values()) { + client.poisonInstructionId(instructionId); + } + } + @Override public BeamFnDataOutboundAggregator createOutboundAggregator( ApiServiceDescriptor apiServiceDescriptor, diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 3b9fccfa2a5e..81a2aa6d1cc6 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -105,7 +105,7 @@ public long getWeight() { // many different state subcaches. return 0; } - }; + } /** A mutable iterable that supports prefetch and is backed by a cache. */ static class CachingStateIterable extends PrefetchableIterables.Default { @@ -138,8 +138,8 @@ public long getWeight() { private static long sumWeight(List> blocks) { try { long sum = 0; - for (int i = 0; i < blocks.size(); ++i) { - sum = Math.addExact(sum, blocks.get(i).getWeight()); + for (Block block : blocks) { + sum = Math.addExact(sum, block.getWeight()); } return sum; } catch (ArithmeticException e) { @@ -437,50 +437,59 @@ public boolean hasNext() { if (currentBlock.getValues().size() > currentCachedBlockValueIndex) { return true; } - if (currentBlock.getNextToken() == null) { + final ByteString nextToken = currentBlock.getNextToken(); + if (nextToken == null) { return false; } - Blocks existing = cache.peek(IterableCacheKey.INSTANCE); - boolean isFirstBlock = ByteString.EMPTY.equals(currentBlock.getNextToken()); + // Release the block while we are loading the next one. + currentBlock = + Block.fromValues(new WeightedList<>(Collections.emptyList(), 0L), ByteString.EMPTY); + + @Nullable Blocks existing = cache.peek(IterableCacheKey.INSTANCE); + boolean isFirstBlock = ByteString.EMPTY.equals(nextToken); if (existing == null) { - currentBlock = loadNextBlock(currentBlock.getNextToken()); + currentBlock = loadNextBlock(nextToken); if (isFirstBlock) { cache.put( IterableCacheKey.INSTANCE, new BlocksPrefix<>(Collections.singletonList(currentBlock))); } + } else if (isFirstBlock) { + currentBlock = existing.getBlocks().get(0); } else { - if (isFirstBlock) { - currentBlock = existing.getBlocks().get(0); - } else { - checkState( - existing instanceof BlocksPrefix, - "Unexpected blocks type %s, expected a %s.", - existing.getClass(), - BlocksPrefix.class); - List> blocks = existing.getBlocks(); - int currentBlockIndex = 0; - for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { - if (currentBlock - .getNextToken() - .equals(blocks.get(currentBlockIndex).getNextToken())) { - break; - } + checkState( + existing instanceof BlocksPrefix, + "Unexpected blocks type %s, expected a %s.", + existing.getClass(), + BlocksPrefix.class); + List> blocks = existing.getBlocks(); + int currentBlockIndex = 0; + for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { + if (nextToken.equals(blocks.get(currentBlockIndex).getNextToken())) { + break; } - // Load the next block from cache if it was found. - if (currentBlockIndex + 1 < blocks.size()) { - currentBlock = blocks.get(currentBlockIndex + 1); - } else { - // Otherwise load the block from state API. - currentBlock = loadNextBlock(currentBlock.getNextToken()); - - // Append this block to the existing set of blocks if it is logically the next one. - if (currentBlockIndex == blocks.size() - 1) { - List> newBlocks = new ArrayList<>(currentBlockIndex + 1); - newBlocks.addAll(blocks); - newBlocks.add(currentBlock); - cache.put(IterableCacheKey.INSTANCE, new BlocksPrefix<>(newBlocks)); - } + } + // Take the next block from the cache if it was found. + if (currentBlockIndex + 1 < blocks.size()) { + currentBlock = blocks.get(currentBlockIndex + 1); + } else { + // Otherwise load the block from state API. + // Remove references on the cached values while we are loading the next block. + existing = null; + blocks = null; + currentBlock = loadNextBlock(nextToken); + existing = cache.peek(IterableCacheKey.INSTANCE); + // Append this block to the existing set of blocks if it is logically the next one + // according to the + // tokens. + if (existing != null + && !existing.getBlocks().isEmpty() + && nextToken.equals( + existing.getBlocks().get(existing.getBlocks().size() - 1).getNextToken())) { + List> newBlocks = new ArrayList<>(currentBlockIndex + 1); + newBlocks.addAll(existing.getBlocks()); + newBlocks.add(currentBlock); + cache.put(IterableCacheKey.INSTANCE, new BlocksPrefix<>(newBlocks)); } } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java index 9328dc86c009..acfd3bb70202 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java @@ -92,6 +92,11 @@ public BeamFnDataOutboundAggregator createOutboundAggregator( boolean collectElementsIfNoFlushes) { throw new UnsupportedOperationException("Unexpected call during test."); } + + @Override + public void poisonInstructionId(String instructionId) { + throw new UnsupportedOperationException("Unexpected call during test."); + } }) .beamFnStateClient( new BeamFnStateClient() { diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 2d1e323707f7..a69ea5338dc3 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -48,6 +48,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -354,6 +355,10 @@ void reset() throws Exception { private static class TestBundleProcessorCache extends BundleProcessorCache { + TestBundleProcessorCache() { + super(Duration.ZERO); + } + @Override BundleProcessor get( InstructionRequest processBundleRequest, @@ -376,7 +381,7 @@ public void testTrySplitBeforeBundleDoesNotFail() { executionStateSampler, ImmutableMap.of(), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); BeamFnApi.InstructionResponse response = @@ -407,7 +412,7 @@ public void testProgressBeforeBundleDoesNotFail() throws Exception { executionStateSampler, ImmutableMap.of(), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.progress( @@ -487,7 +492,7 @@ public void testOrderOfStartAndFinishCalls() throws Exception { DATA_INPUT_URN, startFinishRecorder, DATA_OUTPUT_URN, startFinishRecorder), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.processBundle( @@ -592,7 +597,7 @@ public void testOrderOfSetupTeardownCalls() throws Exception { executionStateSampler, urnToPTransformRunnerFactoryMap, Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.processBundle( @@ -699,7 +704,7 @@ private static InstructionRequest processBundleRequestFor( public void testBundleProcessorIsFoundWhenActive() { BundleProcessor bundleProcessor = mock(BundleProcessor.class); when(bundleProcessor.getInstructionId()).thenReturn("known"); - BundleProcessorCache cache = new BundleProcessorCache(); + BundleProcessorCache cache = new BundleProcessorCache(Duration.ZERO); // Check that an unknown bundle processor is not found assertNull(cache.find("unknown")); @@ -811,7 +816,7 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { throw new IllegalStateException("TestException"); }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "TestException", @@ -862,7 +867,7 @@ public void testBundleFinalizationIsPropagated() throws Exception { return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); BeamFnApi.InstructionResponse.Builder response = handler.processBundle( @@ -916,7 +921,7 @@ public void testPTransformStartExceptionsArePropagated() { return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "TestException", @@ -1094,7 +1099,7 @@ public void onCompleted() {} executionStateSampler, urnToPTransformRunnerFactoryMap, Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); } @@ -1427,7 +1432,7 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() @@ -1500,7 +1505,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "TestException", @@ -1516,6 +1521,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { // Ensure that we unregister during successful processing verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); + verify(beamFnDataClient).poisonInstructionId(eq("instructionId")); verifyNoMoreInteractions(beamFnDataClient); } @@ -1550,7 +1556,7 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "TestException", @@ -1646,7 +1652,7 @@ private void doStateCalls(BeamFnStateClient beamFnStateClient) { } }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() @@ -1697,7 +1703,7 @@ private void doStateCalls(BeamFnStateClient beamFnStateClient) { } }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "State API calls are unsupported", @@ -1786,7 +1792,7 @@ public void reset() { return null; }; - BundleProcessorCache bundleProcessorCache = new BundleProcessorCache(); + BundleProcessorCache bundleProcessorCache = new BundleProcessorCache(Duration.ZERO); ProcessBundleHandler handler = new ProcessBundleHandler( PipelineOptionsFactory.create(), @@ -1929,7 +1935,7 @@ public Object createRunnerForPTransform(Context context) throws IOException { } }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "Timers are unsupported", diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java index 3489fe766891..514cf61ded40 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java @@ -23,14 +23,17 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.model.fnexecution.v1.BeamFnApi; @@ -281,6 +284,93 @@ public StreamObserver data( } } + @Test + public void testForInboundConsumerThatIsPoisoned() throws Exception { + CountDownLatch waitForClientToConnect = new CountDownLatch(1); + CountDownLatch receivedAElement = new CountDownLatch(1); + Collection> inboundValuesA = new ConcurrentLinkedQueue<>(); + Collection inboundServerValues = new ConcurrentLinkedQueue<>(); + AtomicReference> outboundServerObserver = + new AtomicReference<>(); + CallStreamObserver inboundServerObserver = + TestStreams.withOnNext(inboundServerValues::add).build(); + + Endpoints.ApiServiceDescriptor apiServiceDescriptor = + Endpoints.ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID()) + .build(); + Server server = + InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService( + new BeamFnDataGrpc.BeamFnDataImplBase() { + @Override + public StreamObserver data( + StreamObserver outboundObserver) { + outboundServerObserver.set(outboundObserver); + waitForClientToConnect.countDown(); + return inboundServerObserver; + } + }) + .build(); + server.start(); + + try { + ManagedChannel channel = + InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + BeamFnDataGrpcClient clientFactory = + new BeamFnDataGrpcClient( + PipelineOptionsFactory.create(), + (Endpoints.ApiServiceDescriptor descriptor) -> channel, + OutboundObserverFactory.trivial()); + + BeamFnDataInboundObserver observerA = + BeamFnDataInboundObserver.forConsumers( + Arrays.asList( + DataEndpoint.create( + TRANSFORM_ID_A, + CODER, + (WindowedValue elem) -> { + receivedAElement.countDown(); + inboundValuesA.add(elem); + })), + Collections.emptyList()); + CompletableFuture future = + CompletableFuture.runAsync( + () -> { + try { + observerA.awaitCompletion(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + clientFactory.registerReceiver( + INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + + waitForClientToConnect.await(); + outboundServerObserver.get().onNext(ELEMENTS_B_1); + clientFactory.poisonInstructionId(INSTRUCTION_ID_B); + + outboundServerObserver.get().onNext(ELEMENTS_B_1); + outboundServerObserver.get().onNext(ELEMENTS_A_1); + assertTrue(receivedAElement.await(5, TimeUnit.SECONDS)); + + clientFactory.poisonInstructionId(INSTRUCTION_ID_A); + try { + future.get(); + fail(); // We expect the awaitCompletion to fail due to closing. + } catch (Exception ignored) { + } + + outboundServerObserver.get().onNext(ELEMENTS_A_2); + + assertThat(inboundValuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + } finally { + server.shutdownNow(); + } + } + @Test public void testForOutboundConsumer() throws Exception { CountDownLatch waitForInboundServerValuesCompletion = new CountDownLatch(2); diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java index 7ed56efa44bd..75d66c46478a 100644 --- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java +++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/s3/S3FileSystem.java @@ -627,7 +627,17 @@ protected S3ResourceId matchNewResource(String singleResourceSpec, boolean isDir @Override protected void reportLineage(S3ResourceId resourceId, Lineage lineage) { - lineage.add("s3", ImmutableList.of(resourceId.getBucket(), resourceId.getKey())); + reportLineage(resourceId, lineage, LineageLevel.FILE); + } + + @Override + protected void reportLineage(S3ResourceId resourceId, Lineage lineage, LineageLevel level) { + ImmutableList.Builder segments = + ImmutableList.builder().add(resourceId.getBucket()); + if (level != LineageLevel.TOP_LEVEL && !resourceId.getKey().isEmpty()) { + segments.add(resourceId.getKey()); + } + lineage.add("s3", segments.build()); } /** diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java index fbef40f4b5c0..db749d7080e2 100644 --- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java +++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/s3/S3FileSystemTest.java @@ -34,6 +34,7 @@ import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Matchers.anyObject; import static org.mockito.Matchers.notNull; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -74,6 +75,7 @@ import org.apache.beam.sdk.io.aws.options.S3Options; import org.apache.beam.sdk.io.fs.CreateOptions; import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -1209,6 +1211,21 @@ public void testWriteAndReadWithS3Options() throws IOException { open.close(); } + @Test + public void testReportLineageOnBucket() { + verifyLineage("s3://testbucket", ImmutableList.of("testbucket")); + verifyLineage("s3://testbucket/", ImmutableList.of("testbucket")); + verifyLineage("s3://testbucket/foo/bar.txt", ImmutableList.of("testbucket", "foo/bar.txt")); + } + + private void verifyLineage(String uri, List expected) { + S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("mys3"), client); + S3ResourceId path = S3ResourceId.fromUri(uri); + Lineage mockLineage = mock(Lineage.class); + s3FileSystem.reportLineage(path, mockLineage); + verify(mockLineage, times(1)).add("s3", expected); + } + /** A mockito argument matcher to implement equality on GetObjectMetadataRequest. */ private static class GetObjectMetadataRequestMatcher implements ArgumentMatcher { diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/s3/S3FileSystem.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/s3/S3FileSystem.java index 384c8c627ee7..e851f8333d0b 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/s3/S3FileSystem.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/s3/S3FileSystem.java @@ -658,7 +658,17 @@ protected S3ResourceId matchNewResource(String singleResourceSpec, boolean isDir @Override protected void reportLineage(S3ResourceId resourceId, Lineage lineage) { - lineage.add("s3", ImmutableList.of(resourceId.getBucket(), resourceId.getKey())); + reportLineage(resourceId, lineage, LineageLevel.FILE); + } + + @Override + protected void reportLineage(S3ResourceId resourceId, Lineage lineage, LineageLevel level) { + ImmutableList.Builder segments = + ImmutableList.builder().add(resourceId.getBucket()); + if (level != LineageLevel.TOP_LEVEL && !resourceId.getKey().isEmpty()) { + segments.add(resourceId.getKey()); + } + lineage.add("s3", segments.build()); } /** diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java index acdfcfc1ad09..e8b05a8a319e 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java @@ -19,7 +19,6 @@ import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; -import static org.apache.beam.sdk.io.aws2.schemas.AwsSchemaUtils.getter; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets.difference; @@ -46,6 +45,7 @@ import org.apache.beam.sdk.values.RowWithGetters; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; @@ -73,17 +73,20 @@ public class AwsSchemaProvider extends GetterBasedSchemaProviderV2 { return AwsTypes.schemaFor(sdkFields((Class) type.getRawType())); } - @SuppressWarnings("rawtypes") @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { ConverterFactory fromAws = ConverterFactory.fromAws(); Map> sdkFields = sdkFieldsByName((Class) targetTypeDescriptor.getRawType()); - List getters = new ArrayList<>(schema.getFieldCount()); - for (String field : schema.getFieldNames()) { + List> getters = new ArrayList<>(schema.getFieldCount()); + for (@NonNull String field : schema.getFieldNames()) { SdkField sdkField = checkStateNotNull(sdkFields.get(field), "Unknown field"); - getters.add(getter(field, fromAws.create(sdkField::getValueOrDefault, sdkField))); + getters.add( + AwsSchemaUtils.getter( + field, + (SerializableFunction<@NonNull T, Object>) + fromAws.create(sdkField::getValueOrDefault, sdkField))); } return getters; } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java index d36c197d80a4..9e994702fe61 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.awssdk.core.SdkPojo; import software.amazon.awssdk.utils.builder.SdkBuilder; @@ -78,7 +79,7 @@ static SdkBuilderSetter setter(String name, BiConsumer, Object> return new ValueSetter(name, setter); } - static FieldValueGetter getter( + static FieldValueGetter getter( String name, SerializableFunction getter) { return new ValueGetter<>(name, getter); } @@ -107,7 +108,8 @@ public String name() { } } - private static class ValueGetter implements FieldValueGetter { + private static class ValueGetter + implements FieldValueGetter { private final SerializableFunction getter; private final String name; diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/s3/S3FileSystemTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/s3/S3FileSystemTest.java index 423176e52a75..39995b8b3167 100644 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/s3/S3FileSystemTest.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/s3/S3FileSystemTest.java @@ -34,6 +34,7 @@ import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.notNull; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -55,6 +56,7 @@ import org.apache.beam.sdk.io.aws2.options.S3Options; import org.apache.beam.sdk.io.fs.CreateOptions; import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -1068,6 +1070,21 @@ public void testWriteAndRead() throws IOException { open.close(); } + @Test + public void testReportLineageOnBucket() { + verifyLineage("s3://testbucket", ImmutableList.of("testbucket")); + verifyLineage("s3://testbucket/", ImmutableList.of("testbucket")); + verifyLineage("s3://testbucket/foo/bar.txt", ImmutableList.of("testbucket", "foo/bar.txt")); + } + + private void verifyLineage(String uri, List expected) { + S3FileSystem s3FileSystem = buildMockedS3FileSystem(s3Config("mys3"), client); + S3ResourceId path = S3ResourceId.fromUri(uri); + Lineage mockLineage = mock(Lineage.class); + s3FileSystem.reportLineage(path, mockLineage); + verify(mockLineage, times(1)).add("s3", expected); + } + /** A mockito argument matcher to implement equality on GetHeadObjectRequest. */ private static class GetHeadObjectRequestMatcher implements ArgumentMatcher { diff --git a/sdks/java/io/azure/src/main/java/org/apache/beam/sdk/io/azure/blobstore/AzureBlobStoreFileSystem.java b/sdks/java/io/azure/src/main/java/org/apache/beam/sdk/io/azure/blobstore/AzureBlobStoreFileSystem.java index 5137eaf9bb2d..bbb2e22d94ce 100644 --- a/sdks/java/io/azure/src/main/java/org/apache/beam/sdk/io/azure/blobstore/AzureBlobStoreFileSystem.java +++ b/sdks/java/io/azure/src/main/java/org/apache/beam/sdk/io/azure/blobstore/AzureBlobStoreFileSystem.java @@ -453,7 +453,12 @@ protected AzfsResourceId matchNewResource(String singleResourceSpec, boolean isD @Override protected void reportLineage(AzfsResourceId resourceId, Lineage lineage) { - if (!Strings.isNullOrEmpty(resourceId.getBlob())) { + reportLineage(resourceId, lineage, LineageLevel.FILE); + } + + @Override + protected void reportLineage(AzfsResourceId resourceId, Lineage lineage, LineageLevel level) { + if (level != LineageLevel.TOP_LEVEL && !Strings.isNullOrEmpty(resourceId.getBlob())) { lineage.add( "abs", ImmutableList.of( diff --git a/sdks/java/io/azure/src/test/java/org/apache/beam/sdk/io/azure/blobstore/AzureBlobStoreFileSystemTest.java b/sdks/java/io/azure/src/test/java/org/apache/beam/sdk/io/azure/blobstore/AzureBlobStoreFileSystemTest.java index 545f314688c3..27a2220c2e44 100644 --- a/sdks/java/io/azure/src/test/java/org/apache/beam/sdk/io/azure/blobstore/AzureBlobStoreFileSystemTest.java +++ b/sdks/java/io/azure/src/test/java/org/apache/beam/sdk/io/azure/blobstore/AzureBlobStoreFileSystemTest.java @@ -25,6 +25,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -51,6 +52,7 @@ import org.apache.beam.sdk.io.azure.options.BlobstoreOptions; import org.apache.beam.sdk.io.fs.CreateOptions; import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -338,4 +340,20 @@ public void testMatchNonGlobs() throws Exception { blobContainerClient.delete(); } + + @Test + public void testReportLineageOnBucket() { + verifyLineage("azfs://account/container", ImmutableList.of("account", "container")); + verifyLineage("azfs://account/container/", ImmutableList.of("account", "container")); + verifyLineage( + "azfs://account/container/foo/bar.txt", + ImmutableList.of("account", "container", "foo/bar.txt")); + } + + private void verifyLineage(String uri, List expected) { + AzfsResourceId path = AzfsResourceId.fromUri(uri); + Lineage mockLineage = mock(Lineage.class); + azureBlobStoreFileSystem.reportLineage(path, mockLineage); + verify(mockLineage, times(1)).add("abs", expected); + } } diff --git a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java index 2bfa694aebc0..970d9483850c 100644 --- a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java +++ b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java @@ -56,7 +56,7 @@ public class DebeziumIOPostgresSqlConnectorIT { @ClassRule public static final PostgreSQLContainer POSTGRES_SQL_CONTAINER = new PostgreSQLContainer<>( - DockerImageName.parse("debezium/example-postgres:latest") + DockerImageName.parse("quay.io/debezium/example-postgres:latest") .asCompatibleSubstituteFor("postgres")) .withPassword("dbz") .withUsername("debezium") diff --git a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java index c75621040913..c4b5d2d1f890 100644 --- a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java +++ b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java @@ -46,7 +46,7 @@ public class DebeziumReadSchemaTransformTest { @ClassRule public static final PostgreSQLContainer POSTGRES_SQL_CONTAINER = new PostgreSQLContainer<>( - DockerImageName.parse("debezium/example-postgres:latest") + DockerImageName.parse("quay.io/debezium/example-postgres:latest") .asCompatibleSubstituteFor("postgres")) .withPassword("dbz") .withUsername("debezium") diff --git a/sdks/java/io/expansion-service/build.gradle b/sdks/java/io/expansion-service/build.gradle index d7fef3d82332..421719b8f986 100644 --- a/sdks/java/io/expansion-service/build.gradle +++ b/sdks/java/io/expansion-service/build.gradle @@ -27,8 +27,15 @@ applyJavaNature( shadowClosure: {}, ) +// TODO(https://github.com/apache/beam/pull/32486/) Use library.java.kafka_clients once >=3.1.0 is set as default +configurations.runtimeClasspath { + // Pin kafka-clients version due to <3.1.0 missing auth callback classes + resolutionStrategy.force 'org.apache.kafka:kafka-clients:3.1.2' +} + shadowJar { mergeServiceFiles() + outputs.upToDateWhen { false } } description = "Apache Beam :: SDKs :: Java :: IO :: Expansion Service" @@ -37,6 +44,8 @@ ext.summary = "Expansion service serving several Java IOs" dependencies { implementation project(":sdks:java:expansion-service") permitUnusedDeclared project(":sdks:java:expansion-service") // BEAM-11761 + implementation project(":sdks:java:managed") + permitUnusedDeclared project(":sdks:java:managed") // BEAM-11761 implementation project(":sdks:java:io:iceberg") permitUnusedDeclared project(":sdks:java:io:iceberg") // BEAM-11761 implementation project(":sdks:java:io:kafka") @@ -45,6 +54,7 @@ dependencies { permitUnusedDeclared project(":sdks:java:io:kafka:upgrade") // BEAM-11761 // **** IcebergIO runtime dependencies **** + runtimeOnly library.java.hadoop_auth runtimeOnly library.java.hadoop_client // Needed when using GCS as the warehouse location. runtimeOnly library.java.bigdataoss_gcs_connector @@ -52,8 +62,7 @@ dependencies { runtimeOnly ("org.apache.iceberg:iceberg-hive-metastore:1.4.2") runtimeOnly project(path: ":sdks:java:io:iceberg:hive:exec", configuration: "shadow") - // TODO(https://github.com/apache/beam/pull/32486/) Use library.java.kafka_clients once 3.1.2 is set as default - runtimeOnly ("org.apache.kafka:kafka-clients:3.1.2") + runtimeOnly library.java.kafka_clients runtimeOnly library.java.slf4j_jdk14 } diff --git a/sdks/java/io/file-based-io-tests/src/test/java/org/apache/beam/sdk/io/text/TextIOIT.java b/sdks/java/io/file-based-io-tests/src/test/java/org/apache/beam/sdk/io/text/TextIOIT.java index 859c03ed7750..e50a8aba4162 100644 --- a/sdks/java/io/file-based-io-tests/src/test/java/org/apache/beam/sdk/io/text/TextIOIT.java +++ b/sdks/java/io/file-based-io-tests/src/test/java/org/apache/beam/sdk/io/text/TextIOIT.java @@ -154,9 +154,16 @@ public void writeThenReadAll() { PipelineResult result = pipeline.run(); PipelineResult.State pipelineState = result.waitUntilFinish(); - assertEquals( - Lineage.query(result.metrics(), Lineage.Type.SOURCE), - Lineage.query(result.metrics(), Lineage.Type.SINK)); + + Set sources = Lineage.query(result.metrics(), Lineage.Type.SOURCE); + Set sinks = Lineage.query(result.metrics(), Lineage.Type.SINK); + if (numShards != null && numShards <= 100) { + // both should be the full files, if supported by the runner + assertEquals(sources, sinks); + } else { + // if supported by runner, both should be non-empty + assertEquals(sources.isEmpty(), sinks.isEmpty()); + } collectAndPublishMetrics(result); // Fail the test if pipeline failed. diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle index 3e322d976c1a..0a5a89072963 100644 --- a/sdks/java/io/google-cloud-platform/build.gradle +++ b/sdks/java/io/google-cloud-platform/build.gradle @@ -159,6 +159,7 @@ dependencies { testImplementation project(path: ":sdks:java:extensions:google-cloud-platform-core", configuration: "testRuntimeMigration") testImplementation project(path: ":sdks:java:extensions:protobuf", configuration: "testRuntimeMigration") testImplementation project(path: ":runners:direct-java", configuration: "shadow") + testImplementation project(":sdks:java:managed") testImplementation project(path: ":sdks:java:io:common") testImplementation project(path: ":sdks:java:testing:test-utils") testImplementation library.java.commons_math3 @@ -197,6 +198,7 @@ task integrationTest(type: Test, dependsOn: processTestResources) { "--runner=DirectRunner", "--project=${gcpProject}", "--tempRoot=${gcpTempRoot}", + "--tempLocation=${gcpTempRoot}", "--firestoreDb=${firestoreDb}", "--firestoreHost=${firestoreHost}", "--bigtableChangeStreamInstanceId=${bigtableChangeStreamInstanceId}", @@ -357,4 +359,4 @@ task postCommit { description = "Integration tests of GCP connectors using the DirectRunner." dependsOn integrationTest dependsOn integrationTestKms -} +} \ No newline at end of file diff --git a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle index 1288d91964e1..01181721e9a4 100644 --- a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle +++ b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle @@ -36,6 +36,9 @@ dependencies { permitUnusedDeclared project(":sdks:java:io:google-cloud-platform") // BEAM-11761 implementation project(":sdks:java:extensions:schemaio-expansion-service") permitUnusedDeclared project(":sdks:java:extensions:schemaio-expansion-service") // BEAM-11761 + implementation project(":sdks:java:managed") + permitUnusedDeclared project(":sdks:java:managed") // BEAM-11761 + runtimeOnly library.java.slf4j_jdk14 } @@ -44,3 +47,7 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.test.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } + +shadowJar { + outputs.upToDateWhen { false } +} \ No newline at end of file diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java index 7a5aa2408d2e..d7ca787feea3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java @@ -31,7 +31,6 @@ import java.time.LocalTime; import java.time.temporal.ChronoUnit; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; @@ -221,11 +220,18 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) { case ITERABLE: @Nullable FieldType elementType = field.getType().getCollectionElementType(); if (elementType == null) { - throw new RuntimeException("Unexpected null element type!"); + throw new RuntimeException("Unexpected null element type on " + field.getName()); } + TypeName containedTypeName = + Preconditions.checkNotNull( + elementType.getTypeName(), + "Null type name found in contained type at " + field.getName()); Preconditions.checkState( - !Preconditions.checkNotNull(elementType.getTypeName()).isCollectionType(), - "Nested arrays not supported by BigQuery."); + !(containedTypeName.isCollectionType() || containedTypeName.isMapType()), + "Nested container types are not supported by BigQuery. Field " + + field.getName() + + " contains a type " + + containedTypeName.name()); TableFieldSchema elementFieldSchema = fieldDescriptorFromBeamField(Field.of(field.getName(), elementType)); builder = builder.setType(elementFieldSchema.getType()); @@ -244,7 +250,24 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) { builder = builder.setType(type); break; case MAP: - throw new RuntimeException("Map types not supported by BigQuery."); + @Nullable FieldType keyType = field.getType().getMapKeyType(); + @Nullable FieldType valueType = field.getType().getMapValueType(); + if (keyType == null) { + throw new RuntimeException( + "Unexpected null element type for the map's key on " + field.getName()); + } + if (valueType == null) { + throw new RuntimeException( + "Unexpected null element type for the map's value on " + field.getName()); + } + + builder = + builder + .setType(TableFieldSchema.Type.STRUCT) + .addFields(fieldDescriptorFromBeamField(Field.of("key", keyType))) + .addFields(fieldDescriptorFromBeamField(Field.of("value", valueType))) + .setMode(TableFieldSchema.Mode.REPEATED); + break; default: @Nullable TableFieldSchema.Type primitiveType = PRIMITIVE_TYPES.get(field.getType().getTypeName()); @@ -289,25 +312,34 @@ private static Object toProtoValue( case ROW: return messageFromBeamRow(fieldDescriptor.getMessageType(), (Row) value, null, -1); case ARRAY: - List list = (List) value; - @Nullable FieldType arrayElementType = beamFieldType.getCollectionElementType(); - if (arrayElementType == null) { - throw new RuntimeException("Unexpected null element type!"); - } - return list.stream() - .map(v -> toProtoValue(fieldDescriptor, arrayElementType, v)) - .collect(Collectors.toList()); case ITERABLE: Iterable iterable = (Iterable) value; @Nullable FieldType iterableElementType = beamFieldType.getCollectionElementType(); if (iterableElementType == null) { - throw new RuntimeException("Unexpected null element type!"); + throw new RuntimeException("Unexpected null element type: " + fieldDescriptor.getName()); } + return StreamSupport.stream(iterable.spliterator(), false) .map(v -> toProtoValue(fieldDescriptor, iterableElementType, v)) .collect(Collectors.toList()); case MAP: - throw new RuntimeException("Map types not supported by BigQuery."); + Map map = (Map) value; + @Nullable FieldType keyType = beamFieldType.getMapKeyType(); + @Nullable FieldType valueType = beamFieldType.getMapValueType(); + if (keyType == null) { + throw new RuntimeException("Unexpected null for key type: " + fieldDescriptor.getName()); + } + if (valueType == null) { + throw new RuntimeException( + "Unexpected null for value type: " + fieldDescriptor.getName()); + } + + return map.entrySet().stream() + .map( + (Map.Entry entry) -> + mapEntryToProtoValue( + fieldDescriptor.getMessageType(), keyType, valueType, entry)) + .collect(Collectors.toList()); default: return scalarToProtoValue(beamFieldType, value); } @@ -337,6 +369,28 @@ static Object scalarToProtoValue(FieldType beamFieldType, Object value) { } } + static Object mapEntryToProtoValue( + Descriptor descriptor, + FieldType keyFieldType, + FieldType valueFieldType, + Map.Entry entryValue) { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + FieldDescriptor keyFieldDescriptor = + Preconditions.checkNotNull(descriptor.findFieldByName("key")); + @Nullable Object key = toProtoValue(keyFieldDescriptor, keyFieldType, entryValue.getKey()); + if (key != null) { + builder.setField(keyFieldDescriptor, key); + } + FieldDescriptor valueFieldDescriptor = + Preconditions.checkNotNull(descriptor.findFieldByName("value")); + @Nullable + Object value = toProtoValue(valueFieldDescriptor, valueFieldType, entryValue.getValue()); + if (value != null) { + builder.setField(valueFieldDescriptor, value); + } + return builder.build(); + } + static ByteString serializeBigDecimalToNumeric(BigDecimal o) { return serializeBigDecimal(o, NUMERIC_SCALE, MAX_NUMERIC_VALUE, MIN_NUMERIC_VALUE, "Numeric"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java index cddde05b194c..1af44ba7a012 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java @@ -34,6 +34,8 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Set; import org.apache.avro.Conversions; import org.apache.avro.LogicalType; @@ -50,14 +52,14 @@ import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -/** - * A set of utilities for working with Avro files. - * - *

    These utilities are based on the Avro - * 1.8.1 specification. - */ +/** A set of utilities for working with Avro files. */ class BigQueryAvroUtils { + private static final String VERSION_AVRO = + Optional.ofNullable(Schema.class.getPackage()) + .map(Package::getImplementationVersion) + .orElse(""); + // org.apache.avro.LogicalType static class DateTimeLogicalType extends LogicalType { public DateTimeLogicalType() { @@ -74,6 +76,8 @@ public DateTimeLogicalType() { * export * @see BQ * avro storage + * @see BQ avro + * load */ static Schema getPrimitiveType(TableFieldSchema schema, Boolean useAvroLogicalTypes) { String bqType = schema.getType(); @@ -116,6 +120,9 @@ static Schema getPrimitiveType(TableFieldSchema schema, Boolean useAvroLogicalTy } case "DATETIME": if (useAvroLogicalTypes) { + // BQ export uses a custom logical type + // TODO for load/storage use + // LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType()) return DATETIME_LOGICAL_TYPE.addToSchema(SchemaBuilder.builder().stringType()); } else { return SchemaBuilder.builder().stringBuilder().prop("sqlType", bqType).endString(); @@ -158,6 +165,12 @@ static Schema getPrimitiveType(TableFieldSchema schema, Boolean useAvroLogicalTy @VisibleForTesting static String formatTimestamp(Long timestampMicro) { + String dateTime = formatDatetime(timestampMicro); + return dateTime + " UTC"; + } + + @VisibleForTesting + static String formatDatetime(Long timestampMicro) { // timestampMicro is in "microseconds since epoch" format, // e.g., 1452062291123456L means "2016-01-06 06:38:11.123456 UTC". // Separate into seconds and microseconds. @@ -168,11 +181,13 @@ static String formatTimestamp(Long timestampMicro) { timestampSec -= 1; } String dayAndTime = DATE_AND_SECONDS_FORMATTER.print(timestampSec * 1000); - if (micros == 0) { - return String.format("%s UTC", dayAndTime); + return dayAndTime; + } else if (micros % 1000 == 0) { + return String.format("%s.%03d", dayAndTime, micros / 1000); + } else { + return String.format("%s.%06d", dayAndTime, micros); } - return String.format("%s.%06d UTC", dayAndTime, micros); } /** @@ -274,8 +289,7 @@ static TableRow convertGenericRecordToTableRow(GenericRecord record) { case UNION: return convertNullableField(name, schema, v); case MAP: - throw new UnsupportedOperationException( - String.format("Unexpected Avro field schema type %s for field named %s", type, name)); + return convertMapField(name, schema, v); default: return convertRequiredField(name, schema, v); } @@ -296,6 +310,26 @@ private static List convertRepeatedField(String name, Schema elementType return values; } + private static List convertMapField(String name, Schema map, Object v) { + // Avro maps are represented as key/value RECORD. + if (v == null) { + // Handle the case of an empty map. + return new ArrayList<>(); + } + + Schema type = map.getValueType(); + Map elements = (Map) v; + ArrayList values = new ArrayList<>(); + for (Map.Entry element : elements.entrySet()) { + TableRow row = + new TableRow() + .set("key", element.getKey()) + .set("value", convertRequiredField(name, type, element.getValue())); + values.add(row); + } + return values; + } + private static Object convertRequiredField(String name, Schema schema, Object v) { // REQUIRED fields are represented as the corresponding Avro types. For example, a BigQuery // INTEGER type maps to an Avro LONG type. @@ -305,45 +339,83 @@ private static Object convertRequiredField(String name, Schema schema, Object v) LogicalType logicalType = schema.getLogicalType(); switch (type) { case BOOLEAN: - // SQL types BOOL, BOOLEAN + // SQL type BOOL (BOOLEAN) return v; case INT: if (logicalType instanceof LogicalTypes.Date) { - // SQL types DATE + // SQL type DATE + // ideally LocalDate but TableRowJsonCoder encodes as String return formatDate((Integer) v); + } else if (logicalType instanceof LogicalTypes.TimeMillis) { + // Write only: SQL type TIME + // ideally LocalTime but TableRowJsonCoder encodes as String + return formatTime(((Integer) v) * 1000L); } else { - throw new UnsupportedOperationException( - String.format("Unexpected Avro field schema type %s for field named %s", type, name)); + // Write only: SQL type INT64 (INT, SMALLINT, INTEGER, BIGINT, TINYINT, BYTEINT) + // ideally Integer but keep consistency with BQ JSON export that uses String + return ((Integer) v).toString(); } case LONG: if (logicalType instanceof LogicalTypes.TimeMicros) { - // SQL types TIME + // SQL type TIME + // ideally LocalTime but TableRowJsonCoder encodes as String return formatTime((Long) v); + } else if (logicalType instanceof LogicalTypes.TimestampMillis) { + // Write only: SQL type TIMESTAMP + // ideally Instant but TableRowJsonCoder encodes as String + return formatTimestamp((Long) v * 1000L); } else if (logicalType instanceof LogicalTypes.TimestampMicros) { - // SQL types TIMESTAMP + // SQL type TIMESTAMP + // ideally Instant but TableRowJsonCoder encodes as String return formatTimestamp((Long) v); + } else if (!(VERSION_AVRO.startsWith("1.8") || VERSION_AVRO.startsWith("1.9")) + && logicalType instanceof LogicalTypes.LocalTimestampMillis) { + // Write only: SQL type DATETIME + // ideally LocalDateTime but TableRowJsonCoder encodes as String + return formatDatetime(((Long) v) * 1000); + } else if (!(VERSION_AVRO.startsWith("1.8") || VERSION_AVRO.startsWith("1.9")) + && logicalType instanceof LogicalTypes.LocalTimestampMicros) { + // Write only: SQL type DATETIME + // ideally LocalDateTime but TableRowJsonCoder encodes as String + return formatDatetime((Long) v); } else { - // SQL types INT64 (INT, SMALLINT, INTEGER, BIGINT, TINYINT, BYTEINT) + // SQL type INT64 (INT, SMALLINT, INTEGER, BIGINT, TINYINT, BYTEINT) + // ideally Long if in [2^53+1, 2^53-1] but keep consistency with BQ JSON export that uses + // String return ((Long) v).toString(); } + case FLOAT: + // Write only: SQL type FLOAT64 + // ideally Float but TableRowJsonCoder decodes as Double + return Double.valueOf(v.toString()); case DOUBLE: - // SQL types FLOAT64 + // SQL type FLOAT64 return v; case BYTES: if (logicalType instanceof LogicalTypes.Decimal) { // SQL tpe NUMERIC, BIGNUMERIC + // ideally BigDecimal but TableRowJsonCoder encodes as String return new Conversions.DecimalConversion() .fromBytes((ByteBuffer) v, schema, logicalType) .toString(); } else { - // SQL types BYTES + // SQL type BYTES + // ideally byte[] but TableRowJsonCoder encodes as String return BaseEncoding.base64().encode(((ByteBuffer) v).array()); } case STRING: // SQL types STRING, DATETIME, GEOGRAPHY, JSON // when not using logical type DATE, TIME too return v.toString(); + case ENUM: + // SQL types STRING + return v.toString(); + case FIXED: + // SQL type BYTES + // ideally byte[] but TableRowJsonCoder encodes as String + return BaseEncoding.base64().encode(((ByteBuffer) v).array()); case RECORD: + // SQL types RECORD return convertGenericRecordToTableRow((GenericRecord) v); default: throw new UnsupportedOperationException( diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java deleted file mode 100644 index f634b5ec6f60..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.bigquery; - -import com.google.auto.value.AutoValue; -import org.apache.beam.sdk.schemas.AutoValueSchema; -import org.apache.beam.sdk.schemas.annotations.DefaultSchema; - -/** - * Configuration for writing to BigQuery. - * - *

    This class is meant to be used with {@link BigQueryFileLoadsWriteSchemaTransformProvider}. - * - *

    Internal only: This class is actively being worked on, and it will likely change. We - * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam - * repository. - */ -@DefaultSchema(AutoValueSchema.class) -@AutoValue -public abstract class BigQueryFileLoadsWriteSchemaTransformConfiguration { - - /** Instantiates a {@link BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder}. */ - public static Builder builder() { - return new AutoValue_BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder(); - } - - /** - * Writes to the given table specification. See {@link BigQueryIO.Write#to(String)}} for the - * expected format. - */ - public abstract String getTableSpec(); - - /** Specifies whether the table should be created if it does not exist. */ - public abstract String getCreateDisposition(); - - /** Specifies what to do with existing data in the table, in case the table already exists. */ - public abstract String getWriteDisposition(); - - @AutoValue.Builder - public abstract static class Builder { - - /** - * Writes to the given table specification. See {@link BigQueryIO.Write#to(String)}} for the - * expected format. - */ - public abstract Builder setTableSpec(String value); - - /** Specifies whether the table should be created if it does not exist. */ - public abstract Builder setCreateDisposition(String value); - - /** Specifies what to do with existing data in the table, in case the table already exists. */ - public abstract Builder setWriteDisposition(String value); - - /** Builds the {@link BigQueryFileLoadsWriteSchemaTransformConfiguration} configuration. */ - public abstract BigQueryFileLoadsWriteSchemaTransformConfiguration build(); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java deleted file mode 100644 index 3212e2a30348..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.bigquery; - -import com.google.api.services.bigquery.model.Table; -import com.google.api.services.bigquery.model.TableReference; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.auto.service.AutoService; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.io.InvalidConfigurationException; -import org.apache.beam.sdk.schemas.transforms.SchemaTransform; -import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; -import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionRowTuple; -import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; - -/** - * An implementation of {@link TypedSchemaTransformProvider} for BigQuery write jobs configured - * using {@link BigQueryFileLoadsWriteSchemaTransformConfiguration}. - * - *

    Internal only: This class is actively being worked on, and it will likely change. We - * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam - * repository. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -@Internal -@AutoService(SchemaTransformProvider.class) -public class BigQueryFileLoadsWriteSchemaTransformProvider - extends TypedSchemaTransformProvider { - - private static final String IDENTIFIER = - "beam:schematransform:org.apache.beam:bigquery_fileloads_write:v1"; - static final String INPUT_TAG = "INPUT"; - - /** Returns the expected class of the configuration. */ - @Override - protected Class configurationClass() { - return BigQueryFileLoadsWriteSchemaTransformConfiguration.class; - } - - /** Returns the expected {@link SchemaTransform} of the configuration. */ - @Override - protected SchemaTransform from(BigQueryFileLoadsWriteSchemaTransformConfiguration configuration) { - return new BigQueryWriteSchemaTransform(configuration); - } - - /** Implementation of the {@link TypedSchemaTransformProvider} identifier method. */ - @Override - public String identifier() { - return IDENTIFIER; - } - - /** - * Implementation of the {@link TypedSchemaTransformProvider} inputCollectionNames method. Since a - * single is expected, this returns a list with a single name. - */ - @Override - public List inputCollectionNames() { - return Collections.singletonList(INPUT_TAG); - } - - /** - * Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. Since - * no output is expected, this returns an empty list. - */ - @Override - public List outputCollectionNames() { - return Collections.emptyList(); - } - - /** - * A {@link SchemaTransform} that performs {@link BigQueryIO.Write}s based on a {@link - * BigQueryFileLoadsWriteSchemaTransformConfiguration}. - */ - protected static class BigQueryWriteSchemaTransform extends SchemaTransform { - /** An instance of {@link BigQueryServices} used for testing. */ - private BigQueryServices testBigQueryServices = null; - - private final BigQueryFileLoadsWriteSchemaTransformConfiguration configuration; - - BigQueryWriteSchemaTransform(BigQueryFileLoadsWriteSchemaTransformConfiguration configuration) { - this.configuration = configuration; - } - - @Override - public void validate(PipelineOptions options) { - if (!configuration.getCreateDisposition().equals(CreateDisposition.CREATE_NEVER.name())) { - return; - } - - BigQueryOptions bigQueryOptions = options.as(BigQueryOptions.class); - - BigQueryServices bigQueryServices = new BigQueryServicesImpl(); - if (testBigQueryServices != null) { - bigQueryServices = testBigQueryServices; - } - - DatasetService datasetService = bigQueryServices.getDatasetService(bigQueryOptions); - TableReference tableReference = BigQueryUtils.toTableReference(configuration.getTableSpec()); - - try { - Table table = datasetService.getTable(tableReference); - if (table == null) { - throw new NullPointerException(); - } - - if (table.getSchema() == null) { - throw new InvalidConfigurationException( - String.format("could not fetch schema for table: %s", configuration.getTableSpec())); - } - - } catch (NullPointerException | InterruptedException | IOException ex) { - throw new InvalidConfigurationException( - String.format( - "could not fetch table %s, error: %s", - configuration.getTableSpec(), ex.getMessage())); - } - } - - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - validate(input); - PCollection rowPCollection = input.get(INPUT_TAG); - Schema schema = rowPCollection.getSchema(); - BigQueryIO.Write write = toWrite(schema); - if (testBigQueryServices != null) { - write = write.withTestServices(testBigQueryServices); - } - - PCollection tableRowPCollection = - rowPCollection.apply( - MapElements.into(TypeDescriptor.of(TableRow.class)).via(BigQueryUtils::toTableRow)); - tableRowPCollection.apply(write); - return PCollectionRowTuple.empty(input.getPipeline()); - } - - /** Instantiates a {@link BigQueryIO.Write} from a {@link Schema}. */ - BigQueryIO.Write toWrite(Schema schema) { - TableSchema tableSchema = BigQueryUtils.toTableSchema(schema); - CreateDisposition createDisposition = - CreateDisposition.valueOf(configuration.getCreateDisposition()); - WriteDisposition writeDisposition = - WriteDisposition.valueOf(configuration.getWriteDisposition()); - - return BigQueryIO.writeTableRows() - .to(configuration.getTableSpec()) - .withCreateDisposition(createDisposition) - .withWriteDisposition(writeDisposition) - .withSchema(tableSchema); - } - - /** Setter for testing using {@link BigQueryServices}. */ - @VisibleForTesting - void setTestBigQueryServices(BigQueryServices testBigQueryServices) { - this.testBigQueryServices = testBigQueryServices; - } - - /** Validate a {@link PCollectionRowTuple} input. */ - void validate(PCollectionRowTuple input) { - if (!input.has(INPUT_TAG)) { - throw new IllegalArgumentException( - String.format( - "%s %s is missing expected tag: %s", - getClass().getSimpleName(), input.getClass().getSimpleName(), INPUT_TAG)); - } - - PCollection rowInput = input.get(INPUT_TAG); - Schema sourceSchema = rowInput.getSchema(); - - if (sourceSchema == null) { - throw new IllegalArgumentException( - String.format("%s is null for input of tag: %s", Schema.class, INPUT_TAG)); - } - - if (!configuration.getCreateDisposition().equals(CreateDisposition.CREATE_NEVER.name())) { - return; - } - - BigQueryOptions bigQueryOptions = input.getPipeline().getOptions().as(BigQueryOptions.class); - - BigQueryServices bigQueryServices = new BigQueryServicesImpl(); - if (testBigQueryServices != null) { - bigQueryServices = testBigQueryServices; - } - - DatasetService datasetService = bigQueryServices.getDatasetService(bigQueryOptions); - TableReference tableReference = BigQueryUtils.toTableReference(configuration.getTableSpec()); - - try { - Table table = datasetService.getTable(tableReference); - if (table == null) { - throw new NullPointerException(); - } - - TableSchema tableSchema = table.getSchema(); - if (tableSchema == null) { - throw new NullPointerException(); - } - - Schema destinationSchema = BigQueryUtils.fromTableSchema(tableSchema); - if (destinationSchema == null) { - throw new NullPointerException(); - } - - validateMatching(sourceSchema, destinationSchema); - - } catch (NullPointerException | InterruptedException | IOException e) { - throw new InvalidConfigurationException( - String.format( - "could not validate input for create disposition: %s and table: %s, error: %s", - configuration.getCreateDisposition(), - configuration.getTableSpec(), - e.getMessage())); - } - } - - void validateMatching(Schema sourceSchema, Schema destinationSchema) { - if (!sourceSchema.equals(destinationSchema)) { - throw new IllegalArgumentException( - String.format( - "source and destination schema mismatch for table: %s", - configuration.getTableSpec())); - } - } - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index 88dfa2c26348..9a7f3a05556c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -2259,6 +2259,7 @@ public static Write applyRowMutations() { .withFormatFunction(RowMutation::getTableRow) .withRowMutationInformationFn(RowMutation::getMutationInformation); } + /** * A {@link PTransform} that writes a {@link PCollection} containing {@link GenericRecord * GenericRecords} to a BigQuery table. @@ -2367,8 +2368,10 @@ public enum Method { abstract WriteDisposition getWriteDisposition(); abstract Set getSchemaUpdateOptions(); + /** Table description. Default is empty. */ abstract @Nullable String getTableDescription(); + /** An option to indicate if table validation is desired. Default is true. */ abstract boolean getValidate(); @@ -3455,7 +3458,10 @@ && getStorageApiTriggeringFrequency(bqOptions) != null) { LOG.error("The Storage API sink does not support the WRITE_TRUNCATE write disposition."); } if (getRowMutationInformationFn() != null) { - checkArgument(getMethod() == Method.STORAGE_API_AT_LEAST_ONCE); + checkArgument( + getMethod() == Method.STORAGE_API_AT_LEAST_ONCE, + "When using row updates on BigQuery, StorageWrite API should execute using" + + " \"at least once\" mode."); checkArgument( getCreateDisposition() == CreateDisposition.CREATE_NEVER || getPrimaryKey() != null, "If specifying CREATE_IF_NEEDED along with row updates, a primary key needs to be specified"); @@ -3741,7 +3747,7 @@ private WriteResult continueExpandTyped( if (rowWriterFactory.getOutputType() == OutputType.JsonTableRow) { LOG.warn( "Found JSON type in TableSchema for 'FILE_LOADS' write method. \n" - + "Make sure the TableRow value is a parsed JSON to ensure the read as a " + + "Make sure the TableRow value is a Jackson JsonNode to ensure the read as a " + "JSON type. Otherwise it will read as a raw (escaped) string.\n" + "See https://cloud.google.com/bigquery/docs/loading-data-cloud-storage-json#limitations " + "for limitations."); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java index 9d84abbbbf1a..8c6893ef5798 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java @@ -297,19 +297,15 @@ public void finishBundle(FinishBundleContext c) throws Exception { } for (Map.Entry> entry : writers.entrySet()) { - try { - DestinationT destination = entry.getKey(); - BigQueryRowWriter writer = entry.getValue(); - BigQueryRowWriter.Result result = writer.getResult(); - BoundedWindow window = writerWindows.get(destination); - Preconditions.checkStateNotNull(window); - c.output( - new Result<>(result.resourceId.toString(), result.byteSize, destination), - window.maxTimestamp(), - window); - } catch (Exception e) { - exceptionList.add(e); - } + DestinationT destination = entry.getKey(); + BigQueryRowWriter writer = entry.getValue(); + BigQueryRowWriter.Result result = writer.getResult(); + BoundedWindow window = writerWindows.get(destination); + Preconditions.checkStateNotNull(window); + c.output( + new Result<>(result.resourceId.toString(), result.byteSize, destination), + window.maxTimestamp(), + window); } writers.clear(); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java index e374d459af44..288b94ce081b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java @@ -76,10 +76,8 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; -import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -110,17 +108,13 @@ static class ResultCoder extends AtomicCoder { static final ResultCoder INSTANCE = new ResultCoder(); @Override - public void encode(Result value, @UnknownKeyFor @NonNull @Initialized OutputStream outStream) - throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull - @Initialized IOException { + public void encode(Result value, OutputStream outStream) throws CoderException, IOException { StringUtf8Coder.of().encode(value.getTableName(), outStream); BooleanCoder.of().encode(value.isFirstPane(), outStream); } @Override - public Result decode(@UnknownKeyFor @NonNull @Initialized InputStream inStream) - throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull - @Initialized IOException { + public Result decode(InputStream inStream) throws CoderException, IOException { return new AutoValue_WriteTables_Result( StringUtf8Coder.of().decode(inStream), BooleanCoder.of().decode(inStream)); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java index 8b8e8179ce7d..073de40038b3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; @@ -26,6 +27,7 @@ import java.util.Collections; import java.util.List; import javax.annotation.Nullable; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead; @@ -33,7 +35,9 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransformConfiguration; import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; @@ -62,7 +66,7 @@ public class BigQueryDirectReadSchemaTransformProvider extends TypedSchemaTransformProvider { - private static final String OUTPUT_TAG = "OUTPUT_ROWS"; + public static final String OUTPUT_TAG = "output"; @Override protected Class configurationClass() { @@ -76,7 +80,7 @@ protected SchemaTransform from(BigQueryDirectReadSchemaTransformConfiguration co @Override public String identifier() { - return "beam:schematransform:org.apache.beam:bigquery_storage_read:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ); } @Override @@ -139,6 +143,10 @@ public static Builder builder() { @Nullable public abstract List getSelectedFields(); + @SchemaFieldDescription("Use this Cloud KMS key to encrypt your data") + @Nullable + public abstract String getKmsKey(); + @Nullable /** Builder for the {@link BigQueryDirectReadSchemaTransformConfiguration}. */ @AutoValue.Builder @@ -151,6 +159,8 @@ public abstract static class Builder { public abstract Builder setSelectedFields(List selectedFields); + public abstract Builder setKmsKey(String kmsKey); + /** Builds a {@link BigQueryDirectReadSchemaTransformConfiguration} instance. */ public abstract BigQueryDirectReadSchemaTransformConfiguration build(); } @@ -161,7 +171,7 @@ public abstract static class Builder { * BigQueryDirectReadSchemaTransformConfiguration} and instantiated by {@link * BigQueryDirectReadSchemaTransformProvider}. */ - protected static class BigQueryDirectReadSchemaTransform extends SchemaTransform { + public static class BigQueryDirectReadSchemaTransform extends SchemaTransform { private BigQueryServices testBigQueryServices = null; private final BigQueryDirectReadSchemaTransformConfiguration configuration; @@ -172,6 +182,20 @@ protected static class BigQueryDirectReadSchemaTransform extends SchemaTransform this.configuration = configuration; } + public Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(BigQueryDirectReadSchemaTransformConfiguration.class) + .apply(configuration) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } + @VisibleForTesting public void setBigQueryServices(BigQueryServices testBigQueryServices) { this.testBigQueryServices = testBigQueryServices; @@ -209,7 +233,10 @@ BigQueryIO.TypedRead createDirectReadTransform() { read = read.withSelectedFields(configuration.getSelectedFields()); } } else { - read = read.fromQuery(configuration.getQuery()); + read = read.fromQuery(configuration.getQuery()).usingStandardSql(); + } + if (!Strings.isNullOrEmpty(configuration.getKmsKey())) { + read = read.withKmsKey(configuration.getKmsKey()); } if (this.testBigQueryServices != null) { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java new file mode 100644 index 000000000000..7872c91d1f72 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import com.google.auto.service.AutoService; +import java.util.Collections; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; + +/** + * An implementation of {@link TypedSchemaTransformProvider} for BigQuery write jobs configured + * using {@link org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteConfiguration}. + * + *

    Internal only: This class is actively being worked on, and it will likely change. We + * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam + * repository. + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +@Internal +@AutoService(SchemaTransformProvider.class) +public class BigQueryFileLoadsSchemaTransformProvider + extends TypedSchemaTransformProvider { + + static final String INPUT_TAG = "input"; + + @Override + protected SchemaTransform from(BigQueryWriteConfiguration configuration) { + return new BigQueryFileLoadsSchemaTransform(configuration); + } + + @Override + public String identifier() { + return "beam:schematransform:org.apache.beam:bigquery_fileloads:v1"; + } + + @Override + public List inputCollectionNames() { + return Collections.singletonList(INPUT_TAG); + } + + @Override + public List outputCollectionNames() { + return Collections.emptyList(); + } + + public static class BigQueryFileLoadsSchemaTransform extends SchemaTransform { + /** An instance of {@link BigQueryServices} used for testing. */ + private BigQueryServices testBigQueryServices = null; + + private final BigQueryWriteConfiguration configuration; + + BigQueryFileLoadsSchemaTransform(BigQueryWriteConfiguration configuration) { + configuration.validate(); + this.configuration = configuration; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + PCollection rowPCollection = input.getSinglePCollection(); + BigQueryIO.Write write = + toWrite(rowPCollection.getSchema(), input.getPipeline().getOptions()); + rowPCollection.apply(write); + + return PCollectionRowTuple.empty(input.getPipeline()); + } + + BigQueryIO.Write toWrite(Schema schema, PipelineOptions options) { + PortableBigQueryDestinations dynamicDestinations = + new PortableBigQueryDestinations(schema, configuration); + BigQueryIO.Write write = + BigQueryIO.write() + .to(dynamicDestinations) + .withMethod(BigQueryIO.Write.Method.FILE_LOADS) + .withFormatFunction(BigQueryUtils.toTableRow()) + // TODO(https://github.com/apache/beam/issues/33074) BatchLoad's + // createTempFilePrefixView() doesn't pick up the pipeline option + .withCustomGcsTempLocation( + ValueProvider.StaticValueProvider.of(options.getTempLocation())) + .withWriteDisposition(WriteDisposition.WRITE_APPEND) + .withFormatFunction(dynamicDestinations.getFilterFormatFunction(false)); + + if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { + CreateDisposition createDisposition = + CreateDisposition.valueOf(configuration.getCreateDisposition().toUpperCase()); + write = write.withCreateDisposition(createDisposition); + } + if (!Strings.isNullOrEmpty(configuration.getWriteDisposition())) { + WriteDisposition writeDisposition = + WriteDisposition.valueOf(configuration.getWriteDisposition().toUpperCase()); + write = write.withWriteDisposition(writeDisposition); + } + if (!Strings.isNullOrEmpty(configuration.getKmsKey())) { + write = write.withKmsKey(configuration.getKmsKey()); + } + if (testBigQueryServices != null) { + write = write.withTestServices(testBigQueryServices); + } + + return write; + } + + /** Setter for testing using {@link BigQueryServices}. */ + @VisibleForTesting + void setTestBigQueryServices(BigQueryServices testBigQueryServices) { + this.testBigQueryServices = testBigQueryServices; + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java new file mode 100644 index 000000000000..555df0d0a2b8 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransform; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteSchemaTransformProvider.BigQueryWriteSchemaTransform; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +public class BigQuerySchemaTransformTranslation { + public static class BigQueryStorageReadSchemaTransformTranslator + extends SchemaTransformTranslation.SchemaTransformPayloadTranslator< + BigQueryDirectReadSchemaTransform> { + @Override + public SchemaTransformProvider provider() { + return new BigQueryDirectReadSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(BigQueryDirectReadSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + public static class BigQueryWriteSchemaTransformTranslator + extends SchemaTransformTranslation.SchemaTransformPayloadTranslator< + BigQueryWriteSchemaTransform> { + @Override + public SchemaTransformProvider provider() { + return new BigQueryWriteSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(BigQueryWriteSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class ReadWriteRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put( + BigQueryDirectReadSchemaTransform.class, + new BigQueryStorageReadSchemaTransformTranslator()) + .put(BigQueryWriteSchemaTransform.class, new BigQueryWriteSchemaTransformTranslator()) + .build(); + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index 980d783ec43c..1e53ad3553e0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -17,18 +17,15 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.PortableBigQueryDestinations.DESTINATION; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.PortableBigQueryDestinations.RECORD; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; -import com.google.api.services.bigquery.model.TableSchema; import com.google.auto.service.AutoService; -import com.google.auto.value.AutoValue; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; -import javax.annotation.Nullable; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method; @@ -36,18 +33,13 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryStorageApiInsertError; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; -import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; -import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; +import org.apache.beam.sdk.io.gcp.bigquery.RowMutationInformation; import org.apache.beam.sdk.io.gcp.bigquery.WriteResult; -import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.annotations.DefaultSchema; -import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; @@ -59,15 +51,13 @@ import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; -import org.apache.beam.sdk.values.ValueInSingleWindow; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; /** * An implementation of {@link TypedSchemaTransformProvider} for BigQuery Storage Write API jobs - * configured via {@link BigQueryStorageWriteApiSchemaTransformConfiguration}. + * configured via {@link BigQueryWriteConfiguration}. * *

    Internal only: This class is actively being worked on, and it will likely change. We * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam @@ -78,7 +68,7 @@ }) @AutoService(SchemaTransformProvider.class) public class BigQueryStorageWriteApiSchemaTransformProvider - extends TypedSchemaTransformProvider { + extends TypedSchemaTransformProvider { private static final Integer DEFAULT_TRIGGER_FREQUENCY_SECS = 5; private static final Duration DEFAULT_TRIGGERING_FREQUENCY = Duration.standardSeconds(DEFAULT_TRIGGER_FREQUENCY_SECS); @@ -87,16 +77,23 @@ public class BigQueryStorageWriteApiSchemaTransformProvider private static final String FAILED_ROWS_WITH_ERRORS_TAG = "FailedRowsWithErrors"; // magic string that tells us to write to dynamic destinations protected static final String DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS"; + protected static final String ROW_PROPERTY_MUTATION_INFO = "row_mutation_info"; + protected static final String ROW_PROPERTY_MUTATION_TYPE = "mutation_type"; + protected static final String ROW_PROPERTY_MUTATION_SQN = "change_sequence_number"; + protected static final Schema ROW_SCHEMA_MUTATION_INFO = + Schema.builder() + .addStringField("mutation_type") + .addStringField("change_sequence_number") + .build(); @Override - protected SchemaTransform from( - BigQueryStorageWriteApiSchemaTransformConfiguration configuration) { + protected SchemaTransform from(BigQueryWriteConfiguration configuration) { return new BigQueryStorageWriteApiSchemaTransform(configuration); } @Override public String identifier() { - return String.format("beam:schematransform:org.apache.beam:bigquery_storage_write:v2"); + return "beam:schematransform:org.apache.beam:bigquery_storage_write:v2"; } @Override @@ -119,183 +116,17 @@ public List outputCollectionNames() { return Arrays.asList(FAILED_ROWS_TAG, FAILED_ROWS_WITH_ERRORS_TAG, "errors"); } - /** Configuration for writing to BigQuery with Storage Write API. */ - @DefaultSchema(AutoValueSchema.class) - @AutoValue - public abstract static class BigQueryStorageWriteApiSchemaTransformConfiguration { - - static final Map CREATE_DISPOSITIONS = - ImmutableMap.builder() - .put(CreateDisposition.CREATE_IF_NEEDED.name(), CreateDisposition.CREATE_IF_NEEDED) - .put(CreateDisposition.CREATE_NEVER.name(), CreateDisposition.CREATE_NEVER) - .build(); - - static final Map WRITE_DISPOSITIONS = - ImmutableMap.builder() - .put(WriteDisposition.WRITE_TRUNCATE.name(), WriteDisposition.WRITE_TRUNCATE) - .put(WriteDisposition.WRITE_EMPTY.name(), WriteDisposition.WRITE_EMPTY) - .put(WriteDisposition.WRITE_APPEND.name(), WriteDisposition.WRITE_APPEND) - .build(); - - @AutoValue - public abstract static class ErrorHandling { - @SchemaFieldDescription("The name of the output PCollection containing failed writes.") - public abstract String getOutput(); - - public static Builder builder() { - return new AutoValue_BigQueryStorageWriteApiSchemaTransformProvider_BigQueryStorageWriteApiSchemaTransformConfiguration_ErrorHandling - .Builder(); - } - - @AutoValue.Builder - public abstract static class Builder { - public abstract Builder setOutput(String output); - - public abstract ErrorHandling build(); - } - } - - public void validate() { - String invalidConfigMessage = "Invalid BigQuery Storage Write configuration: "; - - // validate output table spec - checkArgument( - !Strings.isNullOrEmpty(this.getTable()), - invalidConfigMessage + "Table spec for a BigQuery Write must be specified."); - - // if we have an input table spec, validate it - if (!this.getTable().equals(DYNAMIC_DESTINATIONS)) { - checkNotNull(BigQueryHelpers.parseTableSpec(this.getTable())); - } - - // validate create and write dispositions - if (!Strings.isNullOrEmpty(this.getCreateDisposition())) { - checkNotNull( - CREATE_DISPOSITIONS.get(this.getCreateDisposition().toUpperCase()), - invalidConfigMessage - + "Invalid create disposition (%s) was specified. Available dispositions are: %s", - this.getCreateDisposition(), - CREATE_DISPOSITIONS.keySet()); - } - if (!Strings.isNullOrEmpty(this.getWriteDisposition())) { - checkNotNull( - WRITE_DISPOSITIONS.get(this.getWriteDisposition().toUpperCase()), - invalidConfigMessage - + "Invalid write disposition (%s) was specified. Available dispositions are: %s", - this.getWriteDisposition(), - WRITE_DISPOSITIONS.keySet()); - } - - if (this.getErrorHandling() != null) { - checkArgument( - !Strings.isNullOrEmpty(this.getErrorHandling().getOutput()), - invalidConfigMessage + "Output must not be empty if error handling specified."); - } - - if (this.getAutoSharding() != null - && this.getAutoSharding() - && this.getNumStreams() != null) { - checkArgument( - this.getNumStreams() == 0, - invalidConfigMessage - + "Cannot set a fixed number of streams when auto-sharding is enabled. Please pick only one of the two options."); - } - } - - /** - * Instantiates a {@link BigQueryStorageWriteApiSchemaTransformConfiguration.Builder} instance. - */ - public static Builder builder() { - return new AutoValue_BigQueryStorageWriteApiSchemaTransformProvider_BigQueryStorageWriteApiSchemaTransformConfiguration - .Builder(); - } - - @SchemaFieldDescription( - "The bigquery table to write to. Format: [${PROJECT}:]${DATASET}.${TABLE}") - public abstract String getTable(); - - @SchemaFieldDescription( - "Optional field that specifies whether the job is allowed to create new tables. " - + "The following values are supported: CREATE_IF_NEEDED (the job may create the table), CREATE_NEVER (" - + "the job must fail if the table does not exist already).") - @Nullable - public abstract String getCreateDisposition(); - - @SchemaFieldDescription( - "Specifies the action that occurs if the destination table already exists. " - + "The following values are supported: " - + "WRITE_TRUNCATE (overwrites the table data), " - + "WRITE_APPEND (append the data to the table), " - + "WRITE_EMPTY (job must fail if the table is not empty).") - @Nullable - public abstract String getWriteDisposition(); - - @SchemaFieldDescription( - "Determines how often to 'commit' progress into BigQuery. Default is every 5 seconds.") - @Nullable - public abstract Long getTriggeringFrequencySeconds(); - - @SchemaFieldDescription( - "This option enables lower latency for insertions to BigQuery but may ocassionally " - + "duplicate data elements.") - @Nullable - public abstract Boolean getUseAtLeastOnceSemantics(); - - @SchemaFieldDescription( - "This option enables using a dynamically determined number of Storage Write API streams to write to " - + "BigQuery. Only applicable to unbounded data.") - @Nullable - public abstract Boolean getAutoSharding(); - - @SchemaFieldDescription( - "Specifies the number of write streams that the Storage API sink will use. " - + "This parameter is only applicable when writing unbounded data.") - @Nullable - public abstract Integer getNumStreams(); - - @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") - @Nullable - public abstract ErrorHandling getErrorHandling(); - - /** Builder for {@link BigQueryStorageWriteApiSchemaTransformConfiguration}. */ - @AutoValue.Builder - public abstract static class Builder { - - public abstract Builder setTable(String table); - - public abstract Builder setCreateDisposition(String createDisposition); - - public abstract Builder setWriteDisposition(String writeDisposition); - - public abstract Builder setTriggeringFrequencySeconds(Long seconds); - - public abstract Builder setUseAtLeastOnceSemantics(Boolean use); - - public abstract Builder setAutoSharding(Boolean autoSharding); - - public abstract Builder setNumStreams(Integer numStreams); - - public abstract Builder setErrorHandling(ErrorHandling errorHandling); - - /** Builds a {@link BigQueryStorageWriteApiSchemaTransformConfiguration} instance. */ - public abstract BigQueryStorageWriteApiSchemaTransformProvider - .BigQueryStorageWriteApiSchemaTransformConfiguration - build(); - } - } - /** * A {@link SchemaTransform} for BigQuery Storage Write API, configured with {@link - * BigQueryStorageWriteApiSchemaTransformConfiguration} and instantiated by {@link + * BigQueryWriteConfiguration} and instantiated by {@link * BigQueryStorageWriteApiSchemaTransformProvider}. */ - protected static class BigQueryStorageWriteApiSchemaTransform extends SchemaTransform { + public static class BigQueryStorageWriteApiSchemaTransform extends SchemaTransform { private BigQueryServices testBigQueryServices = null; - private final BigQueryStorageWriteApiSchemaTransformConfiguration configuration; + private final BigQueryWriteConfiguration configuration; - BigQueryStorageWriteApiSchemaTransform( - BigQueryStorageWriteApiSchemaTransformConfiguration configuration) { + BigQueryStorageWriteApiSchemaTransform(BigQueryWriteConfiguration configuration) { configuration.validate(); this.configuration = configuration; } @@ -342,34 +173,10 @@ private static class NoOutputDoFn extends DoFn { public void process(ProcessContext c) {} } - private static class RowDynamicDestinations extends DynamicDestinations { - Schema schema; - - RowDynamicDestinations(Schema schema) { - this.schema = schema; - } - - @Override - public String getDestination(ValueInSingleWindow element) { - return element.getValue().getString("destination"); - } - - @Override - public TableDestination getTable(String destination) { - return new TableDestination(destination, null); - } - - @Override - public TableSchema getSchema(String destination) { - return BigQueryUtils.toTableSchema(schema); - } - } - @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { // Check that the input exists - checkArgument(input.has(INPUT_ROWS_TAG), "Missing expected input tag: %s", INPUT_ROWS_TAG); - PCollection inputRows = input.get(INPUT_ROWS_TAG); + PCollection inputRows = input.getSinglePCollection(); BigQueryIO.Write write = createStorageWriteApiTransform(inputRows.getSchema()); @@ -463,40 +270,89 @@ BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { BigQueryIO.Write write = BigQueryIO.write() .withMethod(writeMethod) - .withFormatFunction(BigQueryUtils.toTableRow()) .withWriteDisposition(WriteDisposition.WRITE_APPEND); + Schema rowSchema = schema; + boolean fetchNestedRecord = false; if (configuration.getTable().equals(DYNAMIC_DESTINATIONS)) { - checkArgument( - schema.getFieldNames().equals(Arrays.asList("destination", "record")), - "When writing to dynamic destinations, we expect Row Schema with a " - + "\"destination\" string field and a \"record\" Row field."); + validateDynamicDestinationsSchema(schema); + rowSchema = schema.getField(RECORD).getType().getRowSchema(); + fetchNestedRecord = true; + } + if (Boolean.TRUE.equals(configuration.getUseCdcWrites())) { + validateCdcSchema(schema); + rowSchema = schema.getField(RECORD).getType().getRowSchema(); + fetchNestedRecord = true; write = write - .to(new RowDynamicDestinations(schema.getField("record").getType().getRowSchema())) - .withFormatFunction(row -> BigQueryUtils.toTableRow(row.getRow("record"))); - } else { - write = write.to(configuration.getTable()).useBeamSchema(); + .withPrimaryKey(configuration.getPrimaryKey()) + .withRowMutationInformationFn( + row -> + RowMutationInformation.of( + RowMutationInformation.MutationType.valueOf( + row.getRow(ROW_PROPERTY_MUTATION_INFO) + .getString(ROW_PROPERTY_MUTATION_TYPE)), + row.getRow(ROW_PROPERTY_MUTATION_INFO) + .getString(ROW_PROPERTY_MUTATION_SQN))); } + PortableBigQueryDestinations dynamicDestinations = + new PortableBigQueryDestinations(rowSchema, configuration); + write = + write + .to(dynamicDestinations) + .withFormatFunction(dynamicDestinations.getFilterFormatFunction(fetchNestedRecord)); if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { CreateDisposition createDisposition = - BigQueryStorageWriteApiSchemaTransformConfiguration.CREATE_DISPOSITIONS.get( - configuration.getCreateDisposition().toUpperCase()); + CreateDisposition.valueOf(configuration.getCreateDisposition().toUpperCase()); write = write.withCreateDisposition(createDisposition); } + if (!Strings.isNullOrEmpty(configuration.getWriteDisposition())) { WriteDisposition writeDisposition = - BigQueryStorageWriteApiSchemaTransformConfiguration.WRITE_DISPOSITIONS.get( - configuration.getWriteDisposition().toUpperCase()); + WriteDisposition.valueOf(configuration.getWriteDisposition().toUpperCase()); write = write.withWriteDisposition(writeDisposition); } - + if (!Strings.isNullOrEmpty(configuration.getKmsKey())) { + write = write.withKmsKey(configuration.getKmsKey()); + } if (this.testBigQueryServices != null) { write = write.withTestServices(testBigQueryServices); } return write; } + + void validateDynamicDestinationsSchema(Schema schema) { + checkArgument( + schema.getFieldNames().containsAll(Arrays.asList(DESTINATION, RECORD)), + String.format( + "When writing to dynamic destinations, we expect Row Schema with a " + + "\"%s\" string field and a \"%s\" Row field.", + DESTINATION, RECORD)); + } + + private void validateCdcSchema(Schema schema) { + checkArgument( + schema.getFieldNames().containsAll(Arrays.asList(ROW_PROPERTY_MUTATION_INFO, RECORD)), + "When writing using CDC functionality, we expect Row Schema with a " + + "\"" + + ROW_PROPERTY_MUTATION_INFO + + "\" Row field and a \"record\" Row field."); + + Schema mutationSchema = schema.getField(ROW_PROPERTY_MUTATION_INFO).getType().getRowSchema(); + + checkArgument( + mutationSchema != null && mutationSchema.equals(ROW_SCHEMA_MUTATION_INFO), + "When writing using CDC functionality, we expect a \"" + + ROW_PROPERTY_MUTATION_INFO + + "\" field of Row type with schema:\n" + + ROW_SCHEMA_MUTATION_INFO.toString() + + "\n" + + "Received \"" + + ROW_PROPERTY_MUTATION_INFO + + "\" field with schema:\n" + + mutationSchema); + } } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java new file mode 100644 index 000000000000..505ce7125cee --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + +import com.google.auto.value.AutoValue; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Configuration for writing to BigQuery with SchemaTransforms. Used by {@link + * BigQueryStorageWriteApiSchemaTransformProvider} and {@link + * BigQueryFileLoadsSchemaTransformProvider}. + */ +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract class BigQueryWriteConfiguration { + protected static final String DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS"; + + @AutoValue + public abstract static class ErrorHandling { + @SchemaFieldDescription("The name of the output PCollection containing failed writes.") + public abstract String getOutput(); + + public static Builder builder() { + return new AutoValue_BigQueryWriteConfiguration_ErrorHandling.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setOutput(String output); + + public abstract ErrorHandling build(); + } + } + + public void validate() { + String invalidConfigMessage = "Invalid BigQuery Storage Write configuration: "; + + // validate output table spec + checkArgument( + !Strings.isNullOrEmpty(this.getTable()), + invalidConfigMessage + "Table spec for a BigQuery Write must be specified."); + + // validate create and write dispositions + String createDisposition = getCreateDisposition(); + if (createDisposition != null && !createDisposition.isEmpty()) { + List createDispositions = + Arrays.stream(BigQueryIO.Write.CreateDisposition.values()) + .map(c -> c.name()) + .collect(Collectors.toList()); + Preconditions.checkArgument( + createDispositions.contains(createDisposition.toUpperCase()), + "Invalid create disposition (%s) was specified. Available dispositions are: %s", + createDisposition, + createDispositions); + } + String writeDisposition = getWriteDisposition(); + if (writeDisposition != null && !writeDisposition.isEmpty()) { + List writeDispostions = + Arrays.stream(BigQueryIO.Write.WriteDisposition.values()) + .map(w -> w.name()) + .collect(Collectors.toList()); + Preconditions.checkArgument( + writeDispostions.contains(writeDisposition.toUpperCase()), + "Invalid write disposition (%s) was specified. Available dispositions are: %s", + writeDisposition, + writeDispostions); + } + + ErrorHandling errorHandling = getErrorHandling(); + if (errorHandling != null) { + checkArgument( + !Strings.isNullOrEmpty(errorHandling.getOutput()), + invalidConfigMessage + "Output must not be empty if error handling specified."); + } + + Boolean autoSharding = getAutoSharding(); + Integer numStreams = getNumStreams(); + if (autoSharding != null && autoSharding && numStreams != null) { + checkArgument( + numStreams == 0, + invalidConfigMessage + + "Cannot set a fixed number of streams when auto-sharding is enabled. Please pick only one of the two options."); + } + } + + /** Instantiates a {@link BigQueryWriteConfiguration.Builder} instance. */ + public static Builder builder() { + return new AutoValue_BigQueryWriteConfiguration.Builder(); + } + + @SchemaFieldDescription( + "The bigquery table to write to. Format: [${PROJECT}:]${DATASET}.${TABLE}") + public abstract String getTable(); + + @SchemaFieldDescription( + "Optional field that specifies whether the job is allowed to create new tables. " + + "The following values are supported: CREATE_IF_NEEDED (the job may create the table), CREATE_NEVER (" + + "the job must fail if the table does not exist already).") + @Nullable + public abstract String getCreateDisposition(); + + @SchemaFieldDescription( + "Specifies the action that occurs if the destination table already exists. " + + "The following values are supported: " + + "WRITE_TRUNCATE (overwrites the table data), " + + "WRITE_APPEND (append the data to the table), " + + "WRITE_EMPTY (job must fail if the table is not empty).") + @Nullable + public abstract String getWriteDisposition(); + + @SchemaFieldDescription( + "Determines how often to 'commit' progress into BigQuery. Default is every 5 seconds.") + @Nullable + public abstract Long getTriggeringFrequencySeconds(); + + @SchemaFieldDescription( + "This option enables lower latency for insertions to BigQuery but may ocassionally " + + "duplicate data elements.") + @Nullable + public abstract Boolean getUseAtLeastOnceSemantics(); + + @SchemaFieldDescription( + "This option enables using a dynamically determined number of Storage Write API streams to write to " + + "BigQuery. Only applicable to unbounded data.") + @Nullable + public abstract Boolean getAutoSharding(); + + @SchemaFieldDescription( + "Specifies the number of write streams that the Storage API sink will use. " + + "This parameter is only applicable when writing unbounded data.") + @Nullable + public abstract Integer getNumStreams(); + + @SchemaFieldDescription("Use this Cloud KMS key to encrypt your data") + @Nullable + public abstract String getKmsKey(); + + @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") + @Nullable + public abstract ErrorHandling getErrorHandling(); + + @SchemaFieldDescription( + "This option enables the use of BigQuery CDC functionality. The expected PCollection" + + " should contain Beam Rows with a schema wrapping the record to be inserted and" + + " adding the CDC info similar to: {row_mutation_info: {mutation_type:\"...\", " + + "change_sequence_number:\"...\"}, record: {...}}") + @Nullable + public abstract Boolean getUseCdcWrites(); + + @SchemaFieldDescription( + "If CREATE_IF_NEEDED disposition is set, BigQuery table(s) will be created with this" + + " columns as primary key. Required when CDC writes are enabled with CREATE_IF_NEEDED.") + @Nullable + public abstract List getPrimaryKey(); + + @SchemaFieldDescription( + "A list of field names to keep in the input record. All other fields are dropped before writing. " + + "Is mutually exclusive with 'drop' and 'only'.") + public abstract @Nullable List getKeep(); + + @SchemaFieldDescription( + "A list of field names to drop from the input record before writing. " + + "Is mutually exclusive with 'keep' and 'only'.") + public abstract @Nullable List getDrop(); + + @SchemaFieldDescription( + "The name of a single record field that should be written. " + + "Is mutually exclusive with 'keep' and 'drop'.") + public abstract @Nullable String getOnly(); + + /** Builder for {@link BigQueryWriteConfiguration}. */ + @AutoValue.Builder + public abstract static class Builder { + + public abstract Builder setTable(String table); + + public abstract Builder setCreateDisposition(String createDisposition); + + public abstract Builder setWriteDisposition(String writeDisposition); + + public abstract Builder setTriggeringFrequencySeconds(Long seconds); + + public abstract Builder setUseAtLeastOnceSemantics(Boolean use); + + public abstract Builder setAutoSharding(Boolean autoSharding); + + public abstract Builder setNumStreams(Integer numStreams); + + public abstract Builder setKmsKey(String kmsKey); + + public abstract Builder setErrorHandling(ErrorHandling errorHandling); + + public abstract Builder setUseCdcWrites(Boolean cdcWrites); + + public abstract Builder setPrimaryKey(List pkColumns); + + public abstract Builder setKeep(List keep); + + public abstract Builder setDrop(List drop); + + public abstract Builder setOnly(String only); + + /** Builds a {@link BigQueryWriteConfiguration} instance. */ + public abstract BigQueryWriteConfiguration build(); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java new file mode 100644 index 000000000000..abab169d6932 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; + +import com.google.auto.service.AutoService; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; + +/** + * A BigQuery Write SchemaTransformProvider that routes to either {@link + * BigQueryFileLoadsSchemaTransformProvider} or {@link + * BigQueryStorageWriteApiSchemaTransformProvider}. + * + *

    Internal only. Used by the Managed Transform layer. + */ +@Internal +@AutoService(SchemaTransformProvider.class) +public class BigQueryWriteSchemaTransformProvider + extends TypedSchemaTransformProvider { + @Override + public String identifier() { + return getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE); + } + + @Override + protected SchemaTransform from(BigQueryWriteConfiguration configuration) { + return new BigQueryWriteSchemaTransform(configuration); + } + + public static class BigQueryWriteSchemaTransform extends SchemaTransform { + private final BigQueryWriteConfiguration configuration; + + BigQueryWriteSchemaTransform(BigQueryWriteConfiguration configuration) { + configuration.validate(); + this.configuration = configuration; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + if (input.getSinglePCollection().isBounded().equals(PCollection.IsBounded.BOUNDED)) { + return input.apply(new BigQueryFileLoadsSchemaTransformProvider().from(configuration)); + } else { // UNBOUNDED + return input.apply( + new BigQueryStorageWriteApiSchemaTransformProvider().from(configuration)); + } + } + + public Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(BigQueryWriteConfiguration.class) + .apply(configuration) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java new file mode 100644 index 000000000000..54d125012eac --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; +import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import com.google.api.services.bigquery.model.TableConstraints; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; +import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; +import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.RowFilter; +import org.apache.beam.sdk.util.RowStringInterpolator; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.ValueInSingleWindow; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +@Internal +public class PortableBigQueryDestinations extends DynamicDestinations { + public static final String DESTINATION = "destination"; + public static final String RECORD = "record"; + private @MonotonicNonNull RowStringInterpolator interpolator = null; + private final @Nullable List primaryKey; + private final RowFilter rowFilter; + + public PortableBigQueryDestinations(Schema rowSchema, BigQueryWriteConfiguration configuration) { + // DYNAMIC_DESTINATIONS magic string is the old way of doing it for cross-language. + // In that case, we do no interpolation + if (!configuration.getTable().equals(DYNAMIC_DESTINATIONS)) { + this.interpolator = new RowStringInterpolator(configuration.getTable(), rowSchema); + } + this.primaryKey = configuration.getPrimaryKey(); + RowFilter rf = new RowFilter(rowSchema); + if (configuration.getDrop() != null) { + rf = rf.drop(checkStateNotNull(configuration.getDrop())); + } + if (configuration.getKeep() != null) { + rf = rf.keep(checkStateNotNull(configuration.getKeep())); + } + if (configuration.getOnly() != null) { + rf = rf.only(checkStateNotNull(configuration.getOnly())); + } + this.rowFilter = rf; + } + + @Override + public String getDestination(@Nullable ValueInSingleWindow element) { + if (interpolator != null) { + return interpolator.interpolate(checkArgumentNotNull(element)); + } + return checkStateNotNull(checkStateNotNull(element).getValue().getString(DESTINATION)); + } + + @Override + public TableDestination getTable(String destination) { + return new TableDestination(destination, null); + } + + @Override + public @Nullable TableSchema getSchema(String destination) { + return BigQueryUtils.toTableSchema(rowFilter.outputSchema()); + } + + @Override + public @Nullable TableConstraints getTableConstraints(String destination) { + if (primaryKey != null) { + return new TableConstraints() + .setPrimaryKey(new TableConstraints.PrimaryKey().setColumns(primaryKey)); + } + return null; + } + + public SerializableFunction getFilterFormatFunction(boolean fetchNestedRecord) { + return row -> { + if (fetchNestedRecord) { + row = checkStateNotNull(row.getRow(RECORD)); + } + Row filtered = rowFilter.filter(row); + return BigQueryUtils.toTableRow(filtered); + }; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java index 389d2e43c74e..932099e01763 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java @@ -2021,10 +2021,9 @@ public void close() throws IOException { reader.close(); reader = null; } - if (serviceEntry != null) { - serviceEntry.close(); - serviceEntry = null; - } + // Skipping closing the service entry on each bundle. + // In the future we'll close the Bigtable client in + // teardown. } @Override diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java index dca12db0c211..2b187039d6cb 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java @@ -200,11 +200,10 @@ interface RpcWriteAttempt extends RpcAttempt { * provided {@code instant}. * * @param instant The intended start time of the next rpc - * @param The type which will be sent in the request * @param The {@link Element} type which the returned buffer will contain * @return a new {@link FlushBuffer} which queued messages can be staged to before final flush */ - > FlushBuffer newFlushBuffer(Instant instant); + > FlushBuffer newFlushBuffer(Instant instant); /** Record the start time of sending the rpc. */ void recordRequestStart(Instant start, int numWrites); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java index c600ae4224b4..1c83e45acb95 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java @@ -386,7 +386,7 @@ public boolean awaitSafeToProceed(Instant instant) throws InterruptedException { } @Override - public > FlushBufferImpl newFlushBuffer( + public > FlushBufferImpl newFlushBuffer( Instant instantSinceEpoch) { state.checkActive(); int availableWriteCountBudget = writeRampUp.getAvailableWriteCountBudget(instantSinceEpoch); @@ -935,7 +935,7 @@ private static O11y create( } } - static class FlushBufferImpl> implements FlushBuffer { + static class FlushBufferImpl> implements FlushBuffer { final int nextBatchMaxCount; final long nextBatchMaxBytes; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubClient.java index 2964a29dbb6b..bd01604643e1 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubClient.java @@ -507,6 +507,9 @@ public abstract void modifyAckDeadline( /** Return a list of topics for {@code project}. */ public abstract List listTopics(ProjectPath project) throws IOException; + /** Return {@literal true} if {@code topic} exists. */ + public abstract boolean isTopicExists(TopicPath topic) throws IOException; + /** Create {@code subscription} to {@code topic}. */ public abstract void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java index 93fdd5524007..0cfb06688108 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java @@ -54,6 +54,7 @@ import io.grpc.Channel; import io.grpc.ClientInterceptors; import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; import io.grpc.auth.ClientAuthInterceptor; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NegotiationType; @@ -372,6 +373,21 @@ public List listTopics(ProjectPath project) throws IOException { return topics; } + @Override + public boolean isTopicExists(TopicPath topic) throws IOException { + GetTopicRequest request = GetTopicRequest.newBuilder().setTopic(topic.getPath()).build(); + try { + publisherStub().getTopic(request); + return true; + } catch (StatusRuntimeException e) { + if (e.getStatus().getCode() == io.grpc.Status.Code.NOT_FOUND) { + return false; + } + + throw e; + } + } + @Override public void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java index f59a68c40551..c6c8b3e71815 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java @@ -50,6 +50,7 @@ import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.SubscriptionPath; import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.TopicPath; import org.apache.beam.sdk.metrics.Lineage; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; @@ -860,6 +861,8 @@ public abstract static class Read extends PTransform> abstract ErrorHandler getBadRecordErrorHandler(); + abstract boolean getValidate(); + abstract Builder toBuilder(); static Builder newBuilder(SerializableFunction parseFn) { @@ -871,6 +874,7 @@ static Builder newBuilder(SerializableFunction parseFn) builder.setNeedsOrderingKey(false); builder.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER); builder.setBadRecordErrorHandler(new DefaultErrorHandler<>()); + builder.setValidate(false); return builder; } @@ -918,6 +922,8 @@ abstract static class Builder { abstract Builder setBadRecordErrorHandler( ErrorHandler badRecordErrorHandler); + abstract Builder setValidate(boolean validation); + abstract Read build(); } @@ -1097,6 +1103,11 @@ public Read withErrorHandler(ErrorHandler badRecordErrorHandler .build(); } + /** Enable validation of the PubSub Read. */ + public Read withValidation() { + return toBuilder().setValidate(true).build(); + } + @VisibleForTesting /** * Set's the internal Clock. @@ -1262,6 +1273,35 @@ public T apply(PubsubMessage input) { return read.setCoder(getCoder()); } + @Override + public void validate(PipelineOptions options) { + if (!getValidate()) { + return; + } + + PubsubOptions psOptions = options.as(PubsubOptions.class); + + // Validate the existence of the topic. + if (getTopicProvider() != null) { + PubsubTopic topic = getTopicProvider().get(); + boolean topicExists = true; + try (PubsubClient pubsubClient = + getPubsubClientFactory() + .newClient(getTimestampAttribute(), getIdAttribute(), psOptions)) { + topicExists = + pubsubClient.isTopicExists( + PubsubClient.topicPathFromName(topic.project, topic.topic)); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (!topicExists) { + throw new IllegalArgumentException( + String.format("Pubsub topic '%s' does not exist.", topic)); + } + } + } + @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); @@ -1341,6 +1381,8 @@ public abstract static class Write extends PTransform, PDone> abstract ErrorHandler getBadRecordErrorHandler(); + abstract boolean getValidate(); + abstract Builder toBuilder(); static Builder newBuilder( @@ -1350,6 +1392,7 @@ static Builder newBuilder( builder.setFormatFn(formatFn); builder.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER); builder.setBadRecordErrorHandler(new DefaultErrorHandler<>()); + builder.setValidate(false); return builder; } @@ -1386,6 +1429,8 @@ abstract Builder setFormatFn( abstract Builder setBadRecordErrorHandler( ErrorHandler badRecordErrorHandler); + abstract Builder setValidate(boolean validation); + abstract Write build(); } @@ -1396,11 +1441,14 @@ abstract Builder setBadRecordErrorHandler( * {@code topic} string. */ public Write to(String topic) { + ValueProvider topicProvider = StaticValueProvider.of(topic); + validateTopic(topicProvider); return to(StaticValueProvider.of(topic)); } /** Like {@code topic()} but with a {@link ValueProvider}. */ public Write to(ValueProvider topic) { + validateTopic(topic); return toBuilder() .setTopicProvider(NestedValueProvider.of(topic, PubsubTopic::fromPath)) .setTopicFunction(null) @@ -1408,6 +1456,13 @@ public Write to(ValueProvider topic) { .build(); } + /** Handles validation of {@code topic}. */ + private static void validateTopic(ValueProvider topic) { + if (topic.isAccessible()) { + PubsubTopic.fromPath(topic.get()); + } + } + /** * Provides a function to dynamically specify the target topic per message. Not compatible with * any of the other to methods. If {@link #to} is called again specifying a topic, then this @@ -1497,6 +1552,11 @@ public Write withErrorHandler(ErrorHandler badRecordErrorHandle .build(); } + /** Enable validation of the PubSub Write. */ + public Write withValidation() { + return toBuilder().setValidate(true).build(); + } + @Override public PDone expand(PCollection input) { if (getTopicProvider() == null && !getDynamicDestinations()) { @@ -1566,6 +1626,35 @@ public PDone expand(PCollection input) { throw new RuntimeException(); // cases are exhaustive. } + @Override + public void validate(PipelineOptions options) { + if (!getValidate()) { + return; + } + + PubsubOptions psOptions = options.as(PubsubOptions.class); + + // Validate the existence of the topic. + if (getTopicProvider() != null) { + PubsubTopic topic = getTopicProvider().get(); + boolean topicExists = true; + try (PubsubClient pubsubClient = + getPubsubClientFactory() + .newClient(getTimestampAttribute(), getIdAttribute(), psOptions)) { + topicExists = + pubsubClient.isTopicExists( + PubsubClient.topicPathFromName(topic.project, topic.topic)); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (!topicExists) { + throw new IllegalArgumentException( + String.format("Pubsub topic '%s' does not exist.", topic)); + } + } + } + @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java index 386febcf005b..0a838da66f69 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java @@ -19,6 +19,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import com.google.api.client.googleapis.json.GoogleJsonResponseException; import com.google.api.client.http.HttpRequestInitializer; import com.google.api.services.pubsub.Pubsub; import com.google.api.services.pubsub.Pubsub.Projects.Subscriptions; @@ -310,6 +311,19 @@ public List listTopics(ProjectPath project) throws IOException { return topics; } + @Override + public boolean isTopicExists(TopicPath topic) throws IOException { + try { + pubsub.projects().topics().get(topic.getPath()).execute(); + return true; + } catch (GoogleJsonResponseException e) { + if (e.getStatusCode() == 404) { + return false; + } + throw e; + } + } + @Override public void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java index c1f6b2b31754..8a628817fe27 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java @@ -43,10 +43,7 @@ import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; /** * An implementation of {@link TypedSchemaTransformProvider} for Pub/Sub reads configured using @@ -313,19 +310,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsub_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("output", "errors"); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java index a8109d05ec38..3d5a879fce15 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java @@ -605,6 +605,12 @@ public List listTopics(ProjectPath project) throws IOException { throw new UnsupportedOperationException(); } + @Override + public boolean isTopicExists(TopicPath topic) throws IOException { + // Always return true for testing purposes. + return true; + } + @Override public void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java index 6187f6f79d3e..2abd6f5fa95d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java @@ -44,9 +44,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; /** * An implementation of {@link TypedSchemaTransformProvider} for Pub/Sub reads configured using @@ -248,19 +245,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsub_write:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.singletonList("input"); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Collections.singletonList("errors"); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java index 8afe730f32ce..61b94aeee445 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java @@ -63,10 +63,7 @@ import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -86,11 +83,17 @@ public class PubsubLiteReadSchemaTransformProvider public static final TupleTag ERROR_TAG = new TupleTag() {}; @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return PubsubLiteReadSchemaTransformConfiguration.class; } + @Override + public String description() { + return "Performs a read from Google Pub/Sub Lite.\n" + + "\n" + + "**Note**: This provider is deprecated. See Pub/Sub Lite documentation for more information."; + } + public static class ErrorFn extends DoFn { private final SerializableFunction valueMapper; private final Counter errorCounter; @@ -192,8 +195,7 @@ public void finish(FinishBundleContext c) { } @Override - public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - PubsubLiteReadSchemaTransformConfiguration configuration) { + public SchemaTransform from(PubsubLiteReadSchemaTransformConfiguration configuration) { if (!VALID_DATA_FORMATS.contains(configuration.getFormat())) { throw new IllegalArgumentException( String.format( @@ -399,19 +401,17 @@ public Uuid apply(SequencedMessage input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsublite_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("output", "errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java index 8ba8176035da..54ed7ac495d9 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java @@ -60,10 +60,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -81,11 +78,17 @@ public class PubsubLiteWriteSchemaTransformProvider LoggerFactory.getLogger(PubsubLiteWriteSchemaTransformProvider.class); @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return PubsubLiteWriteSchemaTransformConfiguration.class; } + @Override + public String description() { + return "Performs a write to Google Pub/Sub Lite.\n" + + "\n" + + "**Note**: This provider is deprecated. See Pub/Sub Lite documentation for more information."; + } + public static class ErrorCounterFn extends DoFn { private final SerializableFunction toBytesFn; private final Counter errorCounter; @@ -172,8 +175,7 @@ public void finish() { } @Override - public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - PubsubLiteWriteSchemaTransformConfiguration configuration) { + public SchemaTransform from(PubsubLiteWriteSchemaTransformConfiguration configuration) { if (!SUPPORTED_FORMATS.contains(configuration.getFormat())) { throw new IllegalArgumentException( @@ -317,19 +319,17 @@ public byte[] apply(Row input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsublite_write:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.singletonList("input"); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Collections.singletonList("errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java index 2b9f24f09541..933394982e30 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java @@ -135,7 +135,6 @@ String tryGetTableName() { return getTable(); } else if (getQuery() != null) { String query = getQuery().getSql(); - System.err.println(query); Matcher matcher = queryPattern.matcher(query); if (matcher.find()) { return matcher.group("table"); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java index 2a2b01cca9bd..b37f2e585815 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java @@ -30,7 +30,6 @@ import com.google.cloud.spanner.DatabaseAdminClient; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.DatabaseId; -import com.google.cloud.spanner.SessionPoolOptions; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerOptions; import com.google.cloud.spanner.v1.stub.SpannerStubSettings; @@ -233,9 +232,7 @@ static SpannerOptions buildSpannerOptions(SpannerConfig spannerConfig) { if (credentials != null && credentials.get() != null) { builder.setCredentials(credentials.get()); } - SessionPoolOptions sessionPoolOptions = - SessionPoolOptions.newBuilder().setFailIfPoolExhausted().build(); - builder.setSessionPoolOption(sessionPoolOptions); + return builder.build(); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 435bbba9ae8e..a6cf7ebb12a5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -25,7 +25,6 @@ import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.DEFAULT_RPC_PRIORITY; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.MAX_INCLUSIVE_END_AT; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.THROUGHPUT_WINDOW_SECONDS; -import static org.apache.beam.sdk.io.gcp.spanner.changestreams.NameGenerator.generatePartitionMetadataTableName; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; @@ -61,6 +60,7 @@ import java.util.HashMap; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -77,6 +77,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.MetadataSpannerConfigFactory; import org.apache.beam.sdk.io.gcp.spanner.changestreams.action.ActionFactory; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.DaoFactory; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.PartitionMetadataTableNames; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.CleanUpReadChangeStreamDoFn; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.DetectNewPartitionsDoFn; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.InitializeDoFn; @@ -290,8 +291,8 @@ * grouped into batches. The default maximum size of the batch is set to 1MB or 5000 mutated cells, * or 500 rows (whichever is reached first). To override this use {@link * Write#withBatchSizeBytes(long) withBatchSizeBytes()}, {@link Write#withMaxNumMutations(long) - * withMaxNumMutations()} or {@link Write#withMaxNumMutations(long) withMaxNumRows()}. Setting - * either to a small value or zero disables batching. + * withMaxNumMutations()} or {@link Write#withMaxNumRows(long) withMaxNumRows()}. Setting either to + * a small value or zero disables batching. * *

    Note that the maximum @@ -1772,9 +1773,13 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta + fullPartitionMetadataDatabaseId + " has dialect " + metadataDatabaseDialect); - final String partitionMetadataTableName = - MoreObjects.firstNonNull( - getMetadataTable(), generatePartitionMetadataTableName(partitionMetadataDatabaseId)); + PartitionMetadataTableNames partitionMetadataTableNames = + Optional.ofNullable(getMetadataTable()) + .map( + table -> + PartitionMetadataTableNames.fromExistingTable( + partitionMetadataDatabaseId, table)) + .orElse(PartitionMetadataTableNames.generateRandom(partitionMetadataDatabaseId)); final String changeStreamName = getChangeStreamName(); final Timestamp startTimestamp = getInclusiveStartAt(); // Uses (Timestamp.MAX - 1ns) at max for end timestamp, because we add 1ns to transform the @@ -1791,7 +1796,7 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta changeStreamSpannerConfig, changeStreamName, partitionMetadataSpannerConfig, - partitionMetadataTableName, + partitionMetadataTableNames, rpcPriority, input.getPipeline().getOptions().getJobName(), changeStreamDatabaseDialect, @@ -1807,7 +1812,9 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta final PostProcessingMetricsDoFn postProcessingMetricsDoFn = new PostProcessingMetricsDoFn(metrics); - LOG.info("Partition metadata table that will be used is " + partitionMetadataTableName); + LOG.info( + "Partition metadata table that will be used is " + + partitionMetadataTableNames.getTableName()); final PCollection impulseOut = input.apply(Impulse.create()); final PCollection partitionsOut = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java index 9820bb39d09d..76440b1ebf1a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java @@ -40,10 +40,8 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +/** A provider for reading from Cloud Spanner using a Schema Transform Provider. */ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) @@ -57,43 +55,81 @@ * *

    The transformation leverages the {@link SpannerIO} to perform the read operation and maps the * results to Beam rows, preserving the schema. - * - *

    Example usage in a YAML pipeline using query: - * - *

    {@code
    - * pipeline:
    - *   transforms:
    - *     - type: ReadFromSpanner
    - *       name: ReadShipments
    - *       # Columns: shipment_id, customer_id, shipment_date, shipment_cost, customer_name, customer_email
    - *       config:
    - *         project_id: 'apache-beam-testing'
    - *         instance_id: 'shipment-test'
    - *         database_id: 'shipment'
    - *         query: 'SELECT * FROM shipments'
    - * }
    - * - *

    Example usage in a YAML pipeline using a table and columns: - * - *

    {@code
    - * pipeline:
    - *   transforms:
    - *     - type: ReadFromSpanner
    - *       name: ReadShipments
    - *       # Columns: shipment_id, customer_id, shipment_date, shipment_cost, customer_name, customer_email
    - *       config:
    - *         project_id: 'apache-beam-testing'
    - *         instance_id: 'shipment-test'
    - *         database_id: 'shipment'
    - *         table: 'shipments'
    - *         columns: ['customer_id', 'customer_name']
    - * }
    */ @AutoService(SchemaTransformProvider.class) public class SpannerReadSchemaTransformProvider extends TypedSchemaTransformProvider< SpannerReadSchemaTransformProvider.SpannerReadSchemaTransformConfiguration> { + @Override + public String identifier() { + return "beam:schematransform:org.apache.beam:spanner_read:v1"; + } + + @Override + public String description() { + return "Performs a Bulk read from Google Cloud Spanner using a specified SQL query or " + + "by directly accessing a single table and its columns.\n" + + "\n" + + "Both Query and Read APIs are supported. See more information about " + + "
    reading from Cloud Spanner.\n" + + "\n" + + "Example configuration for performing a read using a SQL query: ::\n" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " query: 'SELECT * FROM table'\n" + + "\n" + + "It is also possible to read a table by specifying a table name and a list of columns. For " + + "example, the following configuration will perform a read on an entire table: ::\n" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + " columns: ['col1', 'col2']\n" + + "\n" + + "Additionally, to read using a " + + "Secondary Index, specify the index name: ::" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + " index: 'my-index'\n" + + " columns: ['col1', 'col2']\n" + + "\n" + + "### Advanced Usage\n" + + "\n" + + "Reads by default use the " + + "PartitionQuery API which enforces some limitations on the type of queries that can be used so that " + + "the data can be read in parallel. If the query is not supported by the PartitionQuery API, then you " + + "can specify a non-partitioned read by setting batching to false.\n" + + "\n" + + "For example: ::" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " batching: false\n" + + " ...\n" + + "\n" + + "Note: See " + + "SpannerIO for more advanced information."; + } + static class SpannerSchemaTransformRead extends SchemaTransform implements Serializable { private final SpannerReadSchemaTransformConfiguration configuration; @@ -116,6 +152,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } else { read = read.withTable(configuration.getTableId()).withColumns(configuration.getColumns()); } + if (!Strings.isNullOrEmpty(configuration.getIndex())) { + read = read.withIndex(configuration.getIndex()); + } + if (Boolean.FALSE.equals(configuration.getBatching())) { + read = read.withBatching(false); + } PCollection spannerRows = input.getPipeline().apply(read); Schema schema = spannerRows.getSchema(); PCollection rows = @@ -128,19 +170,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:spanner_read:v1"; - } - - @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Collections.singletonList("output"); } @@ -162,6 +197,10 @@ public abstract static class Builder { public abstract Builder setColumns(List columns); + public abstract Builder setIndex(String index); + + public abstract Builder setBatching(Boolean batching); + public abstract SpannerReadSchemaTransformConfiguration build(); } @@ -198,16 +237,16 @@ public static Builder builder() { .Builder(); } - @SchemaFieldDescription("Specifies the GCP project ID.") - @Nullable - public abstract String getProjectId(); - @SchemaFieldDescription("Specifies the Cloud Spanner instance.") public abstract String getInstanceId(); @SchemaFieldDescription("Specifies the Cloud Spanner database.") public abstract String getDatabaseId(); + @SchemaFieldDescription("Specifies the GCP project ID.") + @Nullable + public abstract String getProjectId(); + @SchemaFieldDescription("Specifies the Cloud Spanner table.") @Nullable public abstract String getTableId(); @@ -216,20 +255,29 @@ public static Builder builder() { @Nullable public abstract String getQuery(); - @SchemaFieldDescription("Specifies the columns to read from the table.") + @SchemaFieldDescription( + "Specifies the columns to read from the table. This parameter is required when table is specified.") @Nullable public abstract List getColumns(); + + @SchemaFieldDescription( + "Specifies the Index to read from. This parameter can only be specified when using table.") + @Nullable + public abstract String getIndex(); + + @SchemaFieldDescription( + "Set to false to disable batching. Useful when using a query that is not compatible with the PartitionQuery API. Defaults to true.") + @Nullable + public abstract Boolean getBatching(); } @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return SpannerReadSchemaTransformConfiguration.class; } @Override - protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - SpannerReadSchemaTransformConfiguration configuration) { + protected SchemaTransform from(SpannerReadSchemaTransformConfiguration configuration) { return new SpannerSchemaTransformRead(configuration); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java index f50755d18155..8601da09ea09 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java @@ -51,9 +51,7 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) @@ -69,43 +67,6 @@ *

    The transformation uses the {@link SpannerIO} to perform the write operation and provides * options to handle failed mutations, either by throwing an error, or passing the failed mutation * further in the pipeline for dealing with accordingly. - * - *

    Example usage in a YAML pipeline without error handling: - * - *

    {@code
    - * pipeline:
    - *   transforms:
    - *     - type: WriteToSpanner
    - *       name: WriteShipments
    - *       config:
    - *         project_id: 'apache-beam-testing'
    - *         instance_id: 'shipment-test'
    - *         database_id: 'shipment'
    - *         table_id: 'shipments'
    - *
    - * }
    - * - *

    Example usage in a YAML pipeline using error handling: - * - *

    {@code
    - * pipeline:
    - *   transforms:
    - *     - type: WriteToSpanner
    - *       name: WriteShipments
    - *       config:
    - *         project_id: 'apache-beam-testing'
    - *         instance_id: 'shipment-test'
    - *         database_id: 'shipment'
    - *         table_id: 'shipments'
    - *         error_handling:
    - *           output: 'errors'
    - *
    - *     - type: WriteToJson
    - *       input: WriteSpanner.my_error_output
    - *       config:
    - *          path: errors.json
    - *
    - * }
    */ @AutoService(SchemaTransformProvider.class) public class SpannerWriteSchemaTransformProvider @@ -113,14 +74,37 @@ public class SpannerWriteSchemaTransformProvider SpannerWriteSchemaTransformProvider.SpannerWriteSchemaTransformConfiguration> { @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + public String identifier() { + return "beam:schematransform:org.apache.beam:spanner_write:v1"; + } + + @Override + public String description() { + return "Performs a bulk write to a Google Cloud Spanner table.\n" + + "\n" + + "Example configuration for performing a write to a single table: ::\n" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " project_id: 'my-project-id'\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + "\n" + + "Note: See " + + "SpannerIO for more advanced information."; + } + + @Override + protected Class configurationClass() { return SpannerWriteSchemaTransformConfiguration.class; } @Override - protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - SpannerWriteSchemaTransformConfiguration configuration) { + protected SchemaTransform from(SpannerWriteSchemaTransformConfiguration configuration) { return new SpannerSchemaTransformWrite(configuration); } @@ -230,19 +214,12 @@ public PCollectionRowTuple expand(@NonNull PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:spanner_write:v1"; - } - - @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.singletonList("input"); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("post-write", "errors"); } @@ -250,10 +227,6 @@ public PCollectionRowTuple expand(@NonNull PCollectionRowTuple input) { @DefaultSchema(AutoValueSchema.class) public abstract static class SpannerWriteSchemaTransformConfiguration implements Serializable { - @SchemaFieldDescription("Specifies the GCP project.") - @Nullable - public abstract String getProjectId(); - @SchemaFieldDescription("Specifies the Cloud Spanner instance.") public abstract String getInstanceId(); @@ -263,7 +236,11 @@ public abstract static class SpannerWriteSchemaTransformConfiguration implements @SchemaFieldDescription("Specifies the Cloud Spanner table.") public abstract String getTableId(); - @SchemaFieldDescription("Specifies how to handle errors.") + @SchemaFieldDescription("Specifies the GCP project.") + @Nullable + public abstract String getProjectId(); + + @SchemaFieldDescription("Whether and how to handle write errors.") @Nullable public abstract ErrorHandling getErrorHandling(); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java deleted file mode 100644 index 322e85cb07a2..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.spanner.changestreams; - -import java.util.UUID; - -/** - * This class generates a unique name for the partition metadata table, which is created when the - * Connector is initialized. - */ -public class NameGenerator { - - private static final String PARTITION_METADATA_TABLE_NAME_FORMAT = "Metadata_%s_%s"; - private static final int MAX_TABLE_NAME_LENGTH = 63; - - /** - * Generates an unique name for the partition metadata table in the form of {@code - * "Metadata__"}. - * - * @param databaseId The database id where the table will be created - * @return the unique generated name of the partition metadata table - */ - public static String generatePartitionMetadataTableName(String databaseId) { - // There are 11 characters in the name format. - // Maximum Spanner database ID length is 30 characters. - // UUID always generates a String with 36 characters. - // Since the Postgres table name length is 63, we may need to truncate the table name depending - // on the database length. - String fullString = - String.format(PARTITION_METADATA_TABLE_NAME_FORMAT, databaseId, UUID.randomUUID()) - .replaceAll("-", "_"); - if (fullString.length() < MAX_TABLE_NAME_LENGTH) { - return fullString; - } - return fullString.substring(0, MAX_TABLE_NAME_LENGTH); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java index f3562e4cd917..e7bc064b1f33 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java @@ -66,10 +66,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.com.google.gson.Gson; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; @@ -80,8 +77,7 @@ public class SpannerChangestreamsReadSchemaTransformProvider extends TypedSchemaTransformProvider< SpannerChangestreamsReadSchemaTransformProvider.SpannerChangestreamsReadConfiguration> { @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return SpannerChangestreamsReadConfiguration.class; } @@ -94,7 +90,7 @@ public class SpannerChangestreamsReadSchemaTransformProvider Schema.builder().addStringField("error").addNullableStringField("row").build(); @Override - public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( + public SchemaTransform from( SpannerChangestreamsReadSchemaTransformProvider.SpannerChangestreamsReadConfiguration configuration) { return new SchemaTransform() { @@ -142,19 +138,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:spanner_cdc_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("output", "errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java index b9718fdb675e..787abad02e02 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java @@ -44,7 +44,7 @@ public class DaoFactory implements Serializable { private final SpannerConfig metadataSpannerConfig; private final String changeStreamName; - private final String partitionMetadataTableName; + private final PartitionMetadataTableNames partitionMetadataTableNames; private final RpcPriority rpcPriority; private final String jobName; private final Dialect spannerChangeStreamDatabaseDialect; @@ -56,7 +56,7 @@ public class DaoFactory implements Serializable { * @param changeStreamSpannerConfig the configuration for the change streams DAO * @param changeStreamName the name of the change stream for the change streams DAO * @param metadataSpannerConfig the metadata tables configuration - * @param partitionMetadataTableName the name of the created partition metadata table + * @param partitionMetadataTableNames the names of the partition metadata ddl objects * @param rpcPriority the priority of the requests made by the DAO queries * @param jobName the name of the running job */ @@ -64,7 +64,7 @@ public DaoFactory( SpannerConfig changeStreamSpannerConfig, String changeStreamName, SpannerConfig metadataSpannerConfig, - String partitionMetadataTableName, + PartitionMetadataTableNames partitionMetadataTableNames, RpcPriority rpcPriority, String jobName, Dialect spannerChangeStreamDatabaseDialect, @@ -78,7 +78,7 @@ public DaoFactory( this.changeStreamSpannerConfig = changeStreamSpannerConfig; this.changeStreamName = changeStreamName; this.metadataSpannerConfig = metadataSpannerConfig; - this.partitionMetadataTableName = partitionMetadataTableName; + this.partitionMetadataTableNames = partitionMetadataTableNames; this.rpcPriority = rpcPriority; this.jobName = jobName; this.spannerChangeStreamDatabaseDialect = spannerChangeStreamDatabaseDialect; @@ -102,7 +102,7 @@ public synchronized PartitionMetadataAdminDao getPartitionMetadataAdminDao() { databaseAdminClient, metadataSpannerConfig.getInstanceId().get(), metadataSpannerConfig.getDatabaseId().get(), - partitionMetadataTableName, + partitionMetadataTableNames, this.metadataDatabaseDialect); } return partitionMetadataAdminDao; @@ -120,7 +120,7 @@ public synchronized PartitionMetadataDao getPartitionMetadataDao() { if (partitionMetadataDaoInstance == null) { partitionMetadataDaoInstance = new PartitionMetadataDao( - this.partitionMetadataTableName, + this.partitionMetadataTableNames.getTableName(), spannerAccessor.getDatabaseClient(), this.metadataDatabaseDialect); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java index 368cab7022b3..3e6045d8858b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java @@ -79,19 +79,13 @@ public class PartitionMetadataAdminDao { */ public static final String COLUMN_FINISHED_AT = "FinishedAt"; - /** Metadata table index for queries over the watermark column. */ - public static final String WATERMARK_INDEX = "WatermarkIndex"; - - /** Metadata table index for queries over the created at / start timestamp columns. */ - public static final String CREATED_AT_START_TIMESTAMP_INDEX = "CreatedAtStartTimestampIndex"; - private static final int TIMEOUT_MINUTES = 10; private static final int TTL_AFTER_PARTITION_FINISHED_DAYS = 1; private final DatabaseAdminClient databaseAdminClient; private final String instanceId; private final String databaseId; - private final String tableName; + private final PartitionMetadataTableNames names; private final Dialect dialect; /** @@ -101,18 +95,18 @@ public class PartitionMetadataAdminDao { * table * @param instanceId the instance where the metadata table will reside * @param databaseId the database where the metadata table will reside - * @param tableName the name of the metadata table + * @param names the names of the metadata table ddl objects */ PartitionMetadataAdminDao( DatabaseAdminClient databaseAdminClient, String instanceId, String databaseId, - String tableName, + PartitionMetadataTableNames names, Dialect dialect) { this.databaseAdminClient = databaseAdminClient; this.instanceId = instanceId; this.databaseId = databaseId; - this.tableName = tableName; + this.names = names; this.dialect = dialect; } @@ -128,8 +122,8 @@ public void createPartitionMetadataTable() { if (this.isPostgres()) { // Literals need be added around literals to preserve casing. ddl.add( - "CREATE TABLE \"" - + tableName + "CREATE TABLE IF NOT EXISTS \"" + + names.getTableName() + "\"(\"" + COLUMN_PARTITION_TOKEN + "\" text NOT NULL,\"" @@ -163,20 +157,20 @@ public void createPartitionMetadataTable() { + COLUMN_FINISHED_AT + "\""); ddl.add( - "CREATE INDEX \"" - + WATERMARK_INDEX + "CREATE INDEX IF NOT EXISTS \"" + + names.getWatermarkIndexName() + "\" on \"" - + tableName + + names.getTableName() + "\" (\"" + COLUMN_WATERMARK + "\") INCLUDE (\"" + COLUMN_STATE + "\")"); ddl.add( - "CREATE INDEX \"" - + CREATED_AT_START_TIMESTAMP_INDEX + "CREATE INDEX IF NOT EXISTS \"" + + names.getCreatedAtIndexName() + "\" ON \"" - + tableName + + names.getTableName() + "\" (\"" + COLUMN_CREATED_AT + "\",\"" @@ -184,8 +178,8 @@ public void createPartitionMetadataTable() { + "\")"); } else { ddl.add( - "CREATE TABLE " - + tableName + "CREATE TABLE IF NOT EXISTS " + + names.getTableName() + " (" + COLUMN_PARTITION_TOKEN + " STRING(MAX) NOT NULL," @@ -218,20 +212,20 @@ public void createPartitionMetadataTable() { + TTL_AFTER_PARTITION_FINISHED_DAYS + " DAY))"); ddl.add( - "CREATE INDEX " - + WATERMARK_INDEX + "CREATE INDEX IF NOT EXISTS " + + names.getWatermarkIndexName() + " on " - + tableName + + names.getTableName() + " (" + COLUMN_WATERMARK + ") STORING (" + COLUMN_STATE + ")"); ddl.add( - "CREATE INDEX " - + CREATED_AT_START_TIMESTAMP_INDEX + "CREATE INDEX IF NOT EXISTS " + + names.getCreatedAtIndexName() + " ON " - + tableName + + names.getTableName() + " (" + COLUMN_CREATED_AT + "," @@ -261,16 +255,14 @@ public void createPartitionMetadataTable() { * Drops the metadata table. This operation should complete in {@link * PartitionMetadataAdminDao#TIMEOUT_MINUTES} minutes. */ - public void deletePartitionMetadataTable() { + public void deletePartitionMetadataTable(List indexes) { List ddl = new ArrayList<>(); if (this.isPostgres()) { - ddl.add("DROP INDEX \"" + CREATED_AT_START_TIMESTAMP_INDEX + "\""); - ddl.add("DROP INDEX \"" + WATERMARK_INDEX + "\""); - ddl.add("DROP TABLE \"" + tableName + "\""); + indexes.forEach(index -> ddl.add("DROP INDEX \"" + index + "\"")); + ddl.add("DROP TABLE \"" + names.getTableName() + "\""); } else { - ddl.add("DROP INDEX " + CREATED_AT_START_TIMESTAMP_INDEX); - ddl.add("DROP INDEX " + WATERMARK_INDEX); - ddl.add("DROP TABLE " + tableName); + indexes.forEach(index -> ddl.add("DROP INDEX " + index)); + ddl.add("DROP TABLE " + names.getTableName()); } OperationFuture op = databaseAdminClient.updateDatabaseDdl(instanceId, databaseId, ddl, null); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java index 7867932cd1ad..654fd946663c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java @@ -96,6 +96,41 @@ public boolean tableExists() { } } + /** + * Finds all indexes for the metadata table. + * + * @return a list of index names for the metadata table. + */ + public List findAllTableIndexes() { + String indexesStmt; + if (this.isPostgres()) { + indexesStmt = + "SELECT index_name FROM information_schema.indexes" + + " WHERE table_schema = 'public'" + + " AND table_name = '" + + metadataTableName + + "' AND index_type != 'PRIMARY_KEY'"; + } else { + indexesStmt = + "SELECT index_name FROM information_schema.indexes" + + " WHERE table_schema = ''" + + " AND table_name = '" + + metadataTableName + + "' AND index_type != 'PRIMARY_KEY'"; + } + + List result = new ArrayList<>(); + try (ResultSet queryResultSet = + databaseClient + .singleUseReadOnlyTransaction() + .executeQuery(Statement.of(indexesStmt), Options.tag("query=findAllTableIndexes"))) { + while (queryResultSet.next()) { + result.add(queryResultSet.getString("index_name")); + } + } + return result; + } + /** * Fetches the partition metadata row data for the given partition token. * diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java new file mode 100644 index 000000000000..07d7b80676de --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.dao; + +import java.io.Serializable; +import java.util.Objects; +import java.util.UUID; +import javax.annotation.Nullable; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; + +/** + * Configuration for a partition metadata table. It encapsulates the name of the metadata table and + * indexes. + */ +public class PartitionMetadataTableNames implements Serializable { + + private static final long serialVersionUID = 8848098877671834584L; + + /** PostgreSQL max table and index length is 63 bytes. */ + @VisibleForTesting static final int MAX_NAME_LENGTH = 63; + + private static final String PARTITION_METADATA_TABLE_NAME_FORMAT = "Metadata_%s_%s"; + private static final String WATERMARK_INDEX_NAME_FORMAT = "WatermarkIdx_%s_%s"; + private static final String CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT = "CreatedAtIdx_%s_%s"; + + /** + * Generates a unique name for the partition metadata table and its indexes. The table name will + * be in the form of {@code "Metadata__"}. The watermark index will be in the + * form of {@code "WatermarkIdx__}. The createdAt / start timestamp index will + * be in the form of {@code "CreatedAtIdx__}. + * + * @param databaseId The database id where the table will be created + * @return the unique generated names of the partition metadata ddl + */ + public static PartitionMetadataTableNames generateRandom(String databaseId) { + UUID uuid = UUID.randomUUID(); + + String table = generateName(PARTITION_METADATA_TABLE_NAME_FORMAT, databaseId, uuid); + String watermarkIndex = generateName(WATERMARK_INDEX_NAME_FORMAT, databaseId, uuid); + String createdAtIndex = + generateName(CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT, databaseId, uuid); + + return new PartitionMetadataTableNames(table, watermarkIndex, createdAtIndex); + } + + /** + * Encapsulates a selected table name. Index names are generated, but will only be used if the + * given table does not exist. The watermark index will be in the form of {@code + * "WatermarkIdx__}. The createdAt / start timestamp index will be in the form + * of {@code "CreatedAtIdx__}. + * + * @param databaseId The database id for the table + * @param table The table name to be used + * @return an instance with the table name and generated index names + */ + public static PartitionMetadataTableNames fromExistingTable(String databaseId, String table) { + UUID uuid = UUID.randomUUID(); + + String watermarkIndex = generateName(WATERMARK_INDEX_NAME_FORMAT, databaseId, uuid); + String createdAtIndex = + generateName(CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT, databaseId, uuid); + return new PartitionMetadataTableNames(table, watermarkIndex, createdAtIndex); + } + + private static String generateName(String template, String databaseId, UUID uuid) { + String name = String.format(template, databaseId, uuid).replaceAll("-", "_"); + if (name.length() > MAX_NAME_LENGTH) { + return name.substring(0, MAX_NAME_LENGTH); + } + return name; + } + + private final String tableName; + private final String watermarkIndexName; + private final String createdAtIndexName; + + public PartitionMetadataTableNames( + String tableName, String watermarkIndexName, String createdAtIndexName) { + this.tableName = tableName; + this.watermarkIndexName = watermarkIndexName; + this.createdAtIndexName = createdAtIndexName; + } + + public String getTableName() { + return tableName; + } + + public String getWatermarkIndexName() { + return watermarkIndexName; + } + + public String getCreatedAtIndexName() { + return createdAtIndexName; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof PartitionMetadataTableNames)) { + return false; + } + PartitionMetadataTableNames that = (PartitionMetadataTableNames) o; + return Objects.equals(tableName, that.tableName) + && Objects.equals(watermarkIndexName, that.watermarkIndexName) + && Objects.equals(createdAtIndexName, that.createdAtIndexName); + } + + @Override + public int hashCode() { + return Objects.hash(tableName, watermarkIndexName, createdAtIndexName); + } + + @Override + public String toString() { + return "PartitionMetadataTableNames{" + + "tableName='" + + tableName + + '\'' + + ", watermarkIndexName='" + + watermarkIndexName + + '\'' + + ", createdAtIndexName='" + + createdAtIndexName + + '\'' + + '}'; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java index a048c885a001..f8aa497292bf 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn; import java.io.Serializable; +import java.util.List; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.DaoFactory; import org.apache.beam.sdk.transforms.DoFn; @@ -33,6 +34,7 @@ public CleanUpReadChangeStreamDoFn(DaoFactory daoFactory) { @ProcessElement public void processElement(OutputReceiver receiver) { - daoFactory.getPartitionMetadataAdminDao().deletePartitionMetadataTable(); + List indexes = daoFactory.getPartitionMetadataDao().findAllTableIndexes(); + daoFactory.getPartitionMetadataAdminDao().deletePartitionMetadataTable(indexes); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java index 387ffd603b14..ca93f34bf1ba 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java @@ -64,6 +64,7 @@ public InitializeDoFn( public void processElement(OutputReceiver receiver) { PartitionMetadataDao partitionMetadataDao = daoFactory.getPartitionMetadataDao(); if (!partitionMetadataDao.tableExists()) { + // Creates partition metadata table and associated indexes daoFactory.getPartitionMetadataAdminDao().createPartitionMetadataTable(); createFakeParentPartition(); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/TableContainer.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/TableContainer.java index b50aa4d32d76..b44b9596cc12 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/TableContainer.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/TableContainer.java @@ -18,11 +18,13 @@ package org.apache.beam.sdk.io.gcp.testing; import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableConstraints; import com.google.api.services.bigquery.model.TableRow; import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.beam.sdk.io.gcp.bigquery.TableRowJsonCoder; @@ -51,12 +53,24 @@ class TableContainer { this.keyedRows = Maps.newHashMap(); this.ids = new ArrayList<>(); this.sizeBytes = 0L; + // extract primary key information from Table if present + List pkColumns = primaryKeyColumns(table); + this.primaryKeyColumns = pkColumns; + this.primaryKeyColumnIndices = primaryColumnFieldIndices(pkColumns, table); } - // Only top-level columns supported. - void setPrimaryKeyColumns(List primaryKeyColumns) { - this.primaryKeyColumns = primaryKeyColumns; + static @Nullable List primaryKeyColumns(Table table) { + return Optional.ofNullable(table.getTableConstraints()) + .flatMap(constraints -> Optional.ofNullable(constraints.getPrimaryKey())) + .map(TableConstraints.PrimaryKey::getColumns) + .orElse(null); + } + static @Nullable List primaryColumnFieldIndices( + @Nullable List primaryKeyColumns, Table table) { + if (primaryKeyColumns == null) { + return null; + } Map indices = IntStream.range(0, table.getSchema().getFields().size()) .boxed() @@ -65,7 +79,13 @@ void setPrimaryKeyColumns(List primaryKeyColumns) { for (String columnName : primaryKeyColumns) { primaryKeyColumnIndices.add(Preconditions.checkStateNotNull(indices.get(columnName))); } - this.primaryKeyColumnIndices = primaryKeyColumnIndices; + return primaryKeyColumnIndices; + } + + // Only top-level columns supported. + void setPrimaryKeyColumns(List primaryKeyColumns) { + this.primaryKeyColumns = primaryKeyColumns; + this.primaryKeyColumnIndices = primaryColumnFieldIndices(primaryKeyColumns, table); } @Nullable @@ -80,7 +100,7 @@ List getPrimaryKey(TableRow tableRow) { .stream() .map(cell -> Preconditions.checkStateNotNull(cell.get("v"))) .collect(Collectors.toList()); - ; + return Preconditions.checkStateNotNull(primaryKeyColumnIndices).stream() .map(cellValues::get) .collect(Collectors.toList()); @@ -91,7 +111,7 @@ List getPrimaryKey(TableRow tableRow) { long addRow(TableRow row, String id) { List primaryKey = getPrimaryKey(row); - if (primaryKey != null) { + if (primaryKey != null && !primaryKey.isEmpty()) { if (keyedRows.putIfAbsent(primaryKey, row) != null) { throw new RuntimeException( "Primary key validation error! Multiple inserts with the same primary key."); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java index 4013f0018553..d8c580a0cd18 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import com.google.protobuf.ByteString; import com.google.protobuf.DescriptorProtos.DescriptorProto; @@ -36,8 +37,11 @@ import java.time.LocalTime; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; @@ -284,12 +288,14 @@ public class BeamRowToStorageApiProtoTest { .addField("nested", FieldType.row(BASE_SCHEMA).withNullable(true)) .addField("nestedArray", FieldType.array(FieldType.row(BASE_SCHEMA))) .addField("nestedIterable", FieldType.iterable(FieldType.row(BASE_SCHEMA))) + .addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.row(BASE_SCHEMA))) .build(); private static final Row NESTED_ROW = Row.withSchema(NESTED_SCHEMA) .withFieldValue("nested", BASE_ROW) .withFieldValue("nestedArray", ImmutableList.of(BASE_ROW, BASE_ROW)) .withFieldValue("nestedIterable", ImmutableList.of(BASE_ROW, BASE_ROW)) + .withFieldValue("nestedMap", ImmutableMap.of("key1", BASE_ROW, "key2", BASE_ROW)) .build(); @Test @@ -347,12 +353,12 @@ public void testNestedFromSchema() { .collect( Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel)); - assertEquals(3, types.size()); + assertEquals(4, types.size()); Map nestedTypes = descriptor.getNestedTypeList().stream() .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); - assertEquals(3, nestedTypes.size()); + assertEquals(4, nestedTypes.size()); assertEquals(Type.TYPE_MESSAGE, types.get("nested")); assertEquals(Label.LABEL_OPTIONAL, typeLabels.get("nested")); String nestedTypeName1 = typeNames.get("nested"); @@ -379,6 +385,87 @@ public void testNestedFromSchema() { .collect( Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); assertEquals(expectedBaseTypes, nestedTypes3); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmap")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmap")); + String nestedTypeName4 = typeNames.get("nestedmap"); + // expects 2 fields in the nested map, key and value + assertEquals(2, nestedTypes.get(nestedTypeName4).getFieldList().size()); + Supplier> stream = + () -> nestedTypes.get(nestedTypeName4).getFieldList().stream(); + assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("key"))); + assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("value"))); + + Map nestedTypes4 = + nestedTypes.get(nestedTypeName4).getNestedTypeList().stream() + .flatMap(vdesc -> vdesc.getFieldList().stream()) + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedBaseTypes, nestedTypes4); + } + + @Test + public void testParticularMapsFromSchemas() { + Schema nestedMapSchemaVariations = + Schema.builder() + .addField( + "nestedMultiMap", + FieldType.map(FieldType.STRING, FieldType.array(FieldType.STRING))) + .addField( + "nestedMapNullable", + FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true)) + .build(); + + DescriptorProto descriptor = + TableRowToStorageApiProto.descriptorSchemaFromTableSchema( + BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema((nestedMapSchemaVariations)), + true, + false); + + Map types = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + Map typeNames = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getTypeName)); + Map typeLabels = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel)); + + Map nestedTypes = + descriptor.getNestedTypeList().stream() + .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); + assertEquals(2, nestedTypes.size()); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmultimap")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmultimap")); + String nestedMultiMapName = typeNames.get("nestedmultimap"); + // expects 2 fields for the nested array of maps, key and value + assertEquals(2, nestedTypes.get(nestedMultiMapName).getFieldList().size()); + Supplier> stream = + () -> nestedTypes.get(nestedMultiMapName).getFieldList().stream(); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1); + assertTrue( + stream + .get() + .filter(fdp -> fdp.getName().equals("value")) + .filter(fdp -> fdp.getLabel().equals(Label.LABEL_REPEATED)) + .count() + == 1); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmapnullable")); + // even though the field is marked as optional in the row we will should see repeated in proto + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmapnullable")); + String nestedMapNullableName = typeNames.get("nestedmapnullable"); + // expects 2 fields in the nullable maps, key and value + assertEquals(2, nestedTypes.get(nestedMapNullableName).getFieldList().size()); + stream = () -> nestedTypes.get(nestedMapNullableName).getFieldList().stream(); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1); } private void assertBaseRecord(DynamicMessage msg) { @@ -395,7 +482,7 @@ public void testMessageFromTableRow() throws Exception { BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(NESTED_SCHEMA), true, false); DynamicMessage msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, null, -1); - assertEquals(3, msg.getAllFields().size()); + assertEquals(4, msg.getAllFields().size()); Map fieldDescriptors = descriptor.getFields().stream() @@ -404,6 +491,63 @@ public void testMessageFromTableRow() throws Exception { assertBaseRecord(nestedMsg); } + @Test + public void testMessageFromTableRowForArraysAndMaps() throws Exception { + Schema nestedMapSchemaVariations = + Schema.builder() + .addField("nestedArrayNullable", FieldType.array(FieldType.STRING).withNullable(true)) + .addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.STRING)) + .addField( + "nestedMultiMap", + FieldType.map(FieldType.STRING, FieldType.iterable(FieldType.STRING))) + .addField( + "nestedMapNullable", + FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true)) + .build(); + + Row nestedRow = + Row.withSchema(nestedMapSchemaVariations) + .withFieldValue("nestedArrayNullable", null) + .withFieldValue("nestedMap", ImmutableMap.of("key1", "value1")) + .withFieldValue( + "nestedMultiMap", + ImmutableMap.of("multikey1", ImmutableList.of("multivalue1", "multivalue2"))) + .withFieldValue("nestedMapNullable", null) + .build(); + + Descriptor descriptor = + TableRowToStorageApiProto.getDescriptorFromTableSchema( + BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(nestedMapSchemaVariations), + true, + false); + DynamicMessage msg = + BeamRowToStorageApiProto.messageFromBeamRow(descriptor, nestedRow, null, -1); + + Map fieldDescriptors = + descriptor.getFields().stream() + .collect(Collectors.toMap(FieldDescriptor::getName, Functions.identity())); + + DynamicMessage nestedMapEntryMsg = + (DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmap"), 0); + String value = + (String) + nestedMapEntryMsg.getField( + fieldDescriptors.get("nestedmap").getMessageType().findFieldByName("value")); + assertEquals("value1", value); + + DynamicMessage nestedMultiMapEntryMsg = + (DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmultimap"), 0); + List values = + (List) + nestedMultiMapEntryMsg.getField( + fieldDescriptors.get("nestedmultimap").getMessageType().findFieldByName("value")); + assertTrue(values.size() == 2); + assertEquals("multivalue1", values.get(0)); + + assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedarraynullable")) == 0); + assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedmapnullable")) == 0); + } + @Test public void testCdcFields() throws Exception { Descriptor descriptor = @@ -413,7 +557,7 @@ public void testCdcFields() throws Exception { assertNotNull(descriptor.findFieldByName(StorageApiCDC.CHANGE_SQN_COLUMN)); DynamicMessage msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, "UPDATE", 42); - assertEquals(5, msg.getAllFields().size()); + assertEquals(6, msg.getAllFields().size()); Map fieldDescriptors = descriptor.getFields().stream() diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java index 662f2658eb6b..2333278a11f5 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java @@ -28,23 +28,23 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneOffset; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import java.util.function.Function; import org.apache.avro.Conversions; import org.apache.avro.LogicalType; import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; -import org.apache.avro.Schema.Field; -import org.apache.avro.Schema.Type; +import org.apache.avro.SchemaBuilder; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; -import org.apache.avro.reflect.AvroSchema; -import org.apache.avro.reflect.Nullable; -import org.apache.avro.reflect.ReflectData; +import org.apache.avro.generic.GenericRecordBuilder; import org.apache.avro.util.Utf8; -import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.BaseEncoding; import org.junit.Test; @@ -54,363 +54,678 @@ /** Tests for {@link BigQueryAvroUtils}. */ @RunWith(JUnit4.class) public class BigQueryAvroUtilsTest { - private List subFields = - Lists.newArrayList( - new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE")); - /* - * Note that the quality and quantity fields do not have their mode set, so they should default - * to NULLABLE. This is an important test of BigQuery semantics. - * - * All the other fields we set in this function are required on the Schema response. - * - * See https://cloud.google.com/bigquery/docs/reference/v2/tables#schema - */ - private List fields = - Lists.newArrayList( - new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"), - new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE"), - new TableFieldSchema().setName("quality").setType("FLOAT") /* default to NULLABLE */, - new TableFieldSchema().setName("quantity").setType("INTEGER") /* default to NULLABLE */, - new TableFieldSchema().setName("birthday").setType("TIMESTAMP").setMode("NULLABLE"), - new TableFieldSchema().setName("birthdayMoney").setType("NUMERIC").setMode("NULLABLE"), - new TableFieldSchema() - .setName("lotteryWinnings") - .setType("BIGNUMERIC") - .setMode("NULLABLE"), - new TableFieldSchema().setName("flighted").setType("BOOLEAN").setMode("NULLABLE"), - new TableFieldSchema().setName("sound").setType("BYTES").setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryDate").setType("DATE").setMode("NULLABLE"), - new TableFieldSchema() - .setName("anniversaryDatetime") - .setType("DATETIME") - .setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryTime").setType("TIME").setMode("NULLABLE"), - new TableFieldSchema() - .setName("scion") - .setType("RECORD") - .setMode("NULLABLE") - .setFields(subFields), - new TableFieldSchema() - .setName("associates") - .setType("RECORD") - .setMode("REPEATED") - .setFields(subFields), - new TableFieldSchema().setName("geoPositions").setType("GEOGRAPHY").setMode("NULLABLE")); - - private ByteBuffer convertToBytes(BigDecimal bigDecimal, int precision, int scale) { - LogicalType bigDecimalLogicalType = LogicalTypes.decimal(precision, scale); - return new Conversions.DecimalConversion().toBytes(bigDecimal, null, bigDecimalLogicalType); + + private TableSchema tableSchema(Function fn) { + TableFieldSchema column = new TableFieldSchema().setName("value"); + TableSchema tableSchema = new TableSchema(); + tableSchema.setFields(Lists.newArrayList(fn.apply(column))); + return tableSchema; + } + + private Schema avroSchema( + Function, SchemaBuilder.FieldAssembler> fn) { + return fn.apply( + SchemaBuilder.record("root") + .namespace("org.apache.beam.sdk.io.gcp.bigquery") + .doc("Translated Avro Schema for root") + .fields() + .name("value")) + .endRecord(); } + @SuppressWarnings("JavaInstantGetSecondsGetNano") @Test - public void testConvertGenericRecordToTableRow() throws Exception { - BigDecimal numeric = new BigDecimal("123456789.123456789"); - ByteBuffer numericBytes = convertToBytes(numeric, 38, 9); - BigDecimal bigNumeric = - new BigDecimal( - "578960446186580977117854925043439539266.34992332820282019728792003956564819967"); - ByteBuffer bigNumericBytes = convertToBytes(bigNumeric, 77, 38); - Schema avroSchema = ReflectData.get().getSchema(Bird.class); - - { - // Test nullable fields. - GenericRecord record = new GenericData.Record(avroSchema); - record.put("number", 5L); - TableRow convertedRow = BigQueryAvroUtils.convertGenericRecordToTableRow(record); - TableRow row = new TableRow().set("number", "5").set("associates", new ArrayList()); - assertEquals(row, convertedRow); - TableRow clonedRow = convertedRow.clone(); - assertEquals(convertedRow, clonedRow); - } - { - // Test type conversion for: - // INTEGER, FLOAT, NUMERIC, TIMESTAMP, BOOLEAN, BYTES, DATE, DATETIME, TIME. - GenericRecord record = new GenericData.Record(avroSchema); - byte[] soundBytes = "chirp,chirp".getBytes(StandardCharsets.UTF_8); - ByteBuffer soundByteBuffer = ByteBuffer.wrap(soundBytes); - soundByteBuffer.rewind(); - record.put("number", 5L); - record.put("quality", 5.0); - record.put("birthday", 5L); - record.put("birthdayMoney", numericBytes); - record.put("lotteryWinnings", bigNumericBytes); - record.put("flighted", Boolean.TRUE); - record.put("sound", soundByteBuffer); - record.put("anniversaryDate", new Utf8("2000-01-01")); - record.put("anniversaryDatetime", new String("2000-01-01 00:00:00.000005")); - record.put("anniversaryTime", new Utf8("00:00:00.000005")); - record.put("geoPositions", new String("LINESTRING(1 2, 3 4, 5 6, 7 8)")); - TableRow convertedRow = BigQueryAvroUtils.convertGenericRecordToTableRow(record); - TableRow row = - new TableRow() - .set("number", "5") - .set("birthday", "1970-01-01 00:00:00.000005 UTC") - .set("birthdayMoney", numeric.toString()) - .set("lotteryWinnings", bigNumeric.toString()) - .set("quality", 5.0) - .set("associates", new ArrayList()) - .set("flighted", Boolean.TRUE) - .set("sound", BaseEncoding.base64().encode(soundBytes)) - .set("anniversaryDate", "2000-01-01") - .set("anniversaryDatetime", "2000-01-01 00:00:00.000005") - .set("anniversaryTime", "00:00:00.000005") - .set("geoPositions", "LINESTRING(1 2, 3 4, 5 6, 7 8)"); - TableRow clonedRow = convertedRow.clone(); - assertEquals(convertedRow, clonedRow); - assertEquals(row, convertedRow); - } - { - // Test repeated fields. - Schema subBirdSchema = AvroCoder.of(Bird.SubBird.class).getSchema(); - GenericRecord nestedRecord = new GenericData.Record(subBirdSchema); - nestedRecord.put("species", "other"); - GenericRecord record = new GenericData.Record(avroSchema); - record.put("number", 5L); - record.put("associates", Lists.newArrayList(nestedRecord)); - record.put("birthdayMoney", numericBytes); - record.put("lotteryWinnings", bigNumericBytes); - TableRow convertedRow = BigQueryAvroUtils.convertGenericRecordToTableRow(record); - TableRow row = + public void testConvertGenericRecordToTableRow() { + { + // bool + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().booleanType().noDefault())) + .set("value", false) + .build(); + TableRow expected = new TableRow().set("value", false); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // int + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().intType().noDefault())) + .set("value", 5) + .build(); + TableRow expected = new TableRow().set("value", "5"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // long + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().longType().noDefault())) + .set("value", 5L) + .build(); + TableRow expected = new TableRow().set("value", "5"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // float + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().floatType().noDefault())) + .set("value", 5.5f) + .build(); + TableRow expected = new TableRow().set("value", 5.5); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // double + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().doubleType().noDefault())) + .set("value", 5.55) + .build(); + TableRow expected = new TableRow().set("value", 5.55); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // bytes + byte[] bytes = "chirp,chirp".getBytes(StandardCharsets.UTF_8); + ByteBuffer bb = ByteBuffer.wrap(bytes); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().bytesType().noDefault())) + .set("value", bb) + .build(); + TableRow expected = new TableRow().set("value", BaseEncoding.base64().encode(bytes)); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // string + Schema schema = avroSchema(f -> f.type().stringType().noDefault()); + GenericRecord record = new GenericRecordBuilder(schema).set("value", "test").build(); + GenericRecord utf8Record = + new GenericRecordBuilder(schema).set("value", new Utf8("test")).build(); + TableRow expected = new TableRow().set("value", "test"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + TableRow utf8Row = BigQueryAvroUtils.convertGenericRecordToTableRow(utf8Record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + assertEquals(expected, utf8Row); + assertEquals(expected, utf8Row.clone()); + } + + { + // decimal + LogicalType lt = LogicalTypes.decimal(38, 9); + Schema decimalType = lt.addToSchema(SchemaBuilder.builder().bytesType()); + BigDecimal bd = new BigDecimal("123456789.123456789"); + ByteBuffer bytes = new Conversions.DecimalConversion().toBytes(bd, null, lt); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(decimalType).noDefault())) + .set("value", bytes) + .build(); + TableRow expected = new TableRow().set("value", bd.toString()); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // date + LogicalType lt = LogicalTypes.date(); + Schema dateType = lt.addToSchema(SchemaBuilder.builder().intType()); + LocalDate date = LocalDate.of(2000, 1, 1); + int days = (int) date.toEpochDay(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(dateType).noDefault())) + .set("value", days) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // time-millis + LogicalType lt = LogicalTypes.timeMillis(); + Schema timeType = lt.addToSchema(SchemaBuilder.builder().intType()); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + int millis = (int) (time.toNanoOfDay() / 1000000); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timeType).noDefault())) + .set("value", millis) + .build(); + TableRow expected = new TableRow().set("value", "01:02:03.123"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // time-micros + LogicalType lt = LogicalTypes.timeMicros(); + Schema timeType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + long micros = time.toNanoOfDay() / 1000; + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timeType).noDefault())) + .set("value", micros) + .build(); + TableRow expected = new TableRow().set("value", "01:02:03.123456"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // local-timestamp-millis + LogicalType lt = LogicalTypes.localTimestampMillis(); + Schema timestampType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalDate date = LocalDate.of(2000, 1, 1); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + LocalDateTime ts = LocalDateTime.of(date, time); + long millis = ts.toInstant(ZoneOffset.UTC).toEpochMilli(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timestampType).noDefault())) + .set("value", millis) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01 01:02:03.123"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // local-timestamp-micros + LogicalType lt = LogicalTypes.localTimestampMicros(); + Schema timestampType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalDate date = LocalDate.of(2000, 1, 1); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + LocalDateTime ts = LocalDateTime.of(date, time); + long seconds = ts.toInstant(ZoneOffset.UTC).getEpochSecond(); + int nanos = ts.toInstant(ZoneOffset.UTC).getNano(); + long micros = seconds * 1000000 + (nanos / 1000); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timestampType).noDefault())) + .set("value", micros) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01 01:02:03.123456"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // timestamp-micros + LogicalType lt = LogicalTypes.timestampMillis(); + Schema timestampType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalDate date = LocalDate.of(2000, 1, 1); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + LocalDateTime ts = LocalDateTime.of(date, time); + long millis = ts.toInstant(ZoneOffset.UTC).toEpochMilli(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timestampType).noDefault())) + .set("value", millis) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01 01:02:03.123 UTC"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // timestamp-millis + LogicalType lt = LogicalTypes.timestampMicros(); + Schema timestampType = lt.addToSchema(SchemaBuilder.builder().longType()); + LocalDate date = LocalDate.of(2000, 1, 1); + LocalTime time = LocalTime.of(1, 2, 3, 123456789); + LocalDateTime ts = LocalDateTime.of(date, time); + long seconds = ts.toInstant(ZoneOffset.UTC).getEpochSecond(); + int nanos = ts.toInstant(ZoneOffset.UTC).getNano(); + long micros = seconds * 1000000 + (nanos / 1000); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(timestampType).noDefault())) + .set("value", micros) + .build(); + TableRow expected = new TableRow().set("value", "2000-01-01 01:02:03.123456 UTC"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // enum + Schema enumSchema = SchemaBuilder.enumeration("color").symbols("red", "green", "blue"); + GenericData.EnumSymbol symbol = new GenericData.EnumSymbol(enumSchema, "RED"); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(enumSchema).noDefault())) + .set("value", symbol) + .build(); + TableRow expected = new TableRow().set("value", "RED"); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // fixed + UUID uuid = UUID.randomUUID(); + ByteBuffer bb = ByteBuffer.allocate(16); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + bb.rewind(); + byte[] bytes = bb.array(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().fixed("uuid").size(16).noDefault())) + .set("value", bb) + .build(); + TableRow expected = new TableRow().set("value", BaseEncoding.base64().encode(bytes)); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // null + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().optional().booleanType())).build(); + TableRow expected = new TableRow(); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // array + GenericRecord record = + new GenericRecordBuilder( + avroSchema(f -> f.type().array().items().booleanType().noDefault())) + .set("value", Lists.newArrayList(true, false)) + .build(); + TableRow expected = new TableRow().set("value", Lists.newArrayList(true, false)); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // map + Map map = new HashMap<>(); + map.put("left", 1); + map.put("right", -1); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type().map().values().intType().noDefault())) + .set("value", map) + .build(); + TableRow expected = new TableRow() - .set("associates", Lists.newArrayList(new TableRow().set("species", "other"))) - .set("number", "5") - .set("birthdayMoney", numeric.toString()) - .set("lotteryWinnings", bigNumeric.toString()); - assertEquals(row, convertedRow); - TableRow clonedRow = convertedRow.clone(); - assertEquals(convertedRow, clonedRow); + .set( + "value", + Lists.newArrayList( + new TableRow().set("key", "left").set("value", "1"), + new TableRow().set("key", "right").set("value", "-1"))); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); + } + + { + // record + Schema subSchema = + SchemaBuilder.builder() + .record("record") + .fields() + .name("int") + .type() + .intType() + .noDefault() + .name("float") + .type() + .floatType() + .noDefault() + .endRecord(); + GenericRecord subRecord = + new GenericRecordBuilder(subSchema).set("int", 5).set("float", 5.5f).build(); + GenericRecord record = + new GenericRecordBuilder(avroSchema(f -> f.type(subSchema).noDefault())) + .set("value", subRecord) + .build(); + TableRow expected = + new TableRow().set("value", new TableRow().set("int", "5").set("float", 5.5)); + TableRow row = BigQueryAvroUtils.convertGenericRecordToTableRow(record); + + assertEquals(expected, row); + assertEquals(expected, row.clone()); } } @Test public void testConvertBigQuerySchemaToAvroSchema() { - TableSchema tableSchema = new TableSchema(); - tableSchema.setFields(fields); - Schema avroSchema = BigQueryAvroUtils.toGenericAvroSchema(tableSchema); + { + // REQUIRED + TableSchema tableSchema = tableSchema(f -> f.setType("BOOLEAN").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().booleanType().noDefault()); - assertThat(avroSchema.getField("number").schema(), equalTo(Schema.create(Type.LONG))); - assertThat( - avroSchema.getField("species").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); - assertThat( - avroSchema.getField("quality").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.DOUBLE)))); - assertThat( - avroSchema.getField("quantity").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.LONG)))); - assertThat( - avroSchema.getField("birthday").schema(), - equalTo( - Schema.createUnion( - Schema.create(Type.NULL), - LogicalTypes.timestampMicros().addToSchema(Schema.create(Type.LONG))))); - assertThat( - avroSchema.getField("birthdayMoney").schema(), - equalTo( - Schema.createUnion( - Schema.create(Type.NULL), - LogicalTypes.decimal(38, 9).addToSchema(Schema.create(Type.BYTES))))); - assertThat( - avroSchema.getField("lotteryWinnings").schema(), - equalTo( - Schema.createUnion( - Schema.create(Type.NULL), - LogicalTypes.decimal(77, 38).addToSchema(Schema.create(Type.BYTES))))); - assertThat( - avroSchema.getField("flighted").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.BOOLEAN)))); - assertThat( - avroSchema.getField("sound").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.BYTES)))); - Schema dateSchema = Schema.create(Type.INT); - LogicalTypes.date().addToSchema(dateSchema); - assertThat( - avroSchema.getField("anniversaryDate").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), dateSchema))); - Schema dateTimeSchema = Schema.create(Type.STRING); - BigQueryAvroUtils.DATETIME_LOGICAL_TYPE.addToSchema(dateTimeSchema); - assertThat( - avroSchema.getField("anniversaryDatetime").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), dateTimeSchema))); - Schema timeSchema = Schema.create(Type.LONG); - LogicalTypes.timeMicros().addToSchema(timeSchema); - assertThat( - avroSchema.getField("anniversaryTime").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), timeSchema))); - Schema geoSchema = Schema.create(Type.STRING); - geoSchema.addProp("sqlType", "GEOGRAPHY"); - assertThat( - avroSchema.getField("geoPositions").schema(), - equalTo(Schema.createUnion(Schema.create(Type.NULL), geoSchema))); - assertThat( - avroSchema.getField("scion").schema(), - equalTo( - Schema.createUnion( - Schema.create(Type.NULL), - Schema.createRecord( - "scion", - "Translated Avro Schema for scion", - "org.apache.beam.sdk.io.gcp.bigquery", - false, - ImmutableList.of( - new Field( - "species", - Schema.createUnion( - Schema.create(Type.NULL), Schema.create(Type.STRING)), - null, - (Object) null)))))); - assertThat( - avroSchema.getField("associates").schema(), - equalTo( - Schema.createArray( - Schema.createRecord( - "associates", - "Translated Avro Schema for associates", - "org.apache.beam.sdk.io.gcp.bigquery", - false, - ImmutableList.of( - new Field( - "species", - Schema.createUnion( - Schema.create(Type.NULL), Schema.create(Type.STRING)), - null, - (Object) null)))))); - } + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } - @Test - public void testConvertBigQuerySchemaToAvroSchemaWithoutLogicalTypes() { - TableSchema tableSchema = new TableSchema(); - tableSchema.setFields(fields); - Schema avroSchema = BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false); + { + // NULLABLE + TableSchema tableSchema = tableSchema(f -> f.setType("BOOLEAN").setMode("NULLABLE")); + Schema expected = + avroSchema(f -> f.type().unionOf().nullType().and().booleanType().endUnion().noDefault()); - assertThat(avroSchema.getField("number").schema(), equalTo(Schema.create(Schema.Type.LONG))); - assertThat( - avroSchema.getField("species").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.STRING)))); - assertThat( - avroSchema.getField("quality").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.DOUBLE)))); - assertThat( - avroSchema.getField("quantity").schema(), - equalTo( - Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.LONG)))); - assertThat( - avroSchema.getField("birthday").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), - LogicalTypes.timestampMicros().addToSchema(Schema.create(Schema.Type.LONG))))); - assertThat( - avroSchema.getField("birthdayMoney").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), - LogicalTypes.decimal(38, 9).addToSchema(Schema.create(Schema.Type.BYTES))))); - assertThat( - avroSchema.getField("lotteryWinnings").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), - LogicalTypes.decimal(77, 38).addToSchema(Schema.create(Schema.Type.BYTES))))); - assertThat( - avroSchema.getField("flighted").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.BOOLEAN)))); - assertThat( - avroSchema.getField("sound").schema(), - equalTo( - Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.BYTES)))); - Schema dateSchema = Schema.create(Schema.Type.STRING); - dateSchema.addProp("sqlType", "DATE"); - assertThat( - avroSchema.getField("anniversaryDate").schema(), - equalTo(Schema.createUnion(Schema.create(Schema.Type.NULL), dateSchema))); - Schema dateTimeSchema = Schema.create(Schema.Type.STRING); - dateTimeSchema.addProp("sqlType", "DATETIME"); - assertThat( - avroSchema.getField("anniversaryDatetime").schema(), - equalTo(Schema.createUnion(Schema.create(Schema.Type.NULL), dateTimeSchema))); - Schema timeSchema = Schema.create(Schema.Type.STRING); - timeSchema.addProp("sqlType", "TIME"); - assertThat( - avroSchema.getField("anniversaryTime").schema(), - equalTo(Schema.createUnion(Schema.create(Schema.Type.NULL), timeSchema))); - Schema geoSchema = Schema.create(Type.STRING); - geoSchema.addProp("sqlType", "GEOGRAPHY"); - assertThat( - avroSchema.getField("geoPositions").schema(), - equalTo(Schema.createUnion(Schema.create(Schema.Type.NULL), geoSchema))); - assertThat( - avroSchema.getField("scion").schema(), - equalTo( - Schema.createUnion( - Schema.create(Schema.Type.NULL), - Schema.createRecord( - "scion", - "Translated Avro Schema for scion", - "org.apache.beam.sdk.io.gcp.bigquery", - false, - ImmutableList.of( - new Schema.Field( - "species", - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.STRING)), - null, - (Object) null)))))); - assertThat( - avroSchema.getField("associates").schema(), - equalTo( - Schema.createArray( - Schema.createRecord( - "associates", - "Translated Avro Schema for associates", - "org.apache.beam.sdk.io.gcp.bigquery", - false, - ImmutableList.of( - new Schema.Field( - "species", - Schema.createUnion( - Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.STRING)), - null, - (Object) null)))))); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } + + { + // default mode -> NULLABLE + TableSchema tableSchema = tableSchema(f -> f.setType("BOOLEAN")); + Schema expected = + avroSchema(f -> f.type().unionOf().nullType().and().booleanType().endUnion().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } + + { + // REPEATED + TableSchema tableSchema = tableSchema(f -> f.setType("BOOLEAN").setMode("REPEATED")); + Schema expected = avroSchema(f -> f.type().array().items().booleanType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } + + { + // INTEGER + TableSchema tableSchema = tableSchema(f -> f.setType("INTEGER").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().longType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + } + + { + // FLOAT + TableSchema tableSchema = tableSchema(f -> f.setType("FLOAT").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().doubleType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // BYTES + TableSchema tableSchema = tableSchema(f -> f.setType("BYTES").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().bytesType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // STRING + TableSchema tableSchema = tableSchema(f -> f.setType("STRING").setMode("REQUIRED")); + Schema expected = avroSchema(f -> f.type().stringType().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // NUMERIC + TableSchema tableSchema = tableSchema(f -> f.setType("NUMERIC").setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(38, 9).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // NUMERIC with precision + TableSchema tableSchema = + tableSchema(f -> f.setType("NUMERIC").setPrecision(29L).setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(29, 0).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // NUMERIC with precision and scale + TableSchema tableSchema = + tableSchema(f -> f.setType("NUMERIC").setPrecision(10L).setScale(9L).setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(10, 9).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // BIGNUMERIC + TableSchema tableSchema = tableSchema(f -> f.setType("BIGNUMERIC").setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(77, 38).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // BIGNUMERIC with precision + TableSchema tableSchema = + tableSchema(f -> f.setType("BIGNUMERIC").setPrecision(38L).setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(38, 0).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // BIGNUMERIC with precision and scale + TableSchema tableSchema = + tableSchema( + f -> f.setType("BIGNUMERIC").setPrecision(39L).setScale(38L).setMode("REQUIRED")); + Schema decimalType = + LogicalTypes.decimal(39, 38).addToSchema(SchemaBuilder.builder().bytesType()); + Schema expected = avroSchema(f -> f.type(decimalType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // DATE + TableSchema tableSchema = tableSchema(f -> f.setType("DATE").setMode("REQUIRED")); + Schema dateType = LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType()); + Schema expected = avroSchema(f -> f.type(dateType).noDefault()); + Schema expectedExport = + avroSchema(f -> f.type().stringBuilder().prop("sqlType", "DATE").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expectedExport, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // TIME + TableSchema tableSchema = tableSchema(f -> f.setType("TIME").setMode("REQUIRED")); + Schema timeType = LogicalTypes.timeMicros().addToSchema(SchemaBuilder.builder().longType()); + Schema expected = avroSchema(f -> f.type(timeType).noDefault()); + Schema expectedExport = + avroSchema(f -> f.type().stringBuilder().prop("sqlType", "TIME").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expectedExport, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // DATETIME + TableSchema tableSchema = tableSchema(f -> f.setType("DATETIME").setMode("REQUIRED")); + Schema timeType = + BigQueryAvroUtils.DATETIME_LOGICAL_TYPE.addToSchema(SchemaBuilder.builder().stringType()); + Schema expected = avroSchema(f -> f.type(timeType).noDefault()); + Schema expectedExport = + avroSchema( + f -> f.type().stringBuilder().prop("sqlType", "DATETIME").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expectedExport, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // TIMESTAMP + TableSchema tableSchema = tableSchema(f -> f.setType("TIMESTAMP").setMode("REQUIRED")); + Schema timestampType = + LogicalTypes.timestampMicros().addToSchema(SchemaBuilder.builder().longType()); + Schema expected = avroSchema(f -> f.type(timestampType).noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // GEOGRAPHY + TableSchema tableSchema = tableSchema(f -> f.setType("GEOGRAPHY").setMode("REQUIRED")); + Schema expected = + avroSchema( + f -> f.type().stringBuilder().prop("sqlType", "GEOGRAPHY").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // JSON + TableSchema tableSchema = tableSchema(f -> f.setType("JSON").setMode("REQUIRED")); + Schema expected = + avroSchema(f -> f.type().stringBuilder().prop("sqlType", "JSON").endString().noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(tableSchema, false)); + } + + { + // STRUCT/RECORD + TableFieldSchema subInteger = + new TableFieldSchema().setName("int").setType("INTEGER").setMode("NULLABLE"); + TableFieldSchema subFloat = + new TableFieldSchema().setName("float").setType("FLOAT").setMode("REQUIRED"); + TableSchema structTableSchema = + tableSchema( + f -> + f.setType("STRUCT") + .setMode("REQUIRED") + .setFields(Lists.newArrayList(subInteger, subFloat))); + TableSchema recordTableSchema = + tableSchema( + f -> + f.setType("RECORD") + .setMode("REQUIRED") + .setFields(Lists.newArrayList(subInteger, subFloat))); + + Schema expected = + avroSchema( + f -> + f.type() + .record("value") + .fields() + .name("int") + .type() + .unionOf() + .nullType() + .and() + .longType() + .endUnion() + .noDefault() + .name("float") + .type() + .doubleType() + .noDefault() + .endRecord() + .noDefault()); + + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(structTableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(structTableSchema, false)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(recordTableSchema)); + assertEquals(expected, BigQueryAvroUtils.toGenericAvroSchema(recordTableSchema, false)); + } } @Test public void testFormatTimestamp() { - assertThat( - BigQueryAvroUtils.formatTimestamp(1452062291123456L), - equalTo("2016-01-06 06:38:11.123456 UTC")); + long micros = 1452062291123456L; + String expected = "2016-01-06 06:38:11.123456"; + assertThat(BigQueryAvroUtils.formatDatetime(micros), equalTo(expected)); + assertThat(BigQueryAvroUtils.formatTimestamp(micros), equalTo(expected + " UTC")); } @Test - public void testFormatTimestampLeadingZeroesOnMicros() { - assertThat( - BigQueryAvroUtils.formatTimestamp(1452062291000456L), - equalTo("2016-01-06 06:38:11.000456 UTC")); + public void testFormatTimestampMillis() { + long millis = 1452062291123L; + long micros = millis * 1000L; + String expected = "2016-01-06 06:38:11.123"; + assertThat(BigQueryAvroUtils.formatDatetime(micros), equalTo(expected)); + assertThat(BigQueryAvroUtils.formatTimestamp(micros), equalTo(expected + " UTC")); } @Test - public void testFormatTimestampTrailingZeroesOnMicros() { - assertThat( - BigQueryAvroUtils.formatTimestamp(1452062291123000L), - equalTo("2016-01-06 06:38:11.123000 UTC")); + public void testFormatTimestampSeconds() { + long seconds = 1452062291L; + long micros = seconds * 1000L * 1000L; + String expected = "2016-01-06 06:38:11"; + assertThat(BigQueryAvroUtils.formatDatetime(micros), equalTo(expected)); + assertThat(BigQueryAvroUtils.formatTimestamp(micros), equalTo(expected + " UTC")); } @Test public void testFormatTimestampNegative() { - assertThat(BigQueryAvroUtils.formatTimestamp(-1L), equalTo("1969-12-31 23:59:59.999999 UTC")); - assertThat( - BigQueryAvroUtils.formatTimestamp(-100_000L), equalTo("1969-12-31 23:59:59.900000 UTC")); - assertThat(BigQueryAvroUtils.formatTimestamp(-1_000_000L), equalTo("1969-12-31 23:59:59 UTC")); + assertThat(BigQueryAvroUtils.formatDatetime(-1L), equalTo("1969-12-31 23:59:59.999999")); + assertThat(BigQueryAvroUtils.formatDatetime(-100_000L), equalTo("1969-12-31 23:59:59.900")); + assertThat(BigQueryAvroUtils.formatDatetime(-1_000_000L), equalTo("1969-12-31 23:59:59")); // No leap seconds before 1972. 477 leap years from 1 through 1969. assertThat( - BigQueryAvroUtils.formatTimestamp(-(1969L * 365 + 477) * 86400 * 1_000_000), - equalTo("0001-01-01 00:00:00 UTC")); + BigQueryAvroUtils.formatDatetime(-(1969L * 365 + 477) * 86400 * 1_000_000), + equalTo("0001-01-01 00:00:00")); } @Test @@ -501,48 +816,4 @@ public void testSchemaCollisionsInAvroConversion() { String output = BigQueryAvroUtils.toGenericAvroSchema(schema, false).toString(); assertThat(output.length(), greaterThan(0)); } - - /** Pojo class used as the record type in tests. */ - @SuppressWarnings("unused") // Used by Avro reflection. - static class Bird { - long number; - @Nullable String species; - @Nullable Double quality; - @Nullable Long quantity; - - @AvroSchema(value = "[\"null\", {\"type\": \"long\", \"logicalType\": \"timestamp-micros\"}]") - Instant birthday; - - @AvroSchema( - value = - "[\"null\", {\"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 38, \"scale\": 9}]") - BigDecimal birthdayMoney; - - @AvroSchema( - value = - "[\"null\", {\"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 77, \"scale\": 38}]") - BigDecimal lotteryWinnings; - - @AvroSchema(value = "[\"null\", {\"type\": \"string\", \"sqlType\": \"GEOGRAPHY\"}]") - String geoPositions; - - @Nullable Boolean flighted; - @Nullable ByteBuffer sound; - @Nullable Utf8 anniversaryDate; - @Nullable String anniversaryDatetime; - @Nullable Utf8 anniversaryTime; - @Nullable SubBird scion; - SubBird[] associates; - - static class SubBird { - @Nullable String species; - - public SubBird() {} - } - - public Bird() { - associates = new SubBird[1]; - associates[0] = new SubBird(); - } - } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java deleted file mode 100644 index dd8bb9fc8664..000000000000 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.bigquery; - -import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryFileLoadsWriteSchemaTransformProvider.INPUT_TAG; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; - -import com.google.api.services.bigquery.model.TableReference; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import java.io.IOException; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryFileLoadsWriteSchemaTransformProvider.BigQueryWriteSchemaTransform; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; -import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; -import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; -import org.apache.beam.sdk.io.gcp.testing.FakeJobService; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.Field; -import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.io.InvalidConfigurationException; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.display.DisplayData.Identifier; -import org.apache.beam.sdk.transforms.display.DisplayData.Item; -import org.apache.beam.sdk.values.PCollectionRowTuple; -import org.apache.beam.sdk.values.Row; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Test for {@link BigQueryFileLoadsWriteSchemaTransformProvider}. */ -@RunWith(JUnit4.class) -public class BigQueryFileLoadsWriteSchemaTransformProviderTest { - - private static final String PROJECT = "fakeproject"; - private static final String DATASET = "fakedataset"; - private static final String TABLE_ID = "faketable"; - - private static final TableReference TABLE_REFERENCE = - new TableReference().setProjectId(PROJECT).setDatasetId(DATASET).setTableId(TABLE_ID); - - private static final Schema SCHEMA = - Schema.of(Field.of("name", FieldType.STRING), Field.of("number", FieldType.INT64)); - - private static final TableSchema TABLE_SCHEMA = BigQueryUtils.toTableSchema(SCHEMA); - - private static final List ROWS = - Arrays.asList( - Row.withSchema(SCHEMA).withFieldValue("name", "a").withFieldValue("number", 1L).build(), - Row.withSchema(SCHEMA).withFieldValue("name", "b").withFieldValue("number", 2L).build(), - Row.withSchema(SCHEMA).withFieldValue("name", "c").withFieldValue("number", 3L).build()); - - private static final BigQueryOptions OPTIONS = - TestPipeline.testingPipelineOptions().as(BigQueryOptions.class); - private final FakeDatasetService fakeDatasetService = new FakeDatasetService(); - private final FakeJobService fakeJobService = new FakeJobService(); - private final TemporaryFolder temporaryFolder = new TemporaryFolder(); - private final FakeBigQueryServices fakeBigQueryServices = - new FakeBigQueryServices() - .withJobService(fakeJobService) - .withDatasetService(fakeDatasetService); - - @Before - public void setUp() throws IOException, InterruptedException { - FakeDatasetService.setUp(); - fakeDatasetService.createDataset(PROJECT, DATASET, "", "", null); - temporaryFolder.create(); - OPTIONS.setProject(PROJECT); - OPTIONS.setTempLocation(temporaryFolder.getRoot().getAbsolutePath()); - } - - @After - public void tearDown() { - temporaryFolder.delete(); - } - - @Rule public transient TestPipeline p = TestPipeline.fromOptions(OPTIONS); - - @Test - public void testLoad() throws IOException, InterruptedException { - BigQueryFileLoadsWriteSchemaTransformProvider provider = - new BigQueryFileLoadsWriteSchemaTransformProvider(); - BigQueryFileLoadsWriteSchemaTransformConfiguration configuration = - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name()) - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .build(); - BigQueryWriteSchemaTransform schemaTransform = - (BigQueryWriteSchemaTransform) provider.from(configuration); - schemaTransform.setTestBigQueryServices(fakeBigQueryServices); - String tag = provider.inputCollectionNames().get(0); - PCollectionRowTuple input = - PCollectionRowTuple.of(tag, p.apply(Create.of(ROWS).withRowSchema(SCHEMA))); - input.apply(schemaTransform); - - p.run(); - - assertNotNull(fakeDatasetService.getTable(TABLE_REFERENCE)); - assertEquals(ROWS.size(), fakeDatasetService.getAllRows(PROJECT, DATASET, TABLE_ID).size()); - } - - @Test - public void testValidatePipelineOptions() { - List< - Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, - Class>> - cases = - Arrays.asList( - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec("project.doesnot.exist") - .setCreateDisposition(CreateDisposition.CREATE_NEVER.name()) - .setWriteDisposition(WriteDisposition.WRITE_APPEND.name()), - InvalidConfigurationException.class), - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(String.format("%s.%s.%s", PROJECT, DATASET, "doesnotexist")) - .setCreateDisposition(CreateDisposition.CREATE_NEVER.name()) - .setWriteDisposition(WriteDisposition.WRITE_EMPTY.name()), - InvalidConfigurationException.class), - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec("project.doesnot.exist") - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .setWriteDisposition(WriteDisposition.WRITE_APPEND.name()), - null)); - for (Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, Class> - caze : cases) { - BigQueryWriteSchemaTransform transform = transformFrom(caze.getLeft().build()); - if (caze.getRight() != null) { - assertThrows(caze.getRight(), () -> transform.validate(p.getOptions())); - } else { - transform.validate(p.getOptions()); - } - } - } - - @Test - public void testToWrite() { - List< - Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, - BigQueryIO.Write>> - cases = - Arrays.asList( - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setCreateDisposition(CreateDisposition.CREATE_NEVER.name()) - .setWriteDisposition(WriteDisposition.WRITE_EMPTY.name()), - BigQueryIO.writeTableRows() - .to(TABLE_REFERENCE) - .withCreateDisposition(CreateDisposition.CREATE_NEVER) - .withWriteDisposition(WriteDisposition.WRITE_EMPTY) - .withSchema(TABLE_SCHEMA)), - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name()), - BigQueryIO.writeTableRows() - .to(TABLE_REFERENCE) - .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) - .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE) - .withSchema(TABLE_SCHEMA))); - for (Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, BigQueryIO.Write> - caze : cases) { - BigQueryWriteSchemaTransform transform = transformFrom(caze.getLeft().build()); - Map gotDisplayData = DisplayData.from(transform.toWrite(SCHEMA)).asMap(); - Map wantDisplayData = DisplayData.from(caze.getRight()).asMap(); - Set keys = new HashSet<>(); - keys.addAll(gotDisplayData.keySet()); - keys.addAll(wantDisplayData.keySet()); - for (Identifier key : keys) { - Item got = null; - Item want = null; - if (gotDisplayData.containsKey(key)) { - got = gotDisplayData.get(key); - } - if (wantDisplayData.containsKey(key)) { - want = wantDisplayData.get(key); - } - assertEquals(want, got); - } - } - } - - @Test - public void validatePCollectionRowTupleInput() { - PCollectionRowTuple empty = PCollectionRowTuple.empty(p); - PCollectionRowTuple valid = - PCollectionRowTuple.of( - INPUT_TAG, p.apply("CreateRowsWithValidSchema", Create.of(ROWS)).setRowSchema(SCHEMA)); - - PCollectionRowTuple invalid = - PCollectionRowTuple.of( - INPUT_TAG, - p.apply( - "CreateRowsWithInvalidSchema", - Create.of( - Row.nullRow( - Schema.builder().addNullableField("name", FieldType.STRING).build())))); - - BigQueryWriteSchemaTransform transform = - transformFrom( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .setWriteDisposition(WriteDisposition.WRITE_APPEND.name()) - .build()); - - assertThrows(IllegalArgumentException.class, () -> transform.validate(empty)); - - assertThrows(IllegalStateException.class, () -> transform.validate(invalid)); - - transform.validate(valid); - - p.run(); - } - - private BigQueryWriteSchemaTransform transformFrom( - BigQueryFileLoadsWriteSchemaTransformConfiguration configuration) { - BigQueryFileLoadsWriteSchemaTransformProvider provider = - new BigQueryFileLoadsWriteSchemaTransformProvider(); - BigQueryWriteSchemaTransform transform = - (BigQueryWriteSchemaTransform) provider.from(configuration); - - transform.setTestBigQueryServices(fakeBigQueryServices); - - return transform; - } -} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java index e26348b7b478..8b65e58a4601 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java @@ -698,6 +698,18 @@ public void testToTableSchema_map() { assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE)); } + @Test + public void testToTableSchema_map_array() { + TableSchema schema = toTableSchema(MAP_ARRAY_TYPE); + + assertThat(schema.getFields().size(), equalTo(1)); + TableFieldSchema field = schema.getFields().get(0); + assertThat(field.getName(), equalTo("map")); + assertThat(field.getType(), equalTo(StandardSQLTypeName.STRUCT.toString())); + assertThat(field.getMode(), equalTo(Mode.REPEATED.toString())); + assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE)); + } + @Test public void testToTableRow_flat() { TableRow row = toTableRow().apply(FLAT_ROW); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java new file mode 100644 index 000000000000..168febea9d88 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import com.google.api.services.bigquery.model.TableReference; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryOptions; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; +import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsSchemaTransformProvider.BigQueryFileLoadsSchemaTransform; +import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; +import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; +import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.RowFilter; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test for {@link BigQueryFileLoadsSchemaTransformProvider}. */ +@RunWith(JUnit4.class) +public class BigQueryFileLoadsSchemaTransformProviderTest { + + private static final String PROJECT = "fakeproject"; + private static final String DATASET = "fakedataset"; + private static final String TABLE_ID = "faketable"; + + private static final TableReference TABLE_REFERENCE = + new TableReference().setProjectId(PROJECT).setDatasetId(DATASET).setTableId(TABLE_ID); + + private static final Schema SCHEMA = + Schema.of(Field.of("name", FieldType.STRING), Field.of("number", FieldType.INT64)); + + private static final List ROWS = + Arrays.asList( + Row.withSchema(SCHEMA).withFieldValue("name", "a").withFieldValue("number", 1L).build(), + Row.withSchema(SCHEMA).withFieldValue("name", "b").withFieldValue("number", 2L).build(), + Row.withSchema(SCHEMA).withFieldValue("name", "c").withFieldValue("number", 3L).build()); + + private static final BigQueryOptions OPTIONS = + TestPipeline.testingPipelineOptions().as(BigQueryOptions.class); + private final FakeDatasetService fakeDatasetService = new FakeDatasetService(); + private final FakeJobService fakeJobService = new FakeJobService(); + private final TemporaryFolder temporaryFolder = new TemporaryFolder(); + private final FakeBigQueryServices fakeBigQueryServices = + new FakeBigQueryServices() + .withJobService(fakeJobService) + .withDatasetService(fakeDatasetService); + + @Before + public void setUp() throws IOException, InterruptedException { + FakeDatasetService.setUp(); + fakeDatasetService.createDataset(PROJECT, DATASET, "", "", null); + temporaryFolder.create(); + OPTIONS.setProject(PROJECT); + OPTIONS.setTempLocation(temporaryFolder.getRoot().getAbsolutePath()); + } + + @After + public void tearDown() { + temporaryFolder.delete(); + } + + @Rule public transient TestPipeline p = TestPipeline.fromOptions(OPTIONS); + + @Test + public void testLoad() throws IOException, InterruptedException { + BigQueryFileLoadsSchemaTransformProvider provider = + new BigQueryFileLoadsSchemaTransformProvider(); + BigQueryWriteConfiguration configuration = + BigQueryWriteConfiguration.builder() + .setTable(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) + .setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name()) + .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) + .build(); + BigQueryFileLoadsSchemaTransform schemaTransform = + (BigQueryFileLoadsSchemaTransform) provider.from(configuration); + schemaTransform.setTestBigQueryServices(fakeBigQueryServices); + String tag = provider.inputCollectionNames().get(0); + PCollectionRowTuple input = + PCollectionRowTuple.of(tag, p.apply(Create.of(ROWS).withRowSchema(SCHEMA))); + input.apply(schemaTransform); + + p.run(); + + assertNotNull(fakeDatasetService.getTable(TABLE_REFERENCE)); + assertEquals(ROWS.size(), fakeDatasetService.getAllRows(PROJECT, DATASET, TABLE_ID).size()); + } + + @Test + public void testWriteToPortableDynamicDestinations() throws Exception { + String destinationTemplate = + String.format("%s:%s.dynamic_write_{name}_{number}", PROJECT, DATASET); + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() + .setTable(destinationTemplate) + .setDrop(Collections.singletonList("number")) + .build(); + BigQueryFileLoadsSchemaTransform write = + (BigQueryFileLoadsSchemaTransform) + new BigQueryFileLoadsSchemaTransformProvider().from(config); + write.setTestBigQueryServices(fakeBigQueryServices); + + PCollection inputRows = p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); + PCollectionRowTuple.of("input", inputRows).apply(write); + p.run().waitUntilFinish(); + + RowFilter rowFilter = new RowFilter(SCHEMA).drop(Collections.singletonList("number")); + assertEquals( + rowFilter.filter(ROWS.get(0)), + BigQueryUtils.toBeamRow( + rowFilter.outputSchema(), + fakeDatasetService.getAllRows(PROJECT, DATASET, "dynamic_write_a_1").get(0))); + assertEquals( + rowFilter.filter(ROWS.get(1)), + BigQueryUtils.toBeamRow( + rowFilter.outputSchema(), + fakeDatasetService.getAllRows(PROJECT, DATASET, "dynamic_write_b_2").get(0))); + assertEquals( + rowFilter.filter(ROWS.get(2)), + BigQueryUtils.toBeamRow( + rowFilter.outputSchema(), + fakeDatasetService.getAllRows(PROJECT, DATASET, "dynamic_write_c_3").get(0))); + } + + @Test + public void testManagedChoosesFileLoadsForBoundedWrites() { + PCollection batchInput = p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); + batchInput.apply( + Managed.write(Managed.BIGQUERY) + .withConfig(ImmutableMap.of("table", "project.dataset.table"))); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> + tr.getUniqueName() + .contains(BigQueryFileLoadsSchemaTransform.class.getSimpleName())) + .collect(Collectors.toList()); + assertThat(writeTransformProto.size(), greaterThan(0)); + p.enableAbandonedNodeEnforcement(false); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java new file mode 100644 index 000000000000..6a422f1832d8 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; +import org.apache.beam.sdk.io.gcp.testing.BigqueryClient; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PeriodicImpulse; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.RowFilter; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** This class tests the execution of {@link Managed} BigQueryIO. */ +@RunWith(JUnit4.class) +public class BigQueryManagedIT { + @Rule public TestName testName = new TestName(); + @Rule public transient TestPipeline writePipeline = TestPipeline.create(); + @Rule public transient TestPipeline readPipeline = TestPipeline.create(); + + private static final Schema SCHEMA = + Schema.of( + Schema.Field.of("str", Schema.FieldType.STRING), + Schema.Field.of("number", Schema.FieldType.INT64), + Schema.Field.of("dest", Schema.FieldType.INT64)); + + private static final SerializableFunction ROW_FUNC = + l -> Row.withSchema(SCHEMA).addValue(Long.toString(l)).addValue(l).addValue(l % 3).build(); + + private static final List ROWS = + LongStream.range(0, 20).mapToObj(ROW_FUNC::apply).collect(Collectors.toList()); + + private static final BigqueryClient BQ_CLIENT = new BigqueryClient("BigQueryManagedIT"); + + private static final String PROJECT = + TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + private static final String BIG_QUERY_DATASET_ID = "bigquery_managed_" + System.nanoTime(); + + @BeforeClass + public static void setUpTestEnvironment() throws IOException, InterruptedException { + // Create one BQ dataset for all test cases. + BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID, null); + } + + @AfterClass + public static void cleanup() { + BQ_CLIENT.deleteDataset(PROJECT, BIG_QUERY_DATASET_ID); + } + + @Test + public void testBatchFileLoadsWriteRead() { + String table = + String.format("%s.%s.%s", PROJECT, BIG_QUERY_DATASET_ID, testName.getMethodName()); + Map writeConfig = ImmutableMap.of("table", table); + + // file loads requires a GCS temp location + String tempLocation = writePipeline.getOptions().as(TestPipelineOptions.class).getTempRoot(); + writePipeline.getOptions().setTempLocation(tempLocation); + + // batch write + PCollectionRowTuple.of("input", getInput(writePipeline, false)) + .apply(Managed.write(Managed.BIGQUERY).withConfig(writeConfig)); + writePipeline.run().waitUntilFinish(); + + Map readConfig = + ImmutableMap.of("query", String.format("SELECT * FROM `%s`", table)); + // read and validate + PCollection outputRows = + readPipeline + .apply(Managed.read(Managed.BIGQUERY).withConfig(readConfig)) + .getSinglePCollection(); + PAssert.that(outputRows).containsInAnyOrder(ROWS); + readPipeline.run().waitUntilFinish(); + } + + @Test + public void testStreamingStorageWriteRead() { + String table = + String.format("%s:%s.%s", PROJECT, BIG_QUERY_DATASET_ID, testName.getMethodName()); + Map config = ImmutableMap.of("table", table); + + if (writePipeline.getOptions().getRunner().getName().contains("DataflowRunner")) { + // Need to manually enable streaming engine for legacy dataflow runner + ExperimentalOptions.addExperiment( + writePipeline.getOptions().as(ExperimentalOptions.class), + GcpOptions.STREAMING_ENGINE_EXPERIMENT); + } + + // streaming write + PCollectionRowTuple.of("input", getInput(writePipeline, true)) + .apply(Managed.write(Managed.BIGQUERY).withConfig(config)); + writePipeline.run().waitUntilFinish(); + + // read and validate + PCollection outputRows = + readPipeline + .apply(Managed.read(Managed.BIGQUERY).withConfig(config)) + .getSinglePCollection(); + PAssert.that(outputRows).containsInAnyOrder(ROWS); + readPipeline.run().waitUntilFinish(); + } + + public void testDynamicDestinations(boolean streaming) throws IOException, InterruptedException { + String baseTableName = + String.format("%s:%s.dynamic_" + System.nanoTime(), PROJECT, BIG_QUERY_DATASET_ID); + String destinationTemplate = baseTableName + "_{dest}"; + Map config = + ImmutableMap.of("table", destinationTemplate, "drop", Collections.singletonList("dest")); + + if (!streaming) { + // file loads requires a GCS temp location + String tempLocation = writePipeline.getOptions().as(TestPipelineOptions.class).getTempRoot(); + writePipeline.getOptions().setTempLocation(tempLocation); + } + + // write + PCollectionRowTuple.of("input", getInput(writePipeline, streaming)) + .apply(Managed.write(Managed.BIGQUERY).withConfig(config)); + writePipeline.run().waitUntilFinish(); + + List destinations = + Arrays.asList(baseTableName + "_0", baseTableName + "_1", baseTableName + "_2"); + + // read and validate each table destination + RowFilter rowFilter = new RowFilter(SCHEMA).drop(Collections.singletonList("dest")); + for (int i = 0; i < destinations.size(); i++) { + long mod = i; + String dest = destinations.get(i); + List writtenRows = + BQ_CLIENT + .queryUnflattened(String.format("SELECT * FROM [%s]", dest), PROJECT, true, false) + .stream() + .map(tableRow -> BigQueryUtils.toBeamRow(rowFilter.outputSchema(), tableRow)) + .collect(Collectors.toList()); + + List expectedRecords = + ROWS.stream() + .filter(row -> row.getInt64("dest") == mod) + .map(rowFilter::filter) + .collect(Collectors.toList()); + + assertThat(writtenRows, containsInAnyOrder(expectedRecords.toArray())); + } + } + + @Test + public void testStreamingDynamicDestinations() throws IOException, InterruptedException { + if (writePipeline.getOptions().getRunner().getName().contains("DataflowRunner")) { + // Need to manually enable streaming engine for legacy dataflow runner + ExperimentalOptions.addExperiment( + writePipeline.getOptions().as(ExperimentalOptions.class), + GcpOptions.STREAMING_ENGINE_EXPERIMENT); + } + testDynamicDestinations(true); + } + + @Test + public void testBatchDynamicDestinations() throws IOException, InterruptedException { + testDynamicDestinations(false); + } + + public PCollection getInput(Pipeline p, boolean isStreaming) { + if (isStreaming) { + return p.apply( + PeriodicImpulse.create() + .startAt(new Instant(0)) + .stopAt(new Instant(19)) + .withInterval(Duration.millis(1))) + .apply(MapElements.into(TypeDescriptors.rows()).via(i -> ROW_FUNC.apply(i.getMillis()))) + .setRowSchema(SCHEMA); + } + return p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java new file mode 100644 index 000000000000..822c607aa3c9 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransform; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQuerySchemaTransformTranslation.BigQueryStorageReadSchemaTransformTranslator; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQuerySchemaTransformTranslation.BigQueryWriteSchemaTransformTranslator; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteSchemaTransformProvider.BigQueryWriteSchemaTransform; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.construction.BeamUrns; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class BigQuerySchemaTransformTranslationTest { + static final BigQueryWriteSchemaTransformProvider WRITE_PROVIDER = + new BigQueryWriteSchemaTransformProvider(); + static final BigQueryDirectReadSchemaTransformProvider READ_PROVIDER = + new BigQueryDirectReadSchemaTransformProvider(); + static final Row WRITE_CONFIG_ROW = + Row.withSchema(WRITE_PROVIDER.configurationSchema()) + .withFieldValue("table", "project:dataset.table") + .withFieldValue("create_disposition", "create_never") + .withFieldValue("write_disposition", "write_append") + .withFieldValue("triggering_frequency_seconds", 5L) + .withFieldValue("use_at_least_once_semantics", false) + .withFieldValue("auto_sharding", false) + .withFieldValue("num_streams", 5) + .withFieldValue("error_handling", null) + .build(); + static final Row READ_CONFIG_ROW = + Row.withSchema(READ_PROVIDER.configurationSchema()) + .withFieldValue("query", null) + .withFieldValue("table_spec", "apache-beam-testing.samples.weather_stations") + .withFieldValue("row_restriction", "col < 5") + .withFieldValue("selected_fields", Arrays.asList("col1", "col2", "col3")) + .build(); + + @Test + public void testRecreateWriteTransformFromRow() { + BigQueryWriteSchemaTransform writeTransform = + (BigQueryWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG_ROW); + + BigQueryWriteSchemaTransformTranslator translator = + new BigQueryWriteSchemaTransformTranslator(); + Row translatedRow = translator.toConfigRow(writeTransform); + + BigQueryWriteSchemaTransform writeTransformFromRow = + translator.fromConfigRow(translatedRow, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG_ROW, writeTransformFromRow.getConfigurationRow()); + } + + @Test + public void testWriteTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + Schema inputSchema = Schema.builder().addByteArrayField("b").build(); + PCollection input = + p.apply( + Create.of( + Collections.singletonList( + Row.withSchema(inputSchema).addValue(new byte[] {1, 2, 3}).build()))) + .setRowSchema(inputSchema); + + BigQueryWriteSchemaTransform writeTransform = + (BigQueryWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG_ROW); + PCollectionRowTuple.of("input", input).apply(writeTransform); + + // Then translate the pipeline to a proto and extract KafkaWriteSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(WRITE_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, writeTransformProto.size()); + RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + + assertEquals(WRITE_CONFIG_ROW, rowFromSpec); + + // Use the information in the proto to recreate the KafkaWriteSchemaTransform + BigQueryWriteSchemaTransformTranslator translator = + new BigQueryWriteSchemaTransformTranslator(); + BigQueryWriteSchemaTransform writeTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG_ROW, writeTransformFromSpec.getConfigurationRow()); + } + + @Test + public void testReCreateReadTransformFromRow() { + BigQueryDirectReadSchemaTransform readTransform = + (BigQueryDirectReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG_ROW); + + BigQueryStorageReadSchemaTransformTranslator translator = + new BigQueryStorageReadSchemaTransformTranslator(); + Row row = translator.toConfigRow(readTransform); + + BigQueryDirectReadSchemaTransform readTransformFromRow = + translator.fromConfigRow(row, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG_ROW, readTransformFromRow.getConfigurationRow()); + } + + @Test + public void testReadTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + + BigQueryDirectReadSchemaTransform readTransform = + (BigQueryDirectReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG_ROW); + + PCollectionRowTuple.empty(p).apply(readTransform); + + // Then translate the pipeline to a proto and extract KafkaReadSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List readTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(READ_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, readTransformProto.size()); + RunnerApi.FunctionSpec spec = readTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(READ_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + assertEquals(READ_CONFIG_ROW, rowFromSpec); + + // Use the information in the proto to recreate the KafkaReadSchemaTransform + BigQueryStorageReadSchemaTransformTranslator translator = + new BigQueryStorageReadSchemaTransformTranslator(); + BigQueryDirectReadSchemaTransform readTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG_ROW, readTransformFromSpec.getConfigurationRow()); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java index 64ea0b11d1b9..584309778286 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java @@ -17,6 +17,10 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.PortableBigQueryDestinations.DESTINATION; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.PortableBigQueryDestinations.RECORD; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; @@ -30,13 +34,17 @@ import java.util.List; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransform; -import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.metrics.MetricNameFilter; import org.apache.beam.sdk.metrics.MetricQueryResults; import org.apache.beam.sdk.metrics.MetricResult; @@ -48,12 +56,17 @@ import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.util.RowFilter; +import org.apache.beam.sdk.util.construction.PipelineTranslation; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -105,15 +118,13 @@ public void setUp() throws Exception { @Test public void testInvalidConfig() { - List invalidConfigs = + List invalidConfigs = Arrays.asList( - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() - .setTable("not_a_valid_table_spec"), - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration.builder() .setTable("project:dataset.table") .setCreateDisposition("INVALID_DISPOSITION")); - for (BigQueryStorageWriteApiSchemaTransformConfiguration.Builder config : invalidConfigs) { + for (BigQueryWriteConfiguration.Builder config : invalidConfigs) { assertThrows( Exception.class, () -> { @@ -122,13 +133,11 @@ public void testInvalidConfig() { } } - public PCollectionRowTuple runWithConfig( - BigQueryStorageWriteApiSchemaTransformConfiguration config) { + public PCollectionRowTuple runWithConfig(BigQueryWriteConfiguration config) { return runWithConfig(config, ROWS); } - public PCollectionRowTuple runWithConfig( - BigQueryStorageWriteApiSchemaTransformConfiguration config, List inputRows) { + public PCollectionRowTuple runWithConfig(BigQueryWriteConfiguration config, List inputRows) { BigQueryStorageWriteApiSchemaTransformProvider provider = new BigQueryStorageWriteApiSchemaTransformProvider(); @@ -164,17 +173,14 @@ public Boolean rowsEquals(List expectedRows, List actualRows) { } public boolean rowEquals(Row expectedRow, TableRow actualRow) { - return expectedRow.getValue("name").equals(actualRow.get("name")) - && expectedRow - .getValue("number") - .equals(Long.parseLong(actualRow.get("number").toString())); + return expectedRow.equals(BigQueryUtils.toBeamRow(expectedRow.getSchema(), actualRow)); } @Test public void testSimpleWrite() throws Exception { String tableSpec = "project:dataset.simple_write"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder().setTable(tableSpec).build(); runWithConfig(config, ROWS); p.run().waitUntilFinish(); @@ -186,21 +192,21 @@ public void testSimpleWrite() throws Exception { @Test public void testWriteToDynamicDestinations() throws Exception { - String dynamic = BigQueryStorageWriteApiSchemaTransformProvider.DYNAMIC_DESTINATIONS; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(dynamic).build(); + String dynamic = BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder().setTable(dynamic).build(); String baseTableSpec = "project:dataset.dynamic_write_"; Schema schemaWithDestinations = - Schema.builder().addStringField("destination").addRowField("record", SCHEMA).build(); + Schema.builder().addStringField(DESTINATION).addRowField(RECORD, SCHEMA).build(); List rowsWithDestinations = ROWS.stream() .map( row -> Row.withSchema(schemaWithDestinations) - .withFieldValue("destination", baseTableSpec + row.getInt64("number")) - .withFieldValue("record", row) + .withFieldValue(DESTINATION, baseTableSpec + row.getInt64("number")) + .withFieldValue(RECORD, row) .build()) .collect(Collectors.toList()); @@ -221,11 +227,149 @@ public void testWriteToDynamicDestinations() throws Exception { fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_3").get(0))); } + @Test + public void testWriteToPortableDynamicDestinations() throws Exception { + String destinationTemplate = "project:dataset.dynamic_write_{name}_{number}"; + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() + .setTable(destinationTemplate) + .setKeep(Arrays.asList("number", "dt")) + .build(); + + runWithConfig(config); + p.run().waitUntilFinish(); + + RowFilter rowFilter = new RowFilter(SCHEMA).keep(Arrays.asList("number", "dt")); + assertTrue( + rowEquals( + rowFilter.filter(ROWS.get(0)), + fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_a_1").get(0))); + assertTrue( + rowEquals( + rowFilter.filter(ROWS.get(1)), + fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_b_2").get(0))); + assertTrue( + rowEquals( + rowFilter.filter(ROWS.get(2)), + fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_c_3").get(0))); + } + + List createCDCUpsertRows(List rows, boolean dynamicDestination, String tablePrefix) { + + Schema.Builder schemaBuilder = + Schema.builder() + .addRowField(RECORD, SCHEMA) + .addRowField( + BigQueryStorageWriteApiSchemaTransformProvider.ROW_PROPERTY_MUTATION_INFO, + BigQueryStorageWriteApiSchemaTransformProvider.ROW_SCHEMA_MUTATION_INFO); + + if (dynamicDestination) { + schemaBuilder = schemaBuilder.addStringField(DESTINATION); + } + + Schema schemaWithCDC = schemaBuilder.build(); + return IntStream.range(0, rows.size()) + .mapToObj( + idx -> { + Row row = rows.get(idx); + Row.FieldValueBuilder rowBuilder = + Row.withSchema(schemaWithCDC) + .withFieldValue( + BigQueryStorageWriteApiSchemaTransformProvider.ROW_PROPERTY_MUTATION_INFO, + Row.withSchema( + BigQueryStorageWriteApiSchemaTransformProvider + .ROW_SCHEMA_MUTATION_INFO) + .withFieldValue( + BigQueryStorageWriteApiSchemaTransformProvider + .ROW_PROPERTY_MUTATION_TYPE, + "UPSERT") + .withFieldValue( + BigQueryStorageWriteApiSchemaTransformProvider + .ROW_PROPERTY_MUTATION_SQN, + "AAA" + idx) + .build()) + .withFieldValue(RECORD, row); + if (dynamicDestination) { + rowBuilder = + rowBuilder.withFieldValue(DESTINATION, tablePrefix + row.getInt64("number")); + } + return rowBuilder.build(); + }) + .collect(Collectors.toList()); + } + + @Test + public void testCDCWrites() throws Exception { + String tableSpec = "project:dataset.cdc_write"; + List primaryKeyColumns = ImmutableList.of("name"); + + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() + .setUseAtLeastOnceSemantics(true) + .setTable(tableSpec) + .setUseCdcWrites(true) + .setPrimaryKey(primaryKeyColumns) + .build(); + + List rowsDuplicated = + Stream.concat(ROWS.stream(), ROWS.stream()).collect(Collectors.toList()); + + runWithConfig(config, createCDCUpsertRows(rowsDuplicated, false, "")); + p.run().waitUntilFinish(); + + assertTrue( + rowEquals( + rowsDuplicated.get(3), + fakeDatasetService.getAllRows("project", "dataset", "cdc_write").get(0))); + assertTrue( + rowEquals( + rowsDuplicated.get(4), + fakeDatasetService.getAllRows("project", "dataset", "cdc_write").get(1))); + assertTrue( + rowEquals( + rowsDuplicated.get(5), + fakeDatasetService.getAllRows("project", "dataset", "cdc_write").get(2))); + } + + @Test + public void testCDCWriteToDynamicDestinations() throws Exception { + List primaryKeyColumns = ImmutableList.of("name"); + String dynamic = BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() + .setUseAtLeastOnceSemantics(true) + .setTable(dynamic) + .setUseCdcWrites(true) + .setPrimaryKey(primaryKeyColumns) + .build(); + + String baseTableSpec = "project:dataset.cdc_dynamic_write_"; + + List rowsDuplicated = + Stream.concat(ROWS.stream(), ROWS.stream()).collect(Collectors.toList()); + + runWithConfig(config, createCDCUpsertRows(rowsDuplicated, true, baseTableSpec)); + p.run().waitUntilFinish(); + + assertTrue( + rowEquals( + rowsDuplicated.get(3), + fakeDatasetService.getAllRows("project", "dataset", "cdc_dynamic_write_1").get(0))); + assertTrue( + rowEquals( + rowsDuplicated.get(4), + fakeDatasetService.getAllRows("project", "dataset", "cdc_dynamic_write_2").get(0))); + assertTrue( + rowEquals( + rowsDuplicated.get(5), + fakeDatasetService.getAllRows("project", "dataset", "cdc_dynamic_write_3").get(0))); + } + @Test public void testInputElementCount() throws Exception { String tableSpec = "project:dataset.input_count"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder().setTable(tableSpec).build(); runWithConfig(config); PipelineResult result = p.run(); @@ -254,13 +398,11 @@ public void testInputElementCount() throws Exception { @Test public void testFailedRows() throws Exception { String tableSpec = "project:dataset.write_with_fail"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setTable(tableSpec) .setErrorHandling( - BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder() - .setOutput("FailedRows") - .build()) + BigQueryWriteConfiguration.ErrorHandling.builder().setOutput("FailedRows").build()) .build(); String failValue = "fail_me"; @@ -292,7 +434,6 @@ public void testFailedRows() throws Exception { MapElements.into(TypeDescriptors.rows()) .via((rowAndError) -> rowAndError.getValue("failed_row"))) .setRowSchema(SCHEMA); - ; PAssert.that(failedRows).containsInAnyOrder(expectedFailedRows); p.run().waitUntilFinish(); @@ -307,13 +448,11 @@ public void testFailedRows() throws Exception { @Test public void testErrorCount() throws Exception { String tableSpec = "project:dataset.error_count"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setTable(tableSpec) .setErrorHandling( - BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder() - .setOutput("FailedRows") - .build()) + BigQueryWriteConfiguration.ErrorHandling.builder().setOutput("FailedRows").build()) .build(); Function shouldFailRow = @@ -343,4 +482,24 @@ public void testErrorCount() throws Exception { assertEquals(expectedCount, count.getAttempted()); } } + + @Test + public void testManagedChoosesStorageApiForUnboundedWrites() { + PCollection batchInput = + p.apply(TestStream.create(SCHEMA).addElements(ROWS.get(0)).advanceWatermarkToInfinity()); + batchInput.apply( + Managed.write(Managed.BIGQUERY) + .withConfig(ImmutableMap.of("table", "project.dataset.table"))); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> + tr.getUniqueName() + .contains(BigQueryStorageWriteApiSchemaTransform.class.getSimpleName())) + .collect(Collectors.toList()); + assertThat(writeTransformProto.size(), greaterThan(0)); + p.enableAbandonedNodeEnforcement(false); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java index d4fcf6153e47..73328afb397b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java @@ -137,7 +137,7 @@ public final void attemptsExhaustedForRetryableError() throws Exception { FlushBuffer> flushBuffer = spy(newFlushBuffer(rpcQosOptions)); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); when(flushBuffer.offer(element1)).thenReturn(true); when(flushBuffer.iterator()).thenReturn(newArrayList(element1).iterator()); when(flushBuffer.getBufferedElementsCount()).thenReturn(1); @@ -224,7 +224,7 @@ public final void endToEnd_success() throws Exception { FlushBuffer> flushBuffer = spy(newFlushBuffer(options)); when(processContext.element()).thenReturn(write); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(BatchWriteRequest.class); when(callable.call(requestCaptor.capture())).thenReturn(response); @@ -267,7 +267,7 @@ public final void endToEnd_exhaustingAttemptsResultsInException() throws Excepti FlushBuffer> flushBuffer = spy(newFlushBuffer(rpcQosOptions)); when(processContext.element()).thenReturn(write); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); when(flushBuffer.isFull()).thenReturn(true); when(flushBuffer.offer(element1)).thenReturn(true); when(flushBuffer.iterator()).thenReturn(newArrayList(element1).iterator()); @@ -324,14 +324,14 @@ public final void endToEnd_awaitSafeToProceed_falseIsTerminalForAttempt() throws when(attempt2.awaitSafeToProceed(any())) .thenReturn(true) .thenThrow(new IllegalStateException("too many attempt2#awaitSafeToProceed")); - when(attempt2.>newFlushBuffer(any())) + when(attempt2.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); // finish bundle attempt RpcQos.RpcWriteAttempt finishBundleAttempt = mock(RpcWriteAttempt.class); when(finishBundleAttempt.awaitSafeToProceed(any())) .thenReturn(true, true) .thenThrow(new IllegalStateException("too many finishBundleAttempt#awaitSafeToProceed")); - when(finishBundleAttempt.>newFlushBuffer(any())) + when(finishBundleAttempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); when(rpcQos.newWriteAttempt(any())).thenReturn(attempt, attempt2, finishBundleAttempt); when(callable.call(requestCaptor.capture())).thenReturn(response); @@ -519,20 +519,15 @@ public final void endToEnd_maxBatchSizeRespected() throws Exception { when(attempt.awaitSafeToProceed(any())).thenReturn(true); when(attempt2.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(enqueue0)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue1)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue2)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue3)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue4)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(enqueue0)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue1)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue2)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue3)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue4)).thenReturn(flushBuffer); when(callable.call(expectedGroup1Request)).thenReturn(group1Response); - when(attempt2.>newFlushBuffer(enqueue5)) - .thenReturn(newFlushBuffer(options)); - when(attempt2.>newFlushBuffer(finalFlush)).thenReturn(flushBuffer2); + when(attempt2.>newFlushBuffer(enqueue5)).thenReturn(newFlushBuffer(options)); + when(attempt2.>newFlushBuffer(finalFlush)).thenReturn(flushBuffer2); when(callable.call(expectedGroup2Request)).thenReturn(group2Response); runFunction( @@ -603,7 +598,7 @@ public final void endToEnd_partialSuccessReturnsWritesToQueue() throws Exception when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); when(attempt.isCodeRetryable(Code.INVALID_ARGUMENT)).thenReturn(true); when(attempt.isCodeRetryable(Code.FAILED_PRECONDITION)).thenReturn(true); @@ -673,9 +668,9 @@ public final void writesRemainInQueueWhenFlushIsNotReadyAndThenFlushesInFinishBu .thenThrow(new IllegalStateException("too many attempt calls")); when(attempt.awaitSafeToProceed(any())).thenReturn(true); when(attempt2.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); - when(attempt2.>newFlushBuffer(any())) + when(attempt2.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); FnT fn = getFn(clock, ff, options, CounterFactory.DEFAULT, DistributionFactory.DEFAULT); @@ -723,7 +718,7 @@ public final void queuedWritesMaintainPriorityIfNotFlushed() throws Exception { when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); FnT fn = getFn(clock, ff, options, CounterFactory.DEFAULT, DistributionFactory.DEFAULT); @@ -779,7 +774,7 @@ protected final void processElementsAndFinishBundle(FnT fn, int processElementCo } } - protected FlushBufferImpl> newFlushBuffer(RpcQosOptions options) { + protected FlushBufferImpl> newFlushBuffer(RpcQosOptions options) { return new FlushBufferImpl<>(options.getBatchMaxCount(), options.getBatchMaxBytes()); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java index d59b9354bd8b..2948be7658a9 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java @@ -177,7 +177,7 @@ public void nonRetryableWriteIsOutput() throws Exception { when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenReturn(newFlushBuffer(options)) .thenReturn(newFlushBuffer(options)) .thenThrow(new IllegalStateException("too many attempt#newFlushBuffer calls")); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java index 9acc3707e3ba..70c4ce5046a5 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java @@ -190,7 +190,7 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception { when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenReturn(newFlushBuffer(options)) .thenReturn(newFlushBuffer(options)) .thenThrow(new IllegalStateException("too many attempt#newFlushBuffer calls")); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java index bbf3e135e43f..7e24888ace43 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java @@ -236,7 +236,7 @@ private void safeToProceedAndWithBudgetAndWrite( assertTrue( msg(description, t, "awaitSafeToProceed was false, expected true"), attempt.awaitSafeToProceed(t)); - FlushBufferImpl> buffer = attempt.newFlushBuffer(t); + FlushBufferImpl> buffer = attempt.newFlushBuffer(t); assertEquals( msg(description, t, "unexpected batchMaxCount"), expectedBatchMaxCount, diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java index 2f3724d6bae7..9dff65bf2f63 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java @@ -455,7 +455,7 @@ public void offerOfElementWhichWouldCrossMaxBytesReturnFalse() { @Test public void flushBuffer_doesNotErrorWhenMaxIsOne() { - FlushBufferImpl> buffer = new FlushBufferImpl<>(1, 1000); + FlushBufferImpl> buffer = new FlushBufferImpl<>(1, 1000); assertTrue(buffer.offer(new FixedSerializationSize<>("a", 1))); assertFalse(buffer.offer(new FixedSerializationSize<>("b", 1))); assertEquals(1, buffer.getBufferedElementsCount()); @@ -463,7 +463,7 @@ public void flushBuffer_doesNotErrorWhenMaxIsOne() { @Test public void flushBuffer_doesNotErrorWhenMaxIsZero() { - FlushBufferImpl> buffer = new FlushBufferImpl<>(0, 1000); + FlushBufferImpl> buffer = new FlushBufferImpl<>(0, 1000); assertFalse(buffer.offer(new FixedSerializationSize<>("a", 1))); assertEquals(0, buffer.getBufferedElementsCount()); assertFalse(buffer.isFull()); @@ -703,7 +703,7 @@ private void doTest_initialBatchSizeRelativeToWorkerCount( .build(); RpcQosImpl qos = new RpcQosImpl(options, random, sleeper, counterFactory, distributionFactory); RpcWriteAttemptImpl attempt = qos.newWriteAttempt(RPC_ATTEMPT_CONTEXT); - FlushBufferImpl> buffer = attempt.newFlushBuffer(Instant.EPOCH); + FlushBufferImpl> buffer = attempt.newFlushBuffer(Instant.EPOCH); assertEquals(expectedBatchMaxCount, buffer.nextBatchMaxCount); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClientTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClientTest.java index 3724e169c612..6c4625f2e077 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClientTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClientTest.java @@ -40,6 +40,7 @@ import com.google.pubsub.v1.Topic; import io.grpc.ManagedChannel; import io.grpc.Server; +import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; @@ -432,4 +433,43 @@ public void getSchema(GetSchemaRequest request, StreamObserver responseO server.shutdownNow(); } } + + @Test + public void isTopicExists() throws IOException { + initializeClient(null, null); + TopicPath topicDoesNotExist = + PubsubClient.topicPathFromPath("projects/testProject/topics/dontexist"); + TopicPath topicExists = PubsubClient.topicPathFromPath("projects/testProject/topics/exist"); + + PublisherImplBase publisherImplBase = + new PublisherImplBase() { + @Override + public void getTopic(GetTopicRequest request, StreamObserver responseObserver) { + String topicPath = request.getTopic(); + if (topicPath.equals(topicDoesNotExist.getPath())) { + responseObserver.onError( + new StatusRuntimeException(Status.fromCode(Status.Code.NOT_FOUND))); + } + if (topicPath.equals(topicExists.getPath())) { + responseObserver.onNext( + Topic.newBuilder() + .setName(topicPath) + .setSchemaSettings( + SchemaSettings.newBuilder().setSchema(SCHEMA.getPath()).build()) + .build()); + responseObserver.onCompleted(); + } + } + }; + Server server = + InProcessServerBuilder.forName(channelName).addService(publisherImplBase).build().start(); + try { + assertEquals(false, client.isTopicExists(topicDoesNotExist)); + + assertEquals(true, client.isTopicExists(topicExists)); + + } finally { + server.shutdownNow(); + } + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java index d4effbae40a4..bec157ae83cc 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java @@ -83,6 +83,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.commons.lang3.RandomStringUtils; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; @@ -97,6 +98,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.junit.runners.model.Statement; +import org.mockito.Mockito; /** Tests for PubsubIO Read and Write transforms. */ @RunWith(JUnit4.class) @@ -928,4 +930,172 @@ public void testBigMessageBounded() throws IOException { pipeline.run(); } } + + @Test + public void testReadValidate() throws IOException { + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath existingTopic = PubsubClient.topicPathFromName("test-project", "testTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(existingTopic)).thenReturn(true); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + Read read = + Read.newBuilder() + .setTopicProvider( + StaticValueProvider.of( + PubsubIO.PubsubTopic.fromPath("projects/test-project/topics/testTopic"))) + .setTimestampAttribute("myTimestamp") + .setIdAttribute("myId") + .setPubsubClientFactory(mockFactory) + .setCoder(PubsubMessagePayloadOnlyCoder.of()) + .setValidate(true) + .build(); + + read.validate(options); + } + + @Test + public void testReadValidateTopicIsNotExists() throws Exception { + thrown.expect(IllegalArgumentException.class); + + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath nonExistingTopic = PubsubClient.topicPathFromName("test-project", "nonExistingTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(nonExistingTopic)).thenReturn(false); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + Read read = + Read.newBuilder() + .setTopicProvider( + StaticValueProvider.of( + PubsubIO.PubsubTopic.fromPath("projects/test-project/topics/nonExistingTopic"))) + .setTimestampAttribute("myTimestamp") + .setIdAttribute("myId") + .setPubsubClientFactory(mockFactory) + .setCoder(PubsubMessagePayloadOnlyCoder.of()) + .setValidate(true) + .build(); + + read.validate(options); + } + + @Test + public void testReadWithoutValidation() throws IOException { + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath nonExistingTopic = PubsubClient.topicPathFromName("test-project", "nonExistingTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(nonExistingTopic)).thenReturn(false); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + Read read = + PubsubIO.readMessages().fromTopic("projects/test-project/topics/nonExistingTopic"); + + read.validate(options); + } + + @Test + public void testWriteTopicValidationSuccess() throws Exception { + PubsubIO.writeStrings().to("projects/my-project/topics/abc"); + PubsubIO.writeStrings().to("projects/my-project/topics/ABC"); + PubsubIO.writeStrings().to("projects/my-project/topics/AbC-DeF"); + PubsubIO.writeStrings().to("projects/my-project/topics/AbC-1234"); + PubsubIO.writeStrings().to("projects/my-project/topics/AbC-1234-_.~%+-_.~%+-_.~%+-abc"); + PubsubIO.writeStrings() + .to( + new StringBuilder() + .append("projects/my-project/topics/A-really-long-one-") + .append(RandomStringUtils.randomAlphanumeric(100)) + .toString()); + } + + @Test + public void testWriteTopicValidationBadCharacter() throws Exception { + thrown.expect(IllegalArgumentException.class); + PubsubIO.writeStrings().to("projects/my-project/topics/abc-*-abc"); + } + + @Test + public void testWriteValidationTooLong() throws Exception { + thrown.expect(IllegalArgumentException.class); + PubsubIO.writeStrings() + .to( + new StringBuilder() + .append("projects/my-project/topics/A-really-long-one-") + .append(RandomStringUtils.randomAlphanumeric(1000)) + .toString()); + } + + @Test + public void testWriteValidate() throws IOException { + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath existingTopic = PubsubClient.topicPathFromName("test-project", "testTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(existingTopic)).thenReturn(true); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + PubsubIO.Write write = + PubsubIO.Write.newBuilder() + .setTopicProvider( + StaticValueProvider.of( + PubsubIO.PubsubTopic.fromPath("projects/test-project/topics/testTopic"))) + .setTimestampAttribute("myTimestamp") + .setIdAttribute("myId") + .setDynamicDestinations(false) + .setPubsubClientFactory(mockFactory) + .setValidate(true) + .build(); + + write.validate(options); + } + + @Test + public void testWriteValidateTopicIsNotExists() throws Exception { + thrown.expect(IllegalArgumentException.class); + + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath nonExistingTopic = PubsubClient.topicPathFromName("test-project", "nonExistingTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(nonExistingTopic)).thenReturn(false); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + PubsubIO.Write write = + PubsubIO.Write.newBuilder() + .setTopicProvider( + StaticValueProvider.of( + PubsubIO.PubsubTopic.fromPath("projects/test-project/topics/nonExistingTopic"))) + .setTimestampAttribute("myTimestamp") + .setIdAttribute("myId") + .setDynamicDestinations(false) + .setPubsubClientFactory(mockFactory) + .setValidate(true) + .build(); + + write.validate(options); + } + + @Test + public void testWithoutValidation() throws IOException { + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath nonExistingTopic = PubsubClient.topicPathFromName("test-project", "nonExistingTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(nonExistingTopic)).thenReturn(false); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + PubsubIO.Write write = + PubsubIO.writeMessages().to("projects/test-project/topics/nonExistingTopic"); + + write.validate(options); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClientTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClientTest.java index 634ad42c937a..49681c86257e 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClientTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClientTest.java @@ -23,6 +23,10 @@ import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.when; +import com.google.api.client.googleapis.json.GoogleJsonError; +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.http.HttpHeaders; +import com.google.api.client.http.HttpResponseException; import com.google.api.services.pubsub.Pubsub; import com.google.api.services.pubsub.Pubsub.Projects.Subscriptions; import com.google.api.services.pubsub.Pubsub.Projects.Topics; @@ -425,4 +429,24 @@ public void getProtoSchema() throws IOException { IllegalArgumentException.class, () -> client.getSchema(SCHEMA)); } + + @Test + public void isTopicExists() throws Exception { + TopicPath topicExists = + PubsubClient.topicPathFromPath("projects/testProject/topics/topicExists"); + TopicPath topicDoesNotExist = + PubsubClient.topicPathFromPath("projects/testProject/topics/topicDoesNotExist"); + HttpResponseException.Builder builder = + new HttpResponseException.Builder(404, "topic is not found", new HttpHeaders()); + GoogleJsonError error = new GoogleJsonError(); + when(mockPubsub.projects().topics().get(topicExists.getPath()).execute()) + .thenReturn(new Topic().setName(topicExists.getName())); + when(mockPubsub.projects().topics().get(topicDoesNotExist.getPath()).execute()) + .thenThrow(new GoogleJsonResponseException(builder, error)); + + client = new PubsubJsonClient(null, null, mockPubsub); + + assertEquals(true, client.isTopicExists(topicExists)); + assertEquals(false, client.isTopicExists(topicDoesNotExist)); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java index 70105f820536..b80fba31d3a2 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessorTest.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.io.gcp.spanner; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -164,6 +163,5 @@ public void testBuildSpannerOptionsWithCredential() { assertEquals("project", options.getProjectId()); assertEquals("test-role", options.getDatabaseRole()); assertEquals(testCredential, options.getCredentials()); - assertNotNull(options.getSessionPoolOptions()); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java index 3752c2fb3afc..02b9d111583b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java @@ -33,7 +33,9 @@ import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SpannerException; import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Iterator; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -58,6 +60,8 @@ public class PartitionMetadataAdminDaoTest { private static final String DATABASE_ID = "SPANNER_DATABASE"; private static final String TABLE_NAME = "SPANNER_TABLE"; + private static final String WATERMARK_INDEX_NAME = "WATERMARK_INDEX"; + private static final String CREATED_AT_INDEX_NAME = "CREATED_AT_INDEX"; private static final int TIMEOUT_MINUTES = 10; @@ -68,12 +72,14 @@ public class PartitionMetadataAdminDaoTest { @Before public void setUp() { databaseAdminClient = mock(DatabaseAdminClient.class); + PartitionMetadataTableNames names = + new PartitionMetadataTableNames(TABLE_NAME, WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME); partitionMetadataAdminDao = new PartitionMetadataAdminDao( - databaseAdminClient, INSTANCE_ID, DATABASE_ID, TABLE_NAME, Dialect.GOOGLE_STANDARD_SQL); + databaseAdminClient, INSTANCE_ID, DATABASE_ID, names, Dialect.GOOGLE_STANDARD_SQL); partitionMetadataAdminDaoPostgres = new PartitionMetadataAdminDao( - databaseAdminClient, INSTANCE_ID, DATABASE_ID, TABLE_NAME, Dialect.POSTGRESQL); + databaseAdminClient, INSTANCE_ID, DATABASE_ID, names, Dialect.POSTGRESQL); op = (OperationFuture) mock(OperationFuture.class); statements = ArgumentCaptor.forClass(Iterable.class); when(databaseAdminClient.updateDatabaseDdl( @@ -89,9 +95,9 @@ public void testCreatePartitionMetadataTable() throws Exception { .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); Iterator it = statements.getValue().iterator(); - assertTrue(it.next().contains("CREATE TABLE")); - assertTrue(it.next().contains("CREATE INDEX")); - assertTrue(it.next().contains("CREATE INDEX")); + assertTrue(it.next().contains("CREATE TABLE IF NOT EXISTS")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS")); } @Test @@ -102,9 +108,9 @@ public void testCreatePartitionMetadataTablePostgres() throws Exception { .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); Iterator it = statements.getValue().iterator(); - assertTrue(it.next().contains("CREATE TABLE \"")); - assertTrue(it.next().contains("CREATE INDEX \"")); - assertTrue(it.next().contains("CREATE INDEX \"")); + assertTrue(it.next().contains("CREATE TABLE IF NOT EXISTS \"")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS \"")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS \"")); } @Test @@ -133,7 +139,8 @@ public void testCreatePartitionMetadataTableWithInterruptedException() throws Ex @Test public void testDeletePartitionMetadataTable() throws Exception { when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); verify(databaseAdminClient, times(1)) .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); @@ -143,10 +150,22 @@ public void testDeletePartitionMetadataTable() throws Exception { assertTrue(it.next().contains("DROP TABLE")); } + @Test + public void testDeletePartitionMetadataTableWithNoIndexes() throws Exception { + when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); + partitionMetadataAdminDao.deletePartitionMetadataTable(Collections.emptyList()); + verify(databaseAdminClient, times(1)) + .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); + assertEquals(1, ((Collection) statements.getValue()).size()); + Iterator it = statements.getValue().iterator(); + assertTrue(it.next().contains("DROP TABLE")); + } + @Test public void testDeletePartitionMetadataTablePostgres() throws Exception { when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); - partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable(); + partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); verify(databaseAdminClient, times(1)) .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); @@ -156,11 +175,23 @@ public void testDeletePartitionMetadataTablePostgres() throws Exception { assertTrue(it.next().contains("DROP TABLE \"")); } + @Test + public void testDeletePartitionMetadataTablePostgresWithNoIndexes() throws Exception { + when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); + partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable(Collections.emptyList()); + verify(databaseAdminClient, times(1)) + .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); + assertEquals(1, ((Collection) statements.getValue()).size()); + Iterator it = statements.getValue().iterator(); + assertTrue(it.next().contains("DROP TABLE \"")); + } + @Test public void testDeletePartitionMetadataTableWithTimeoutException() throws Exception { when(op.get(10, TimeUnit.MINUTES)).thenThrow(new TimeoutException(TIMED_OUT)); try { - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); fail(); } catch (SpannerException e) { assertTrue(e.getMessage().contains(TIMED_OUT)); @@ -171,7 +202,8 @@ public void testDeletePartitionMetadataTableWithTimeoutException() throws Except public void testDeletePartitionMetadataTableWithInterruptedException() throws Exception { when(op.get(10, TimeUnit.MINUTES)).thenThrow(new InterruptedException(INTERRUPTED)); try { - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); fail(); } catch (SpannerException e) { assertEquals(ErrorCode.CANCELLED, e.getErrorCode()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java new file mode 100644 index 000000000000..2aae5b26a2cb --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.dao; + +import static org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.PartitionMetadataTableNames.MAX_NAME_LENGTH; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class PartitionMetadataTableNamesTest { + @Test + public void testGeneratePartitionMetadataNamesRemovesHyphens() { + String databaseId = "my-database-id-12345"; + + PartitionMetadataTableNames names1 = PartitionMetadataTableNames.generateRandom(databaseId); + assertFalse(names1.getTableName().contains("-")); + assertFalse(names1.getWatermarkIndexName().contains("-")); + assertFalse(names1.getCreatedAtIndexName().contains("-")); + + PartitionMetadataTableNames names2 = PartitionMetadataTableNames.generateRandom(databaseId); + assertNotEquals(names1.getTableName(), names2.getTableName()); + assertNotEquals(names1.getWatermarkIndexName(), names2.getWatermarkIndexName()); + assertNotEquals(names1.getCreatedAtIndexName(), names2.getCreatedAtIndexName()); + } + + @Test + public void testGeneratePartitionMetadataNamesIsShorterThan64Characters() { + PartitionMetadataTableNames names = + PartitionMetadataTableNames.generateRandom( + "my-database-id-larger-than-maximum-length-1234567890-1234567890-1234567890"); + assertTrue(names.getTableName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getWatermarkIndexName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getCreatedAtIndexName().length() <= MAX_NAME_LENGTH); + + names = PartitionMetadataTableNames.generateRandom("d"); + assertTrue(names.getTableName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getWatermarkIndexName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getCreatedAtIndexName().length() <= MAX_NAME_LENGTH); + } + + @Test + public void testPartitionMetadataNamesFromExistingTable() { + PartitionMetadataTableNames names1 = + PartitionMetadataTableNames.fromExistingTable("databaseid", "mytable"); + assertEquals("mytable", names1.getTableName()); + assertFalse(names1.getWatermarkIndexName().contains("-")); + assertFalse(names1.getCreatedAtIndexName().contains("-")); + + PartitionMetadataTableNames names2 = + PartitionMetadataTableNames.fromExistingTable("databaseid", "mytable"); + assertEquals("mytable", names2.getTableName()); + assertNotEquals(names1.getWatermarkIndexName(), names2.getWatermarkIndexName()); + assertNotEquals(names1.getCreatedAtIndexName(), names2.getCreatedAtIndexName()); + } +} diff --git a/sdks/java/io/hadoop-common/build.gradle b/sdks/java/io/hadoop-common/build.gradle index 466aa8fb6730..4375001ffa81 100644 --- a/sdks/java/io/hadoop-common/build.gradle +++ b/sdks/java/io/hadoop-common/build.gradle @@ -25,10 +25,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Hadoop Common" ext.summary = "Library to add shared Hadoop classes among Beam IOs." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + // "341": "3.4.1", // tests already exercised on the default version ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hadoop-file-system/build.gradle b/sdks/java/io/hadoop-file-system/build.gradle index 3fc872bb5d02..b4ebbfa08c5e 100644 --- a/sdks/java/io/hadoop-file-system/build.gradle +++ b/sdks/java/io/hadoop-file-system/build.gradle @@ -26,10 +26,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Hadoop File System" ext.summary = "Library to read and write Hadoop/HDFS file formats from Beam." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + // "341": "3.4.1", // tests already exercised on the default version ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hadoop-format/build.gradle b/sdks/java/io/hadoop-format/build.gradle index dbb9f8fdd73d..73fc44a0f311 100644 --- a/sdks/java/io/hadoop-format/build.gradle +++ b/sdks/java/io/hadoop-format/build.gradle @@ -30,10 +30,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Hadoop Format" ext.summary = "IO to read data from sources and to write data to sinks that implement Hadoop MapReduce Format." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + // "341": "3.4.1", // tests already exercised on the default version ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hbase/build.gradle b/sdks/java/io/hbase/build.gradle index d85c0fc610bb..07014f2d5e3b 100644 --- a/sdks/java/io/hbase/build.gradle +++ b/sdks/java/io/hbase/build.gradle @@ -34,7 +34,7 @@ test { jvmArgs "-Dtest.build.data.basedirectory=build/test-data" } -def hbase_version = "2.5.5" +def hbase_version = "2.6.1-hadoop3" dependencies { implementation library.java.vendored_guava_32_1_2_jre @@ -46,12 +46,7 @@ dependencies { testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation library.java.junit testImplementation library.java.hamcrest - testImplementation library.java.hadoop_minicluster - testImplementation library.java.hadoop_hdfs - testImplementation library.java.hadoop_common + // shaded-testing-utils has shaded all Hadoop/HBase dependencies testImplementation("org.apache.hbase:hbase-shaded-testing-util:$hbase_version") - testImplementation "org.apache.hbase:hbase-hadoop-compat:$hbase_version:tests" - testImplementation "org.apache.hbase:hbase-hadoop2-compat:$hbase_version:tests" testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") } - diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle index c4f1b76ec390..d07904f3465e 100644 --- a/sdks/java/io/hcatalog/build.gradle +++ b/sdks/java/io/hcatalog/build.gradle @@ -30,9 +30,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: HCatalog" ext.summary = "IO to read and write for HCatalog source." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", + "324": "3.2.4", + "336": "3.3.6", + // "341": "3.4.1", // tests already exercised on the default version ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} @@ -70,13 +71,21 @@ dependencies { testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") hadoopVersions.each {kv -> "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-common:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-hdfs:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-hdfs-client:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-mapreduce-client-core:$kv.value" } } hadoopVersions.each {kv -> configurations."hadoopVersion$kv.key" { resolutionStrategy { + force "org.apache.hadoop:hadoop-client:$kv.value" force "org.apache.hadoop:hadoop-common:$kv.value" + force "org.apache.hadoop:hadoop-mapreduce-client-core:$kv.value" + force "org.apache.hadoop:hadoop-minicluster:$kv.value" + force "org.apache.hadoop:hadoop-hdfs:$kv.value" + force "org.apache.hadoop:hadoop-hdfs-client:$kv.value" } } } diff --git a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatToRow.java b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatToRow.java index 8e29650f3fc3..e5bdf18ecbcf 100644 --- a/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatToRow.java +++ b/sdks/java/io/hcatalog/src/main/java/org/apache/beam/sdk/io/hcatalog/HCatToRow.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.hcatalog; +import java.util.List; +import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; @@ -25,6 +27,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; import org.apache.hive.hcatalog.data.HCatRecord; +import org.joda.time.Instant; /** Utilities to convert {@link HCatRecord HCatRecords} to {@link Row Rows}. */ @SuppressWarnings({ @@ -74,6 +77,18 @@ public PCollection expand(PBegin input) { private static class HCatToRowFn extends DoFn { private final Schema schema; + private Object maybeCastHDate(Object obj) { + if (obj instanceof org.apache.hadoop.hive.common.type.Date) { + return new Instant(((org.apache.hadoop.hive.common.type.Date) obj).toEpochMilli()); + } + return obj; + } + + /** Cast objects of the types that aren't supported by {@link Row}. */ + private List castTypes(List values) { + return values.stream().map(this::maybeCastHDate).collect(Collectors.toList()); + } + HCatToRowFn(Schema schema) { this.schema = schema; } @@ -81,7 +96,7 @@ private static class HCatToRowFn extends DoFn { @ProcessElement public void processElement(ProcessContext c) { HCatRecord hCatRecord = c.element(); - c.output(Row.withSchema(schema).addValues(hCatRecord.getAll()).build()); + c.output(Row.withSchema(schema).addValues(castTypes(hCatRecord.getAll())).build()); } } } diff --git a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java index 4bb7e1bd7044..3d97a2ccc1d9 100644 --- a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java +++ b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/HCatalogIOTest.java @@ -22,6 +22,7 @@ import static org.apache.beam.sdk.io.hcatalog.test.HCatalogIOTestUtils.TEST_RECORDS_COUNT; import static org.apache.beam.sdk.io.hcatalog.test.HCatalogIOTestUtils.TEST_TABLE; import static org.apache.beam.sdk.io.hcatalog.test.HCatalogIOTestUtils.buildHCatRecords; +import static org.apache.beam.sdk.io.hcatalog.test.HCatalogIOTestUtils.buildHCatRecordsWithDate; import static org.apache.beam.sdk.io.hcatalog.test.HCatalogIOTestUtils.getConfigPropertiesAsMap; import static org.apache.beam.sdk.io.hcatalog.test.HCatalogIOTestUtils.getExpectedRecords; import static org.apache.beam.sdk.io.hcatalog.test.HCatalogIOTestUtils.getReaderContext; @@ -54,12 +55,14 @@ import org.apache.beam.sdk.testing.SourceTestUtils; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Distinct; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Watch; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.hadoop.hive.metastore.api.NoSuchObjectException; import org.apache.hive.hcatalog.data.DefaultHCatRecord; @@ -230,6 +233,44 @@ public void processElement(ProcessContext c) { readAfterWritePipeline.run(); } + /** Perform test for reading Date column type from an hcatalog. */ + @Test + public void testReadHCatalogDateType() throws Exception { + service.executeQuery("drop table if exists " + TEST_TABLE); + service.executeQuery("create table " + TEST_TABLE + "(mycol1 string, mycol2 date)"); + + defaultPipeline + .apply(Create.of(buildHCatRecordsWithDate(TEST_RECORDS_COUNT))) + .apply( + HCatalogIO.write() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .withDatabase(TEST_DATABASE) + .withTable(TEST_TABLE) + .withPartition(new java.util.HashMap<>())); + defaultPipeline.run().waitUntilFinish(); + + final PCollection output = + readAfterWritePipeline + .apply( + HCatToRow.fromSpec( + HCatalogIO.read() + .withConfigProperties(getConfigPropertiesAsMap(service.getHiveConf())) + .withDatabase(TEST_DATABASE) + .withTable(TEST_TABLE) + .withFilter(TEST_FILTER))) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element().getDateTime("mycol2").toString("yyyy-MM-dd HH:mm:ss")); + } + })) + .apply(Distinct.create()); + PAssert.that(output).containsInAnyOrder(ImmutableList.of("2014-01-20 00:00:00")); + readAfterWritePipeline.run(); + } + /** Test of Write to a non-existent table. */ @Test public void testWriteFailureTableDoesNotExist() { diff --git a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/test/HCatalogIOTestUtils.java b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/test/HCatalogIOTestUtils.java index d0d1d850a6cb..c09c2c906d64 100644 --- a/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/test/HCatalogIOTestUtils.java +++ b/sdks/java/io/hcatalog/src/test/java/org/apache/beam/sdk/io/hcatalog/test/HCatalogIOTestUtils.java @@ -26,6 +26,7 @@ import java.util.Map.Entry; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.values.KV; +import org.apache.hadoop.hive.common.type.Date; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hive.hcatalog.common.HCatException; import org.apache.hive.hcatalog.data.DefaultHCatRecord; @@ -120,4 +121,13 @@ public static Map getConfigPropertiesAsMap(HiveConf hiveConf) { private static DefaultHCatRecord toHCatRecord(int value) { return new DefaultHCatRecord(Arrays.asList("record " + value, value)); } + + /** Returns a list of HCatRecords of passed size with some dummy date as a field. */ + public static List buildHCatRecordsWithDate(int size) { + List expected = new ArrayList<>(); + for (int i = 0; i < size; i++) { + expected.add(new DefaultHCatRecord(Arrays.asList("record " + i, Date.valueOf("2014-01-20")))); + } + return expected; + } } diff --git a/sdks/java/io/iceberg/build.gradle b/sdks/java/io/iceberg/build.gradle index 3d653d6b276e..0cfa8da4eb7d 100644 --- a/sdks/java/io/iceberg/build.gradle +++ b/sdks/java/io/iceberg/build.gradle @@ -29,10 +29,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Iceberg" ext.summary = "Integration with Iceberg data warehouses." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", - "2102": "2.10.2", - "324": "3.2.4", + "2102": "2.10.2", + "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} @@ -44,7 +44,7 @@ def orc_version = "1.9.2" dependencies { implementation library.java.vendored_guava_32_1_2_jre implementation project(path: ":sdks:java:core", configuration: "shadow") - implementation project(":sdks:java:managed") + implementation project(path: ":model:pipeline", configuration: "shadow") implementation library.java.slf4j_api implementation library.java.joda_time implementation "org.apache.parquet:parquet-column:$parquet_version" @@ -54,12 +54,13 @@ dependencies { implementation "org.apache.iceberg:iceberg-parquet:$iceberg_version" implementation "org.apache.iceberg:iceberg-orc:$iceberg_version" implementation library.java.hadoop_common + runtimeOnly "org.apache.iceberg:iceberg-gcp:$iceberg_version" + testImplementation project(":sdks:java:managed") testImplementation library.java.hadoop_client testImplementation library.java.bigdataoss_gcsio testImplementation library.java.bigdataoss_gcs_connector testImplementation library.java.bigdataoss_util_hadoop - testImplementation "org.apache.iceberg:iceberg-gcp:$iceberg_version" testImplementation "org.apache.iceberg:iceberg-data:$iceberg_version" testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation project(":sdks:java:extensions:google-cloud-platform-core") @@ -69,6 +70,9 @@ dependencies { testRuntimeOnly project(path: ":runners:google-cloud-dataflow-java") hadoopVersions.each {kv -> "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-client:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-minicluster:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-hdfs-client:$kv.value" + "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-mapreduce-client-core:$kv.value" } } @@ -76,6 +80,11 @@ hadoopVersions.each {kv -> configurations."hadoopVersion$kv.key" { resolutionStrategy { force "org.apache.hadoop:hadoop-client:$kv.value" + force "org.apache.hadoop:hadoop-common:$kv.value" + force "org.apache.hadoop:hadoop-mapreduce-client-core:$kv.value" + force "org.apache.hadoop:hadoop-minicluster:$kv.value" + force "org.apache.hadoop:hadoop-hdfs:$kv.value" + force "org.apache.hadoop:hadoop-hdfs-client:$kv.value" } } } diff --git a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/IcebergHiveCatalogIT.java b/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/IcebergHiveCatalogIT.java index 54a4998d37fb..ca4d862c2c72 100644 --- a/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/IcebergHiveCatalogIT.java +++ b/sdks/java/io/iceberg/hive/src/test/java/org/apache/beam/sdk/io/iceberg/hive/IcebergHiveCatalogIT.java @@ -32,6 +32,7 @@ import org.apache.beam.sdk.io.iceberg.hive.testutils.HiveMetastoreExtension; import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -64,7 +65,10 @@ import org.apache.iceberg.io.InputFile; import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.util.DateTimeUtil; import org.apache.thrift.TException; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; @@ -100,6 +104,10 @@ public class IcebergHiveCatalogIT { .addArrayField("arr_long", Schema.FieldType.INT64) .addRowField("row", NESTED_ROW_SCHEMA) .addNullableRowField("nullable_row", NESTED_ROW_SCHEMA) + .addDateTimeField("datetime_tz") + .addLogicalTypeField("datetime", SqlTypes.DATETIME) + .addLogicalTypeField("date", SqlTypes.DATE) + .addLogicalTypeField("time", SqlTypes.TIME) .build(); private static final SimpleFunction ROW_FUNC = @@ -127,6 +135,10 @@ public Row apply(Long num) { .addValue(LongStream.range(1, num % 10).boxed().collect(Collectors.toList())) .addValue(nestedRow) .addValue(num % 2 == 0 ? null : nestedRow) + .addValue(new DateTime(num).withZone(DateTimeZone.forOffsetHoursMinutes(3, 25))) + .addValue(DateTimeUtil.timestampFromMicros(num)) + .addValue(DateTimeUtil.dateFromDays(Integer.parseInt(strNum))) + .addValue(DateTimeUtil.timeFromMicros(num)) .build(); } }; diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java index b91253cf3c12..d9768114e7c6 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java @@ -17,9 +17,16 @@ */ package org.apache.beam.sdk.io.iceberg; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Distribution; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; @@ -28,13 +35,21 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.iceberg.AppendFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Snapshot; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,9 +58,11 @@ class AppendFilesToTables extends PTransform, PCollection>> { private static final Logger LOG = LoggerFactory.getLogger(AppendFilesToTables.class); private final IcebergCatalogConfig catalogConfig; + private final String manifestFilePrefix; - AppendFilesToTables(IcebergCatalogConfig catalogConfig) { + AppendFilesToTables(IcebergCatalogConfig catalogConfig, String manifestFilePrefix) { this.catalogConfig = catalogConfig; + this.manifestFilePrefix = manifestFilePrefix; } @Override @@ -65,7 +82,7 @@ public String apply(FileWriteResult input) { .apply("Group metadata updates by table", GroupByKey.create()) .apply( "Append metadata updates to tables", - ParDo.of(new AppendFilesToTablesDoFn(catalogConfig))) + ParDo.of(new AppendFilesToTablesDoFn(catalogConfig, manifestFilePrefix))) .setCoder(KvCoder.of(StringUtf8Coder.of(), SnapshotInfo.CODER)); } @@ -73,13 +90,19 @@ private static class AppendFilesToTablesDoFn extends DoFn>, KV> { private final Counter snapshotsCreated = Metrics.counter(AppendFilesToTables.class, "snapshotsCreated"); + private final Distribution committedDataFileByteSize = + Metrics.distribution(RecordWriter.class, "committedDataFileByteSize"); + private final Distribution committedDataFileRecordCount = + Metrics.distribution(RecordWriter.class, "committedDataFileRecordCount"); private final IcebergCatalogConfig catalogConfig; + private final String manifestFilePrefix; private transient @MonotonicNonNull Catalog catalog; - private AppendFilesToTablesDoFn(IcebergCatalogConfig catalogConfig) { + private AppendFilesToTablesDoFn(IcebergCatalogConfig catalogConfig, String manifestFilePrefix) { this.catalogConfig = catalogConfig; + this.manifestFilePrefix = manifestFilePrefix; } private Catalog getCatalog() { @@ -89,26 +112,104 @@ private Catalog getCatalog() { return catalog; } + private boolean containsMultiplePartitionSpecs(Iterable fileWriteResults) { + int id = fileWriteResults.iterator().next().getSerializableDataFile().getPartitionSpecId(); + for (FileWriteResult result : fileWriteResults) { + if (id != result.getSerializableDataFile().getPartitionSpecId()) { + return true; + } + } + return false; + } + @ProcessElement public void processElement( @Element KV> element, OutputReceiver> out, - BoundedWindow window) { - if (!element.getValue().iterator().hasNext()) { + BoundedWindow window) + throws IOException { + String tableStringIdentifier = element.getKey(); + Iterable fileWriteResults = element.getValue(); + if (!fileWriteResults.iterator().hasNext()) { return; } Table table = getCatalog().loadTable(TableIdentifier.parse(element.getKey())); - AppendFiles update = table.newAppend(); - for (FileWriteResult writtenFile : element.getValue()) { - update.appendManifest(writtenFile.getManifestFile()); + + // vast majority of the time, we will simply append data files. + // in the rare case we get a batch that contains multiple partition specs, we will group + // data into manifest files and append. + // note: either way, we must use a single commit operation for atomicity. + if (containsMultiplePartitionSpecs(fileWriteResults)) { + appendManifestFiles(table, fileWriteResults); + } else { + appendDataFiles(table, fileWriteResults); } - update.commit(); + Snapshot snapshot = table.currentSnapshot(); - LOG.info("Created new snapshot for table '{}': {}", element.getKey(), snapshot); + LOG.info("Created new snapshot for table '{}': {}", tableStringIdentifier, snapshot); snapshotsCreated.inc(); out.outputWithTimestamp( KV.of(element.getKey(), SnapshotInfo.fromSnapshot(snapshot)), window.maxTimestamp()); } + + // This works only when all files are using the same partition spec. + private void appendDataFiles(Table table, Iterable fileWriteResults) { + AppendFiles update = table.newAppend(); + for (FileWriteResult result : fileWriteResults) { + DataFile dataFile = result.getDataFile(table.specs()); + update.appendFile(dataFile); + committedDataFileByteSize.update(dataFile.fileSizeInBytes()); + committedDataFileRecordCount.update(dataFile.recordCount()); + } + update.commit(); + } + + // When a user updates their table partition spec during runtime, we can end up with + // a batch of files where some are written with the old spec and some are written with the new + // spec. + // A table commit is limited to a single partition spec. + // To handle this, we create a manifest file for each partition spec, and group data files + // accordingly. + // Afterward, we append all manifests using a single commit operation. + private void appendManifestFiles(Table table, Iterable fileWriteResults) + throws IOException { + String uuid = UUID.randomUUID().toString(); + Map specs = table.specs(); + + Map> dataFilesBySpec = new HashMap<>(); + for (FileWriteResult result : fileWriteResults) { + DataFile dataFile = result.getDataFile(specs); + dataFilesBySpec.computeIfAbsent(dataFile.specId(), i -> new ArrayList<>()).add(dataFile); + } + + AppendFiles update = table.newAppend(); + for (Map.Entry> entry : dataFilesBySpec.entrySet()) { + int specId = entry.getKey(); + List files = entry.getValue(); + PartitionSpec spec = Preconditions.checkStateNotNull(specs.get(specId)); + ManifestWriter writer = + createManifestWriter(table.location(), uuid, spec, table.io()); + for (DataFile file : files) { + writer.add(file); + committedDataFileByteSize.update(file.fileSizeInBytes()); + committedDataFileRecordCount.update(file.recordCount()); + } + writer.close(); + update.appendManifest(writer.toManifestFile()); + } + update.commit(); + } + + private ManifestWriter createManifestWriter( + String tableLocation, String uuid, PartitionSpec spec, FileIO io) { + String location = + FileFormat.AVRO.addExtension( + String.format( + "%s/metadata/%s-%s-%s.manifest", + tableLocation, manifestFilePrefix, uuid, spec.specId())); + OutputFile outputFile = io.newOutputFile(location); + return ManifestFiles.write(spec, outputFile); + } } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java index 2459c0befde1..bf00bf8519fc 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java @@ -18,12 +18,12 @@ package org.apache.beam.sdk.io.iceberg; import com.google.auto.value.AutoValue; -import java.io.IOException; +import java.util.Map; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; -import org.apache.iceberg.ManifestFile; -import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.catalog.TableIdentifier; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; @@ -32,12 +32,11 @@ abstract class FileWriteResult { private transient @MonotonicNonNull TableIdentifier cachedTableIdentifier; - private transient @MonotonicNonNull ManifestFile cachedManifestFile; + private transient @MonotonicNonNull DataFile cachedDataFile; abstract String getTableIdentifierString(); - @SuppressWarnings("mutable") - abstract byte[] getManifestFileBytes(); + abstract SerializableDataFile getSerializableDataFile(); @SchemaIgnore public TableIdentifier getTableIdentifier() { @@ -48,15 +47,11 @@ public TableIdentifier getTableIdentifier() { } @SchemaIgnore - public ManifestFile getManifestFile() { - if (cachedManifestFile == null) { - try { - cachedManifestFile = ManifestFiles.decode(getManifestFileBytes()); - } catch (IOException exc) { - throw new RuntimeException("Error decoding manifest file bytes"); - } + public DataFile getDataFile(Map specs) { + if (cachedDataFile == null) { + cachedDataFile = getSerializableDataFile().createDataFile(specs); } - return cachedManifestFile; + return cachedDataFile; } public static Builder builder() { @@ -68,18 +63,13 @@ abstract static class Builder { abstract Builder setTableIdentifierString(String tableIdString); - abstract Builder setManifestFileBytes(byte[] manifestFileBytes); + abstract Builder setSerializableDataFile(SerializableDataFile dataFile); @SchemaIgnore public Builder setTableIdentifier(TableIdentifier tableId) { return setTableIdentifierString(tableId.toString()); } - @SchemaIgnore - public Builder setManifestFile(ManifestFile manifestFile) throws IOException { - return setManifestFileBytes(ManifestFiles.encode(manifestFile)); - } - public abstract FileWriteResult build(); } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java index 6321f9006e2a..1d4b36585237 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergIO.java @@ -24,7 +24,6 @@ import java.util.List; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.io.Read; -import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PBegin; @@ -34,6 +33,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.data.Record; import org.apache.iceberg.types.Type; @@ -44,8 +44,8 @@ * A connector that reads and writes to Apache Iceberg * tables. * - *

    {@link IcebergIO} is offered as a {@link Managed} transform. This class is subject to change - * and should not be used directly. Instead, use it via {@link Managed#ICEBERG} like so: + *

    {@link IcebergIO} is offered as a Managed transform. This class is subject to change and + * should not be used directly. Instead, use it like so: * *

    {@code
      * Map config = Map.of(
    @@ -106,6 +106,14 @@
      * 

    Additional configuration options are provided in the `Pre-filtering Options` section below, * for Iceberg writes. * + *

    Creating Tables

    + * + *

    If an Iceberg table does not exist at the time of writing, this connector will automatically + * create one with the data's schema. + * + *

    Note that this is a best-effort operation that depends on the {@link Catalog} implementation. + * Some implementations may not support creating a table using the Iceberg API. + * *

    Beam Rows

    * *

    Being a Managed transform, this IO exclusively writes and reads using Beam {@link Row}s. @@ -141,7 +149,16 @@ *

  • * * - * + * + * + * + * + * + * + * + * + * + * * * * @@ -157,6 +174,29 @@ * *
    KeyTokenQueuedActive ForStateState Active For
    KeyTokenQueuedActive ForStateState Active ForProcessing ThreadBackend
    "); activeWorkStatus.append(elapsedString(activeWork.getStateStartTime(), now)); + activeWorkStatus.append(""); + activeWorkStatus.append(activeWork.getProcessingThreadName()); + activeWorkStatus.append(""); + activeWorkStatus.append(activeWork.backendWorkerToken()); activeWorkStatus.append("
    DOUBLE DOUBLE
    DATETIME STRING SqlTypes.DATETIME TIMESTAMP
    DATETIME TIMESTAMPTZ
    SqlTypes.DATE DATE
    SqlTypes.TIME TIME
    ITERABLE LIST
    * + *

    Note: {@code SqlTypes} are Beam logical types. + * + *

    Note on timestamps

    + * + *

    For an existing table, the following Beam types are supported for both {@code timestamp} and + * {@code timestamptz}: + * + *

      + *
    • {@code SqlTypes.DATETIME} --> Using a {@link java.time.LocalDateTime} object + *
    • {@code DATETIME} --> Using a {@link org.joda.time.DateTime} object + *
    • {@code INT64} --> Using a {@link Long} representing micros since EPOCH + *
    • {@code STRING} --> Using a timestamp {@link String} representation (e.g. {@code + * "2024-10-08T13:18:20.053+03:27"}) + *
    + * + *

    Note: If you expect Beam to create the Iceberg table at runtime, please provide {@code + * SqlTypes.DATETIME} for a {@code timestamp} column and {@code DATETIME} for a {@code timestamptz} + * column. If the table does not exist, Beam will treat {@code STRING} and {@code INT64} at + * face-value and create equivalent column types. + * + *

    For Iceberg reads, the connector will produce Beam {@code SqlTypes.DATETIME} types for + * Iceberg's {@code timestamp} and {@code DATETIME} types for {@code timestamptz}. + * *

    Dynamic Destinations

    * *

    Managed Iceberg supports writing to dynamic destinations. To do so, please provide an @@ -318,7 +358,7 @@ public WriteRows to(DynamicDestinations destinations) { * org.apache.iceberg.Snapshot} is produced. * *

    Roughly every triggeringFrequency duration, records are written to data files and appended - * to the respective table. Each append operation created a new table snapshot. + * to the respective table. Each append operation creates a new table snapshot. * *

    Generally speaking, increasing this duration will result in fewer, larger data files and * fewer snapshots. diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java index df7bda4560dd..d44149fda08e 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProvider.java @@ -17,10 +17,12 @@ */ package org.apache.beam.sdk.io.iceberg; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; + import com.google.auto.service.AutoService; import java.util.Collections; import java.util.List; -import org.apache.beam.sdk.managed.ManagedTransformConstants; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; @@ -53,7 +55,7 @@ public List outputCollectionNames() { @Override public String identifier() { - return ManagedTransformConstants.ICEBERG_READ; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_READ); } static class IcebergReadSchemaTransform extends SchemaTransform { diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergUtils.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergUtils.java index acd9b25a6a5e..ef19a5881366 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergUtils.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergUtils.java @@ -20,12 +20,18 @@ import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; import java.nio.ByteBuffer; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; @@ -34,15 +40,13 @@ import org.apache.iceberg.data.Record; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.DateTimeUtil; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; -import org.joda.time.DateTimeZone; +import org.joda.time.Instant; -/** Utilities for converting between Beam and Iceberg types. */ +/** Utilities for converting between Beam and Iceberg types, made public for user's convenience. */ public class IcebergUtils { - // This is made public for users convenience, as many may have more experience working with - // Iceberg types. - private IcebergUtils() {} private static final Map BEAM_TYPES_TO_ICEBERG_TYPES = @@ -54,6 +58,14 @@ private IcebergUtils() {} .put(Schema.TypeName.DOUBLE, Types.DoubleType.get()) .put(Schema.TypeName.STRING, Types.StringType.get()) .put(Schema.TypeName.BYTES, Types.BinaryType.get()) + .put(Schema.TypeName.DATETIME, Types.TimestampType.withZone()) + .build(); + + private static final Map BEAM_LOGICAL_TYPES_TO_ICEBERG_TYPES = + ImmutableMap.builder() + .put(SqlTypes.DATE.getIdentifier(), Types.DateType.get()) + .put(SqlTypes.TIME.getIdentifier(), Types.TimeType.get()) + .put(SqlTypes.DATETIME.getIdentifier(), Types.TimestampType.withoutZone()) .build(); private static Schema.FieldType icebergTypeToBeamFieldType(final Type type) { @@ -69,9 +81,15 @@ private static Schema.FieldType icebergTypeToBeamFieldType(final Type type) { case DOUBLE: return Schema.FieldType.DOUBLE; case DATE: + return Schema.FieldType.logicalType(SqlTypes.DATE); case TIME: - case TIMESTAMP: // TODO: Logical types? - return Schema.FieldType.DATETIME; + return Schema.FieldType.logicalType(SqlTypes.TIME); + case TIMESTAMP: + Types.TimestampType ts = (Types.TimestampType) type.asPrimitiveType(); + if (ts.shouldAdjustToUTC()) { + return Schema.FieldType.DATETIME; + } + return Schema.FieldType.logicalType(SqlTypes.DATETIME); case STRING: return Schema.FieldType.STRING; case UUID: @@ -151,6 +169,14 @@ static TypeAndMaxId beamFieldTypeToIcebergFieldType( // other types. return new TypeAndMaxId( --nestedFieldId, BEAM_TYPES_TO_ICEBERG_TYPES.get(beamType.getTypeName())); + } else if (beamType.getTypeName().isLogicalType()) { + String logicalTypeIdentifier = + checkArgumentNotNull(beamType.getLogicalType()).getIdentifier(); + @Nullable Type type = BEAM_LOGICAL_TYPES_TO_ICEBERG_TYPES.get(logicalTypeIdentifier); + if (type == null) { + throw new RuntimeException("Unsupported Beam logical type " + logicalTypeIdentifier); + } + return new TypeAndMaxId(--nestedFieldId, type); } else if (beamType.getTypeName().isCollectionType()) { // ARRAY or ITERABLE Schema.FieldType beamCollectionType = Preconditions.checkArgumentNotNull(beamType.getCollectionElementType()); @@ -227,8 +253,6 @@ static TypeAndMaxId beamFieldTypeToIcebergFieldType( * *

    The following unsupported Beam types will be defaulted to {@link Types.StringType}: *

  • {@link Schema.TypeName.DECIMAL} - *
  • {@link Schema.TypeName.DATETIME} - *
  • {@link Schema.TypeName.LOGICAL_TYPE} */ public static org.apache.iceberg.Schema beamSchemaToIcebergSchema(final Schema schema) { List fields = new ArrayList<>(schema.getFieldCount()); @@ -282,12 +306,20 @@ private static void copyFieldIntoRecord(Record rec, Types.NestedField field, Row Optional.ofNullable(value.getDouble(name)).ifPresent(v -> rec.setField(name, v)); break; case DATE: - throw new UnsupportedOperationException("Date fields not yet supported"); + Optional.ofNullable(value.getLogicalTypeValue(name, LocalDate.class)) + .ifPresent(v -> rec.setField(name, v)); + break; case TIME: - throw new UnsupportedOperationException("Time fields not yet supported"); + Optional.ofNullable(value.getLogicalTypeValue(name, LocalTime.class)) + .ifPresent(v -> rec.setField(name, v)); + break; case TIMESTAMP: - Optional.ofNullable(value.getDateTime(name)) - .ifPresent(v -> rec.setField(name, v.getMillis())); + Object val = value.getValue(name); + if (val == null) { + break; + } + Types.TimestampType ts = (Types.TimestampType) field.type().asPrimitiveType(); + rec.setField(name, getIcebergTimestampValue(val, ts.shouldAdjustToUTC())); break; case STRING: Optional.ofNullable(value.getString(name)).ifPresent(v -> rec.setField(name, v)); @@ -322,6 +354,55 @@ private static void copyFieldIntoRecord(Record rec, Types.NestedField field, Row } } + /** + * Returns the appropriate value for an Iceberg timestamp field + * + *

    If `timestamp`, we resolve incoming values to a {@link LocalDateTime}. + * + *

    If `timestamptz`, we resolve to a UTC {@link OffsetDateTime}. Iceberg already resolves all + * incoming timestamps to UTC, so there is no harm in doing it from our side. + * + *

    Valid types are: + * + *

      + *
    • {@link SqlTypes.DATETIME} --> {@link LocalDateTime} + *
    • {@link Schema.FieldType.DATETIME} --> {@link Instant} + *
    • {@link Schema.FieldType.INT64} --> {@link Long} + *
    • {@link Schema.FieldType.STRING} --> {@link String} + *
    + */ + private static Object getIcebergTimestampValue(Object beamValue, boolean shouldAdjustToUtc) { + // timestamptz + if (shouldAdjustToUtc) { + if (beamValue instanceof LocalDateTime) { // SqlTypes.DATETIME + return OffsetDateTime.of((LocalDateTime) beamValue, ZoneOffset.UTC); + } else if (beamValue instanceof Instant) { // FieldType.DATETIME + return DateTimeUtil.timestamptzFromMicros(((Instant) beamValue).getMillis() * 1000L); + } else if (beamValue instanceof Long) { // FieldType.INT64 + return DateTimeUtil.timestamptzFromMicros((Long) beamValue); + } else if (beamValue instanceof String) { // FieldType.STRING + return OffsetDateTime.parse((String) beamValue).withOffsetSameInstant(ZoneOffset.UTC); + } else { + throw new UnsupportedOperationException( + "Unsupported Beam type for Iceberg timestamp with timezone: " + beamValue.getClass()); + } + } + + // timestamp + if (beamValue instanceof LocalDateTime) { // SqlType.DATETIME + return beamValue; + } else if (beamValue instanceof Instant) { // FieldType.DATETIME + return DateTimeUtil.timestampFromMicros(((Instant) beamValue).getMillis() * 1000L); + } else if (beamValue instanceof Long) { // FieldType.INT64 + return DateTimeUtil.timestampFromMicros((Long) beamValue); + } else if (beamValue instanceof String) { // FieldType.STRING + return LocalDateTime.parse((String) beamValue); + } else { + throw new UnsupportedOperationException( + "Unsupported Beam type for Iceberg timestamp with timezone: " + beamValue.getClass()); + } + } + /** Converts an Iceberg {@link Record} to a Beam {@link Row}. */ public static Row icebergRecordToBeamRow(Schema schema, Record record) { Row.Builder rowBuilder = Row.withSchema(schema); @@ -345,16 +426,17 @@ public static Row icebergRecordToBeamRow(Schema schema, Record record) { case FLOAT: // Iceberg and Beam both use float case DOUBLE: // Iceberg and Beam both use double case STRING: // Iceberg and Beam both use String - case BOOLEAN: // Iceberg and Beam both use String + case BOOLEAN: // Iceberg and Beam both use boolean case ARRAY: case ITERABLE: case MAP: rowBuilder.addValue(icebergValue); break; case DATETIME: - // Iceberg uses a long for millis; Beam uses joda time DateTime - long millis = (long) icebergValue; - rowBuilder.addValue(new DateTime(millis, DateTimeZone.UTC)); + // Iceberg uses a long for micros. + // Beam DATETIME uses joda's DateTime, which only supports millis, + // so we do lose some precision here + rowBuilder.addValue(getBeamDateTimeValue(icebergValue)); break; case BYTES: // Iceberg uses ByteBuffer; Beam uses byte[] @@ -369,8 +451,8 @@ public static Row icebergRecordToBeamRow(Schema schema, Record record) { rowBuilder.addValue(icebergRecordToBeamRow(nestedSchema, nestedRecord)); break; case LOGICAL_TYPE: - throw new UnsupportedOperationException( - "Cannot convert iceberg field to Beam logical type"); + rowBuilder.addValue(getLogicalTypeValue(icebergValue, field.getType())); + break; default: throw new UnsupportedOperationException( "Unsupported Beam type: " + field.getType().getTypeName()); @@ -378,4 +460,50 @@ public static Row icebergRecordToBeamRow(Schema schema, Record record) { } return rowBuilder.build(); } + + private static DateTime getBeamDateTimeValue(Object icebergValue) { + long micros; + if (icebergValue instanceof OffsetDateTime) { + micros = DateTimeUtil.microsFromTimestamptz((OffsetDateTime) icebergValue); + } else if (icebergValue instanceof LocalDateTime) { + micros = DateTimeUtil.microsFromTimestamp((LocalDateTime) icebergValue); + } else if (icebergValue instanceof Long) { + micros = (long) icebergValue; + } else if (icebergValue instanceof String) { + return DateTime.parse((String) icebergValue); + } else { + throw new UnsupportedOperationException( + "Unsupported Iceberg type for Beam type DATETIME: " + icebergValue.getClass()); + } + return new DateTime(micros / 1000L); + } + + private static Object getLogicalTypeValue(Object icebergValue, Schema.FieldType type) { + if (icebergValue instanceof String) { + String strValue = (String) icebergValue; + if (type.isLogicalType(SqlTypes.DATE.getIdentifier())) { + return LocalDate.parse(strValue); + } else if (type.isLogicalType(SqlTypes.TIME.getIdentifier())) { + return LocalTime.parse(strValue); + } else if (type.isLogicalType(SqlTypes.DATETIME.getIdentifier())) { + return LocalDateTime.parse(strValue); + } + } else if (icebergValue instanceof Long) { + if (type.isLogicalType(SqlTypes.TIME.getIdentifier())) { + return DateTimeUtil.timeFromMicros((Long) icebergValue); + } else if (type.isLogicalType(SqlTypes.DATETIME.getIdentifier())) { + return DateTimeUtil.timestampFromMicros((Long) icebergValue); + } + } else if (icebergValue instanceof Integer + && type.isLogicalType(SqlTypes.DATE.getIdentifier())) { + return DateTimeUtil.dateFromDays((Integer) icebergValue); + } else if (icebergValue instanceof OffsetDateTime + && type.isLogicalType(SqlTypes.DATETIME.getIdentifier())) { + return ((OffsetDateTime) icebergValue) + .withOffsetSameInstant(ZoneOffset.UTC) + .toLocalDateTime(); + } + // LocalDateTime, LocalDate, LocalTime + return icebergValue; + } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java index ea46e8560815..6aa830e7fbc6 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProvider.java @@ -18,13 +18,14 @@ package org.apache.beam.sdk.io.iceberg; import static org.apache.beam.sdk.io.iceberg.IcebergWriteSchemaTransformProvider.Configuration; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import java.util.Collections; import java.util.List; import java.util.Map; -import org.apache.beam.sdk.managed.ManagedTransformConstants; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; @@ -151,7 +152,7 @@ public List outputCollectionNames() { @Override public String identifier() { - return ManagedTransformConstants.ICEBERG_WRITE; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_WRITE); } static class IcebergWriteSchemaTransform extends SchemaTransform { diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java index 92b5dd58b51e..7941c13b0dfe 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java @@ -19,7 +19,6 @@ import java.io.IOException; import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Distribution; import org.apache.beam.sdk.metrics.Metrics; import org.apache.iceberg.DataFile; import org.apache.iceberg.FileFormat; @@ -38,9 +37,8 @@ class RecordWriter { private static final Logger LOG = LoggerFactory.getLogger(RecordWriter.class); private final Counter activeIcebergWriters = - Metrics.counter(RecordWriterManager.class, "activeIcebergWriters"); - private final Distribution dataFileByteSize = - Metrics.distribution(RecordWriter.class, "dataFileByteSize"); + Metrics.counter(RecordWriter.class, "activeIcebergWriters"); + private final Counter dataFilesWritten = Metrics.counter(RecordWriter.class, "dataFilesWritten"); private final DataWriter icebergDataWriter; private final Table table; private final String absoluteFilename; @@ -128,7 +126,7 @@ public void close() throws IOException { dataFile.recordCount(), dataFile.fileSizeInBytes(), absoluteFilename); - dataFileByteSize.update(dataFile.fileSizeInBytes()); + dataFilesWritten.inc(); } public long bytesWritten() { @@ -138,4 +136,8 @@ public long bytesWritten() { public DataFile getDataFile() { return icebergDataWriter.toDataFile(); } + + public String path() { + return absoluteFilename; + } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java index 5979e2a60131..255fce9ece4e 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java @@ -21,13 +21,12 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeUnit; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.Row; @@ -38,17 +37,18 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.iceberg.DataFile; -import org.apache.iceberg.FileFormat; import org.apache.iceberg.ManifestFile; -import org.apache.iceberg.ManifestFiles; -import org.apache.iceberg.ManifestWriter; import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.data.Record; -import org.apache.iceberg.io.FileIO; -import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A writer that manages multiple {@link RecordWriter}s to write to multiple tables and partitions. @@ -66,19 +66,14 @@ * *
      *
    1. Close all underlying {@link RecordWriter}s - *
    2. Collect all {@link DataFile}s - *
    3. Create a new {@link ManifestFile} referencing these {@link DataFile}s + *
    4. Collect all {@link DataFile}s as {@link SerializableDataFile}s (a more Beam-friendly type) *
    * - *

    After closing, the resulting {@link ManifestFile}s can be retrieved using {@link - * #getManifestFiles()}. + *

    After closing, the resulting {@link SerializableDataFile}s can be retrieved using {@link + * #getSerializableDataFiles()}. */ class RecordWriterManager implements AutoCloseable { - private final Counter dataFilesWritten = - Metrics.counter(RecordWriterManager.class, "dataFilesWritten"); - private final Counter manifestFilesWritten = - Metrics.counter(RecordWriterManager.class, "manifestFilesWritten"); - + private static final Logger LOG = LoggerFactory.getLogger(RecordWriterManager.class); /** * Represents the state of one Iceberg table destination. Creates one {@link RecordWriter} per * partition and manages them in a {@link Cache}. @@ -90,21 +85,18 @@ class DestinationState { private final PartitionSpec spec; private final org.apache.iceberg.Schema schema; private final PartitionKey partitionKey; - private final String tableLocation; - private final FileIO fileIO; private final Table table; private final String stateToken = UUID.randomUUID().toString(); - private final List dataFiles = Lists.newArrayList(); - @VisibleForTesting final Cache writers; + final Cache writers; + private final List dataFiles = Lists.newArrayList(); @VisibleForTesting final Map writerCounts = Maps.newHashMap(); + private final List exceptions = Lists.newArrayList(); DestinationState(IcebergDestination icebergDestination, Table table) { this.icebergDestination = icebergDestination; this.schema = table.schema(); this.spec = table.spec(); this.partitionKey = new PartitionKey(spec, schema); - this.tableLocation = table.location(); - this.fileIO = table.io(); this.table = table; // build a cache of RecordWriters. @@ -121,15 +113,17 @@ class DestinationState { try { recordWriter.close(); } catch (IOException e) { - throw new RuntimeException( - String.format( - "Encountered an error when closing data writer for table '%s', partition %s", - icebergDestination.getTableIdentifier(), pk), - e); + RuntimeException rethrow = + new RuntimeException( + String.format( + "Encountered an error when closing data writer for table '%s', path: %s", + icebergDestination.getTableIdentifier(), recordWriter.path()), + e); + exceptions.add(rethrow); + throw rethrow; } openWriters--; - dataFiles.add(recordWriter.getDataFile()); - dataFilesWritten.inc(); + dataFiles.add(SerializableDataFile.from(recordWriter.getDataFile(), pk)); }) .build(); } @@ -191,13 +185,6 @@ private RecordWriter createWriter(PartitionKey partitionKey) { e); } } - - private String getManifestFileLocation(PaneInfo paneInfo) { - return FileFormat.AVRO.addExtension( - String.format( - "%s/metadata/%s-%s-%s.manifest", - tableLocation, filePrefix, stateToken, paneInfo.getIndex())); - } } private final Catalog catalog; @@ -209,8 +196,12 @@ private String getManifestFileLocation(PaneInfo paneInfo) { @VisibleForTesting final Map, DestinationState> destinations = Maps.newHashMap(); - private final Map, List> totalManifestFiles = - Maps.newHashMap(); + private final Map, List> + totalSerializableDataFiles = Maps.newHashMap(); + + @VisibleForTesting + static final Cache TABLE_CACHE = + CacheBuilder.newBuilder().expireAfterAccess(10, TimeUnit.MINUTES).build(); private boolean isClosed = false; @@ -221,6 +212,46 @@ private String getManifestFileLocation(PaneInfo paneInfo) { this.maxNumWriters = maxNumWriters; } + /** + * Returns an Iceberg {@link Table}. + * + *

    First attempts to fetch the table from the {@link #TABLE_CACHE}. If it's not there, we + * attempt to load it using the Iceberg API. If the table doesn't exist at all, we attempt to + * create it, inferring the table schema from the record schema. + * + *

    Note that this is a best-effort operation that depends on the {@link Catalog} + * implementation. Although it is expected, some implementations may not support creating a table + * using the Iceberg API. + */ + private Table getOrCreateTable(TableIdentifier identifier, Schema dataSchema) { + @Nullable Table table = TABLE_CACHE.getIfPresent(identifier); + if (table == null) { + synchronized (TABLE_CACHE) { + try { + table = catalog.loadTable(identifier); + } catch (NoSuchTableException e) { + try { + org.apache.iceberg.Schema tableSchema = + IcebergUtils.beamSchemaToIcebergSchema(dataSchema); + // TODO(ahmedabu98): support creating a table with a specified partition spec + table = catalog.createTable(identifier, tableSchema); + LOG.info("Created Iceberg table '{}' with schema: {}", identifier, tableSchema); + } catch (AlreadyExistsException alreadyExistsException) { + // handle race condition where workers are concurrently creating the same table. + // if running into already exists exception, we perform one last load + table = catalog.loadTable(identifier); + } + } + TABLE_CACHE.put(identifier, table); + } + } else { + // If fetching from cache, refresh the table to avoid working with stale metadata + // (e.g. partition spec) + table.refresh(); + } + return table; + } + /** * Fetches the appropriate {@link RecordWriter} for this destination and partition and writes the * record. @@ -233,7 +264,8 @@ public boolean write(WindowedValue icebergDestination, Row r destinations.computeIfAbsent( icebergDestination, destination -> { - Table table = catalog.loadTable(destination.getValue().getTableIdentifier()); + TableIdentifier identifier = destination.getValue().getTableIdentifier(); + Table table = getOrCreateTable(identifier, row.getSchema()); return new DestinationState(destination.getValue(), table); }); @@ -249,31 +281,28 @@ public boolean write(WindowedValue icebergDestination, Row r public void close() throws IOException { for (Map.Entry, DestinationState> windowedDestinationAndState : destinations.entrySet()) { - WindowedValue windowedDestination = windowedDestinationAndState.getKey(); DestinationState state = windowedDestinationAndState.getValue(); // removing writers from the state's cache will trigger the logic to collect each writer's // data file. state.writers.invalidateAll(); - if (state.dataFiles.isEmpty()) { - continue; + // first check for any exceptions swallowed by the cache + if (!state.exceptions.isEmpty()) { + IllegalStateException exception = + new IllegalStateException( + String.format("Encountered %s failed writer(s).", state.exceptions.size())); + for (Exception e : state.exceptions) { + exception.addSuppressed(e); + } + throw exception; } - OutputFile outputFile = - state.fileIO.newOutputFile(state.getManifestFileLocation(windowedDestination.getPane())); - - ManifestWriter manifestWriter; - try (ManifestWriter openWriter = ManifestFiles.write(state.spec, outputFile)) { - openWriter.addAll(state.dataFiles); - manifestWriter = openWriter; + if (state.dataFiles.isEmpty()) { + continue; } - ManifestFile manifestFile = manifestWriter.toManifestFile(); - manifestFilesWritten.inc(); - - totalManifestFiles - .computeIfAbsent(windowedDestination, dest -> Lists.newArrayList()) - .add(manifestFile); + totalSerializableDataFiles.put( + windowedDestinationAndState.getKey(), new ArrayList<>(state.dataFiles)); state.dataFiles.clear(); } destinations.clear(); @@ -285,15 +314,16 @@ public void close() throws IOException { } /** - * Returns a list of accumulated windowed {@link ManifestFile}s for each windowed {@link + * Returns a list of accumulated serializable {@link DataFile}s for each windowed {@link * IcebergDestination}. The {@link RecordWriterManager} must first be closed before this is * called. */ - public Map, List> getManifestFiles() { + public Map, List> + getSerializableDataFiles() { checkState( isClosed, - "Please close this %s before retrieving its manifest files.", + "Please close this %s before retrieving its data files.", getClass().getSimpleName()); - return totalManifestFiles; + return totalSerializableDataFiles; } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java new file mode 100644 index 000000000000..59b456162008 --- /dev/null +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.iceberg; + +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import com.google.auto.value.AutoValue; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.Metrics; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Serializable version of an Iceberg {@link DataFile}. + * + *

    {@link DataFile} is not serializable and the Iceberg API doesn't offer an easy way to + * encode/decode it. This class is an identical version that can be used as a PCollection element + * type. + * + *

    Use {@link #from(DataFile, PartitionKey)} to create a {@link SerializableDataFile} and {@link + * #createDataFile(PartitionSpec)} to reconstruct the original {@link DataFile}. + */ +@DefaultSchema(AutoValueSchema.class) +@AutoValue +abstract class SerializableDataFile { + public static Builder builder() { + return new AutoValue_SerializableDataFile.Builder(); + } + + abstract String getPath(); + + abstract String getFileFormat(); + + abstract long getRecordCount(); + + abstract long getFileSizeInBytes(); + + abstract String getPartitionPath(); + + abstract int getPartitionSpecId(); + + abstract @Nullable ByteBuffer getKeyMetadata(); + + abstract @Nullable List getSplitOffsets(); + + abstract @Nullable Map getColumnSizes(); + + abstract @Nullable Map getValueCounts(); + + abstract @Nullable Map getNullValueCounts(); + + abstract @Nullable Map getNanValueCounts(); + + abstract @Nullable Map getLowerBounds(); + + abstract @Nullable Map getUpperBounds(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setPath(String path); + + abstract Builder setFileFormat(String fileFormat); + + abstract Builder setRecordCount(long recordCount); + + abstract Builder setFileSizeInBytes(long fileSizeInBytes); + + abstract Builder setPartitionPath(String partitionPath); + + abstract Builder setPartitionSpecId(int partitionSpec); + + abstract Builder setKeyMetadata(ByteBuffer keyMetadata); + + abstract Builder setSplitOffsets(List splitOffsets); + + abstract Builder setColumnSizes(Map columnSizes); + + abstract Builder setValueCounts(Map valueCounts); + + abstract Builder setNullValueCounts(Map nullValueCounts); + + abstract Builder setNanValueCounts(Map nanValueCounts); + + abstract Builder setLowerBounds(@Nullable Map lowerBounds); + + abstract Builder setUpperBounds(@Nullable Map upperBounds); + + abstract SerializableDataFile build(); + } + + /** + * Create a {@link SerializableDataFile} from a {@link DataFile} and its associated {@link + * PartitionKey}. + */ + static SerializableDataFile from(DataFile f, PartitionKey key) { + return SerializableDataFile.builder() + .setPath(f.path().toString()) + .setFileFormat(f.format().toString()) + .setRecordCount(f.recordCount()) + .setFileSizeInBytes(f.fileSizeInBytes()) + .setPartitionPath(key.toPath()) + .setPartitionSpecId(f.specId()) + .setKeyMetadata(f.keyMetadata()) + .setSplitOffsets(f.splitOffsets()) + .setColumnSizes(f.columnSizes()) + .setValueCounts(f.valueCounts()) + .setNullValueCounts(f.nullValueCounts()) + .setNanValueCounts(f.nanValueCounts()) + .setLowerBounds(toByteArrayMap(f.lowerBounds())) + .setUpperBounds(toByteArrayMap(f.upperBounds())) + .build(); + } + + /** + * Reconstructs the original {@link DataFile} from this {@link SerializableDataFile}. + * + *

    We require an input {@link PartitionSpec} as well because there's no easy way to reconstruct + * it from Beam-compatible types. + */ + @SuppressWarnings("nullness") + DataFile createDataFile(Map partitionSpecs) { + PartitionSpec partitionSpec = + checkStateNotNull( + partitionSpecs.get(getPartitionSpecId()), + "This DataFile was originally created with spec id '%s'. Could not find " + + "this among table's partition specs: %s.", + getPartitionSpecId(), + partitionSpecs.keySet()); + + Metrics dataFileMetrics = + new Metrics( + getRecordCount(), + getColumnSizes(), + getValueCounts(), + getNullValueCounts(), + getNanValueCounts(), + toByteBufferMap(getLowerBounds()), + toByteBufferMap(getUpperBounds())); + + return DataFiles.builder(partitionSpec) + .withFormat(FileFormat.fromString(getFileFormat())) + .withPath(getPath()) + .withPartitionPath(getPartitionPath()) + .withEncryptionKeyMetadata(getKeyMetadata()) + .withFileSizeInBytes(getFileSizeInBytes()) + .withMetrics(dataFileMetrics) + .withSplitOffsets(getSplitOffsets()) + .build(); + } + + // ByteBuddyUtils has trouble converting Map value type ByteBuffer + // to byte[] and back to ByteBuffer, so we perform these conversions manually + // TODO(https://github.com/apache/beam/issues/32701) + private static @Nullable Map toByteArrayMap( + @Nullable Map input) { + if (input == null) { + return null; + } + Map output = new HashMap<>(input.size()); + for (Map.Entry e : input.entrySet()) { + output.put(e.getKey(), e.getValue().array()); + } + return output; + } + + private static @Nullable Map toByteBufferMap( + @Nullable Map input) { + if (input == null) { + return null; + } + Map output = new HashMap<>(input.size()); + for (Map.Entry e : input.entrySet()) { + output.put(e.getKey(), ByteBuffer.wrap(e.getValue())); + } + return output; + } +} diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteGroupedRowsToFiles.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteGroupedRowsToFiles.java index 1926a769a6da..6a61aafbe8b9 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteGroupedRowsToFiles.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteGroupedRowsToFiles.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.io.iceberg; import java.util.List; -import java.util.UUID; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -30,7 +29,6 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.iceberg.ManifestFile; import org.apache.iceberg.catalog.Catalog; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; @@ -42,11 +40,15 @@ class WriteGroupedRowsToFiles private final DynamicDestinations dynamicDestinations; private final IcebergCatalogConfig catalogConfig; + private final String filePrefix; WriteGroupedRowsToFiles( - IcebergCatalogConfig catalogConfig, DynamicDestinations dynamicDestinations) { + IcebergCatalogConfig catalogConfig, + DynamicDestinations dynamicDestinations, + String filePrefix) { this.catalogConfig = catalogConfig; this.dynamicDestinations = dynamicDestinations; + this.filePrefix = filePrefix; } @Override @@ -55,7 +57,7 @@ public PCollection expand( return input.apply( ParDo.of( new WriteGroupedRowsToFilesDoFn( - catalogConfig, dynamicDestinations, DEFAULT_MAX_BYTES_PER_FILE))); + catalogConfig, dynamicDestinations, DEFAULT_MAX_BYTES_PER_FILE, filePrefix))); } private static class WriteGroupedRowsToFilesDoFn @@ -70,10 +72,11 @@ private static class WriteGroupedRowsToFilesDoFn WriteGroupedRowsToFilesDoFn( IcebergCatalogConfig catalogConfig, DynamicDestinations dynamicDestinations, - long maxFileSize) { + long maxFileSize, + String filePrefix) { this.catalogConfig = catalogConfig; this.dynamicDestinations = dynamicDestinations; - this.filePrefix = UUID.randomUUID().toString(); + this.filePrefix = filePrefix; this.maxFileSize = maxFileSize; } @@ -105,13 +108,13 @@ public void processElement( } } - List manifestFiles = - Preconditions.checkNotNull(writer.getManifestFiles().get(windowedDestination)); - for (ManifestFile manifestFile : manifestFiles) { + List serializableDataFiles = + Preconditions.checkNotNull(writer.getSerializableDataFiles().get(windowedDestination)); + for (SerializableDataFile dataFile : serializableDataFiles) { c.output( FileWriteResult.builder() .setTableIdentifier(destination.getTableIdentifier()) - .setManifestFile(manifestFile) + .setSerializableDataFile(dataFile) .build()); } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java index 4d03f3a3bc58..fb3bf43f3515 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java @@ -19,6 +19,7 @@ import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import java.util.UUID; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.RowCoder; @@ -50,6 +51,7 @@ class WriteToDestinations extends PTransform>, Icebe private final IcebergCatalogConfig catalogConfig; private final DynamicDestinations dynamicDestinations; private final @Nullable Duration triggeringFrequency; + private final String filePrefix; WriteToDestinations( IcebergCatalogConfig catalogConfig, @@ -58,6 +60,8 @@ class WriteToDestinations extends PTransform>, Icebe this.dynamicDestinations = dynamicDestinations; this.catalogConfig = catalogConfig; this.triggeringFrequency = triggeringFrequency; + // single unique prefix per write transform + this.filePrefix = UUID.randomUUID().toString(); } @Override @@ -70,7 +74,7 @@ public IcebergWriteResult expand(PCollection> input) { // Commit files to tables PCollection> snapshots = - writtenFiles.apply(new AppendFilesToTables(catalogConfig)); + writtenFiles.apply(new AppendFilesToTables(catalogConfig, filePrefix)); return new IcebergWriteResult(input.getPipeline(), snapshots); } @@ -97,7 +101,9 @@ private PCollection writeTriggered(PCollection> IterableCoder.of(RowCoder.of(dynamicDestinations.getDataSchema())))); return groupedRecords - .apply("WriteGroupedRows", new WriteGroupedRowsToFiles(catalogConfig, dynamicDestinations)) + .apply( + "WriteGroupedRows", + new WriteGroupedRowsToFiles(catalogConfig, dynamicDestinations, filePrefix)) // Respect user's triggering frequency before committing snapshots .apply( "ApplyUserTrigger", @@ -120,7 +126,7 @@ private PCollection writeUntriggered(PCollection writeGroupedResult = @@ -129,7 +135,7 @@ private PCollection writeUntriggered(PCollection, List> destinationAndFiles : - Preconditions.checkNotNull(recordWriterManager).getManifestFiles().entrySet()) { + + for (Map.Entry, List> + destinationAndFiles : + Preconditions.checkNotNull(recordWriterManager) + .getSerializableDataFiles() + .entrySet()) { WindowedValue windowedDestination = destinationAndFiles.getKey(); - for (ManifestFile manifestFile : destinationAndFiles.getValue()) { + for (SerializableDataFile dataFile : destinationAndFiles.getValue()) { c.output( FileWriteResult.builder() - .setManifestFile(manifestFile) + .setSerializableDataFile(dataFile) .setTableIdentifier(windowedDestination.getValue().getTableIdentifier()) .build(), windowedDestination.getTimestamp(), diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java index 84f2146275f0..c79b0a550051 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java @@ -17,11 +17,14 @@ */ package org.apache.beam.sdk.io.iceberg; +import static org.apache.beam.sdk.schemas.Schema.FieldType; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertTrue; +import com.google.api.services.storage.model.Objects; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; @@ -34,7 +37,11 @@ import java.util.stream.LongStream; import java.util.stream.Stream; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.extensions.gcp.options.GcsOptions; +import org.apache.beam.sdk.extensions.gcp.util.GcsUtil; +import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -71,9 +78,13 @@ import org.apache.iceberg.io.InputFile; import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.util.DateTimeUtil; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; @@ -81,10 +92,14 @@ import org.junit.rules.TestName; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** Integration tests for {@link IcebergIO} source and sink. */ @RunWith(JUnit4.class) public class IcebergIOIT implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(IcebergIOIT.class); + private static final org.apache.beam.sdk.schemas.Schema DOUBLY_NESTED_ROW_SCHEMA = org.apache.beam.sdk.schemas.Schema.builder() .addStringField("doubly_nested_str") @@ -106,9 +121,13 @@ public class IcebergIOIT implements Serializable { .addBooleanField("bool") .addInt32Field("int") .addRowField("row", NESTED_ROW_SCHEMA) - .addArrayField("arr_long", org.apache.beam.sdk.schemas.Schema.FieldType.INT64) + .addArrayField("arr_long", FieldType.INT64) .addNullableRowField("nullable_row", NESTED_ROW_SCHEMA) .addNullableInt64Field("nullable_long") + .addDateTimeField("datetime_tz") + .addLogicalTypeField("datetime", SqlTypes.DATETIME) + .addLogicalTypeField("date", SqlTypes.DATE) + .addLogicalTypeField("time", SqlTypes.TIME) .build(); private static final SimpleFunction ROW_FUNC = @@ -138,6 +157,10 @@ public Row apply(Long num) { .addValue(LongStream.range(0, num % 10).boxed().collect(Collectors.toList())) .addValue(num % 2 == 0 ? null : nestedRow) .addValue(num) + .addValue(new DateTime(num).withZone(DateTimeZone.forOffsetHoursMinutes(3, 25))) + .addValue(DateTimeUtil.timestampFromMicros(num)) + .addValue(DateTimeUtil.dateFromDays(Integer.parseInt(strNum))) + .addValue(DateTimeUtil.timeFromMicros(num)) .build(); } }; @@ -162,27 +185,45 @@ public Record apply(Row input) { @Rule public TestName testName = new TestName(); - private String warehouseLocation; + private static String warehouseLocation; private String tableId; - private Catalog catalog; + private static Catalog catalog; @BeforeClass public static void beforeClass() { options = TestPipeline.testingPipelineOptions().as(GcpOptions.class); - + warehouseLocation = + String.format("%s/IcebergIOIT/%s", options.getTempLocation(), UUID.randomUUID()); catalogHadoopConf = new Configuration(); catalogHadoopConf.set("fs.gs.project.id", options.getProject()); catalogHadoopConf.set("fs.gs.auth.type", "APPLICATION_DEFAULT"); + catalog = new HadoopCatalog(catalogHadoopConf, warehouseLocation); } @Before public void setUp() { - warehouseLocation = - String.format("%s/IcebergIOIT/%s", options.getTempLocation(), UUID.randomUUID()); - tableId = testName.getMethodName() + ".test_table"; - catalog = new HadoopCatalog(catalogHadoopConf, warehouseLocation); + } + + @AfterClass + public static void afterClass() { + try { + GcsUtil gcsUtil = options.as(GcsOptions.class).getGcsUtil(); + GcsPath path = GcsPath.fromUri(warehouseLocation); + + Objects objects = + gcsUtil.listObjects( + path.getBucket(), "IcebergIOIT/" + path.getFileName().toString(), null); + List filesToDelete = + objects.getItems().stream() + .map(obj -> "gs://" + path.getBucket() + "/" + obj.getName()) + .collect(Collectors.toList()); + + gcsUtil.remove(filesToDelete); + } catch (Exception e) { + LOG.warn("Failed to clean up files.", e); + } } /** Populates the Iceberg table and Returns a {@link List} of expected elements. */ @@ -290,14 +331,16 @@ public void testRead() throws Exception { */ @Test public void testWrite() { - Table table = catalog.createTable(TableIdentifier.parse(tableId), ICEBERG_SCHEMA); - // Write with Beam + // Expect the sink to create the table Map config = managedIcebergConfig(tableId); PCollection input = pipeline.apply(Create.of(INPUT_ROWS)).setRowSchema(BEAM_SCHEMA); input.apply(Managed.write(Managed.ICEBERG).withConfig(config)); pipeline.run().waitUntilFinish(); + Table table = catalog.loadTable(TableIdentifier.parse(tableId)); + assertTrue(table.schema().sameSchema(ICEBERG_SCHEMA)); + // Read back and check records are correct List returnedRecords = readRecords(table); assertThat( @@ -434,22 +477,23 @@ private void writeToDynamicDestinations( Schema tableSchema = IcebergUtils.beamSchemaToIcebergSchema(rowFilter.outputSchema()); - PartitionSpec partitionSpec = null; + TableIdentifier tableIdentifier0 = TableIdentifier.parse(tableId + "_0_a"); + TableIdentifier tableIdentifier1 = TableIdentifier.parse(tableId + "_1_b"); + TableIdentifier tableIdentifier2 = TableIdentifier.parse(tableId + "_2_c"); + TableIdentifier tableIdentifier3 = TableIdentifier.parse(tableId + "_3_d"); + TableIdentifier tableIdentifier4 = TableIdentifier.parse(tableId + "_4_e"); + // the sink doesn't support creating partitioned tables yet, + // so we need to create it manually for this test case if (partitioning) { Preconditions.checkState(filterOp == null || !filterOp.equals("only")); - partitionSpec = + PartitionSpec partitionSpec = PartitionSpec.builderFor(tableSchema).identity("bool").identity("modulo_5").build(); + catalog.createTable(tableIdentifier0, tableSchema, partitionSpec); + catalog.createTable(tableIdentifier1, tableSchema, partitionSpec); + catalog.createTable(tableIdentifier2, tableSchema, partitionSpec); + catalog.createTable(tableIdentifier3, tableSchema, partitionSpec); + catalog.createTable(tableIdentifier4, tableSchema, partitionSpec); } - Table table0 = - catalog.createTable(TableIdentifier.parse(tableId + "_0_a"), tableSchema, partitionSpec); - Table table1 = - catalog.createTable(TableIdentifier.parse(tableId + "_1_b"), tableSchema, partitionSpec); - Table table2 = - catalog.createTable(TableIdentifier.parse(tableId + "_2_c"), tableSchema, partitionSpec); - Table table3 = - catalog.createTable(TableIdentifier.parse(tableId + "_3_d"), tableSchema, partitionSpec); - Table table4 = - catalog.createTable(TableIdentifier.parse(tableId + "_4_e"), tableSchema, partitionSpec); // Write with Beam PCollection input; @@ -467,6 +511,16 @@ private void writeToDynamicDestinations( input.setRowSchema(BEAM_SCHEMA).apply(Managed.write(Managed.ICEBERG).withConfig(writeConfig)); pipeline.run().waitUntilFinish(); + Table table0 = catalog.loadTable(tableIdentifier0); + Table table1 = catalog.loadTable(tableIdentifier1); + Table table2 = catalog.loadTable(tableIdentifier2); + Table table3 = catalog.loadTable(tableIdentifier3); + Table table4 = catalog.loadTable(tableIdentifier4); + + for (Table t : Arrays.asList(table0, table1, table2, table3, table4)) { + assertTrue(t.schema().sameSchema(tableSchema)); + } + // Read back and check records are correct List> returnedRecords = Arrays.asList( diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOWriteTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOWriteTest.java index e62c22be7968..87a543a439ec 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOWriteTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOWriteTest.java @@ -80,9 +80,6 @@ public void testSimpleAppend() throws Exception { TableIdentifier tableId = TableIdentifier.of("default", "table" + Long.toString(UUID.randomUUID().hashCode(), 16)); - // Create a table and add records to it. - Table table = warehouse.createTable(tableId, TestFixtures.SCHEMA); - Map catalogProps = ImmutableMap.builder() .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) @@ -104,6 +101,7 @@ public void testSimpleAppend() throws Exception { testPipeline.run().waitUntilFinish(); LOG.info("Done running pipeline"); + Table table = warehouse.loadTable(tableId); List writtenRecords = ImmutableList.copyOf(IcebergGenerics.read(table).build()); assertThat(writtenRecords, Matchers.containsInAnyOrder(TestFixtures.FILE1SNAPSHOT1.toArray())); @@ -117,11 +115,6 @@ public void testDynamicDestinationsWithoutSpillover() throws Exception { final TableIdentifier table2Id = TableIdentifier.of("default", "table2-" + salt); final TableIdentifier table3Id = TableIdentifier.of("default", "table3-" + salt); - // Create a table and add records to it. - Table table1 = warehouse.createTable(table1Id, TestFixtures.SCHEMA); - Table table2 = warehouse.createTable(table2Id, TestFixtures.SCHEMA); - Table table3 = warehouse.createTable(table3Id, TestFixtures.SCHEMA); - Map catalogProps = ImmutableMap.builder() .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) @@ -177,6 +170,10 @@ public IcebergDestination instantiateDestination(String dest) { testPipeline.run().waitUntilFinish(); LOG.info("Done running pipeline"); + Table table1 = warehouse.loadTable(table1Id); + Table table2 = warehouse.loadTable(table2Id); + Table table3 = warehouse.loadTable(table3Id); + List writtenRecords1 = ImmutableList.copyOf(IcebergGenerics.read(table1).build()); List writtenRecords2 = ImmutableList.copyOf(IcebergGenerics.read(table2).build()); List writtenRecords3 = ImmutableList.copyOf(IcebergGenerics.read(table3).build()); @@ -320,9 +317,6 @@ public void testStreamingWrite() { TableIdentifier.of( "default", "streaming_" + Long.toString(UUID.randomUUID().hashCode(), 16)); - // Create a table and add records to it. - Table table = warehouse.createTable(tableId, TestFixtures.SCHEMA); - Map catalogProps = ImmutableMap.builder() .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) @@ -365,6 +359,8 @@ public void testStreamingWrite() { PAssert.that(snapshots).containsInAnyOrder(2L); testPipeline.run().waitUntilFinish(); + Table table = warehouse.loadTable(tableId); + List writtenRecords = ImmutableList.copyOf(IcebergGenerics.read(table).build()); assertThat(writtenRecords, Matchers.containsInAnyOrder(TestFixtures.FILE1SNAPSHOT1.toArray())); } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergUtilsTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergUtilsTest.java index a20d5b7c8f59..134f05c34bfb 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergUtilsTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergUtilsTest.java @@ -28,16 +28,21 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; import java.util.Arrays; import java.util.List; import java.util.Map; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.data.Record; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.DateTimeUtil; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.junit.Test; @@ -45,6 +50,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +/** Test class for {@link IcebergUtils}. */ @RunWith(Enclosed.class) public class IcebergUtilsTest { @@ -102,21 +108,87 @@ public void testDouble() { } @Test - public void testDate() {} + public void testDate() { + checkRowValueToRecordValue( + Schema.FieldType.logicalType(SqlTypes.DATE), + Types.DateType.get(), + DateTimeUtil.dateFromDays(12345)); + } @Test - public void testTime() {} + public void testTime() { + checkRowValueToRecordValue( + Schema.FieldType.logicalType(SqlTypes.TIME), + Types.TimeType.get(), + DateTimeUtil.timeFromMicros(12345678L)); + } @Test public void testTimestamp() { + // SqlTypes.DATETIME + checkRowValueToRecordValue( + Schema.FieldType.logicalType(SqlTypes.DATETIME), + Types.TimestampType.withoutZone(), + DateTimeUtil.timestampFromMicros(123456789L)); + + // Schema.FieldType.DATETIME DateTime dateTime = new DateTime().withDate(1979, 03, 14).withTime(1, 2, 3, 4).withZone(DateTimeZone.UTC); - checkRowValueToRecordValue( Schema.FieldType.DATETIME, - dateTime.toInstant(), + dateTime, + Types.TimestampType.withoutZone(), + DateTimeUtil.timestampFromMicros(dateTime.getMillis() * 1000L)); + + // Schema.FieldType.INT64 + long micros = 1234567890L; + checkRowValueToRecordValue( + Schema.FieldType.INT64, + micros, Types.TimestampType.withoutZone(), - dateTime.getMillis()); + DateTimeUtil.timestampFromMicros(micros)); + + // Schema.FieldType.STRING + String val = "2024-10-08T13:18:20.053"; + LocalDateTime localDateTime = LocalDateTime.of(2024, 10, 8, 13, 18, 20, 53_000_000); + checkRowValueToRecordValue( + Schema.FieldType.STRING, val, Types.TimestampType.withoutZone(), localDateTime); + } + + @Test + public void testTimestampWithZone() { + String val = "2024-10-08T13:18:20.053+03:27"; + DateTime dateTime = DateTime.parse(val); + OffsetDateTime offsetDateTime = OffsetDateTime.parse(val); + LocalDateTime localDateTime = + offsetDateTime.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime(); + // SqlTypes.DATETIME + checkRowValueToRecordValue( + Schema.FieldType.logicalType(SqlTypes.DATETIME), + localDateTime, + Types.TimestampType.withZone(), + offsetDateTime.withOffsetSameInstant(ZoneOffset.UTC)); + + // Schema.FieldType.DATETIME + checkRowValueToRecordValue( + Schema.FieldType.DATETIME, + dateTime, + Types.TimestampType.withZone(), + offsetDateTime.withOffsetSameInstant(ZoneOffset.UTC)); + + // Schema.FieldType.INT64 + checkRowValueToRecordValue( + Schema.FieldType.INT64, + DateTimeUtil.microsFromTimestamptz(offsetDateTime), + Types.TimestampType.withZone(), + offsetDateTime.withOffsetSameInstant(ZoneOffset.UTC)); + + // Schema.FieldType.STRING + checkRowValueToRecordValue( + Schema.FieldType.STRING, + val, + Types.TimestampType.withZone(), + offsetDateTime.withOffsetSameInstant(ZoneOffset.UTC)); } @Test @@ -190,7 +262,7 @@ private void checkRecordValueToRowValue( Row row = IcebergUtils.icebergRecordToBeamRow(beamSchema, record); - assertThat(row.getBaseValue("v"), equalTo(destValue)); + assertThat(row.getValue("v"), equalTo(destValue)); } @Test @@ -224,21 +296,75 @@ public void testDouble() { } @Test - public void testDate() {} + public void testDate() { + checkRecordValueToRowValue( + Types.DateType.get(), + Schema.FieldType.logicalType(SqlTypes.DATE), + DateTimeUtil.dateFromDays(12345)); + } @Test - public void testTime() {} + public void testTime() { + checkRecordValueToRowValue( + Types.TimeType.get(), + Schema.FieldType.logicalType(SqlTypes.TIME), + DateTimeUtil.timeFromMicros(1234567L)); + } @Test public void testTimestamp() { + // SqlTypes.DATETIME + checkRecordValueToRowValue( + Types.TimestampType.withoutZone(), + Schema.FieldType.logicalType(SqlTypes.DATETIME), + DateTimeUtil.timestampFromMicros(123456789L)); + + // Schema.FieldType.DATETIME DateTime dateTime = new DateTime().withDate(1979, 03, 14).withTime(1, 2, 3, 4).withZone(DateTimeZone.UTC); - checkRecordValueToRowValue( Types.TimestampType.withoutZone(), - dateTime.getMillis(), + dateTime.getMillis() * 1000L, + Schema.FieldType.DATETIME, + dateTime); + } + + @Test + public void testTimestampWithZone() { + String timestamp = "2024-10-08T13:18:20.053+03:27"; + OffsetDateTime offsetDateTime = OffsetDateTime.parse(timestamp); + LocalDateTime localDateTime = + offsetDateTime.withOffsetSameInstant(ZoneOffset.UTC).toLocalDateTime(); + // SqlTypes.DATETIME + checkRecordValueToRowValue( + Types.TimestampType.withZone(), + offsetDateTime, + Schema.FieldType.logicalType(SqlTypes.DATETIME), + localDateTime); + checkRecordValueToRowValue( + Types.TimestampType.withZone(), + localDateTime, + Schema.FieldType.logicalType(SqlTypes.DATETIME), + localDateTime); + checkRecordValueToRowValue( + Types.TimestampType.withZone(), + DateTimeUtil.microsFromTimestamptz(offsetDateTime), + Schema.FieldType.logicalType(SqlTypes.DATETIME), + localDateTime); + + // Schema.FieldType.DATETIME + DateTime dateTime = DateTime.parse(timestamp).withZone(DateTimeZone.UTC); + checkRecordValueToRowValue( + Types.TimestampType.withZone(), offsetDateTime, Schema.FieldType.DATETIME, dateTime); + checkRecordValueToRowValue( + Types.TimestampType.withZone(), localDateTime, Schema.FieldType.DATETIME, dateTime); + checkRecordValueToRowValue( + Types.TimestampType.withZone(), + DateTimeUtil.microsFromTimestamptz(offsetDateTime), Schema.FieldType.DATETIME, - dateTime.toInstant()); + dateTime); + checkRecordValueToRowValue( + Types.TimestampType.withZone(), timestamp, Schema.FieldType.DATETIME, dateTime); } @Test @@ -425,7 +551,7 @@ public void testStructBeamFieldTypeToIcebergFieldType() { new BeamFieldTypeTestCase( 1, Schema.FieldType.row(BEAM_SCHEMA_PRIMITIVE), - 7, + 11, Types.StructType.of(ICEBERG_SCHEMA_PRIMITIVE.columns())), new BeamFieldTypeTestCase( 15, @@ -537,6 +663,10 @@ public void testMapBeamFieldTypeToIcebergFieldType() { .addNullableStringField("str") .addNullableBooleanField("bool") .addByteArrayField("bytes") + .addDateTimeField("datetime_tz") + .addLogicalTypeField("datetime", SqlTypes.DATETIME) + .addLogicalTypeField("time", SqlTypes.TIME) + .addLogicalTypeField("date", SqlTypes.DATE) .build(); static final org.apache.iceberg.Schema ICEBERG_SCHEMA_PRIMITIVE = @@ -547,16 +677,17 @@ public void testMapBeamFieldTypeToIcebergFieldType() { required(4, "long", Types.LongType.get()), optional(5, "str", Types.StringType.get()), optional(6, "bool", Types.BooleanType.get()), - required(7, "bytes", Types.BinaryType.get())); + required(7, "bytes", Types.BinaryType.get()), + required(8, "datetime_tz", Types.TimestampType.withZone()), + required(9, "datetime", Types.TimestampType.withoutZone()), + required(10, "time", Types.TimeType.get()), + required(11, "date", Types.DateType.get())); @Test public void testPrimitiveBeamSchemaToIcebergSchema() { org.apache.iceberg.Schema convertedIcebergSchema = IcebergUtils.beamSchemaToIcebergSchema(BEAM_SCHEMA_PRIMITIVE); - System.out.println(convertedIcebergSchema); - System.out.println(ICEBERG_SCHEMA_PRIMITIVE); - assertTrue(convertedIcebergSchema.sameSchema(ICEBERG_SCHEMA_PRIMITIVE)); } @@ -591,9 +722,6 @@ public void testArrayBeamSchemaToIcebergSchema() { public void testArrayIcebergSchemaToBeamSchema() { Schema convertedBeamSchema = IcebergUtils.icebergSchemaToBeamSchema(ICEBERG_SCHEMA_LIST); - System.out.println(convertedBeamSchema); - System.out.println(BEAM_SCHEMA_LIST); - assertEquals(BEAM_SCHEMA_LIST, convertedBeamSchema); } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java index 3196d303239f..47dc9aa425dd 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java @@ -95,11 +95,6 @@ public void testBuildTransformWithRow() { public void testSimpleAppend() { String identifier = "default.table_" + Long.toString(UUID.randomUUID().hashCode(), 16); - TableIdentifier tableId = TableIdentifier.parse(identifier); - - // Create a table and add records to it. - Table table = warehouse.createTable(tableId, TestFixtures.SCHEMA); - Map properties = new HashMap<>(); properties.put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP); properties.put("warehouse", warehouse.location); @@ -129,6 +124,9 @@ public void testSimpleAppend() { testPipeline.run().waitUntilFinish(); + TableIdentifier tableId = TableIdentifier.parse(identifier); + Table table = warehouse.loadTable(tableId); + List writtenRecords = ImmutableList.copyOf(IcebergGenerics.read(table).build()); assertThat(writtenRecords, Matchers.containsInAnyOrder(TestFixtures.FILE1SNAPSHOT1.toArray())); @@ -137,7 +135,6 @@ public void testSimpleAppend() { @Test public void testWriteUsingManagedTransform() { String identifier = "default.table_" + Long.toString(UUID.randomUUID().hashCode(), 16); - Table table = warehouse.createTable(TableIdentifier.parse(identifier), TestFixtures.SCHEMA); String yamlConfig = String.format( @@ -161,6 +158,7 @@ public void testWriteUsingManagedTransform() { testPipeline.run().waitUntilFinish(); + Table table = warehouse.loadTable(TableIdentifier.parse(identifier)); List writtenRecords = ImmutableList.copyOf(IcebergGenerics.read(table).build()); assertThat(writtenRecords, Matchers.containsInAnyOrder(TestFixtures.FILE1SNAPSHOT1.toArray())); } @@ -261,9 +259,6 @@ private void writeToDynamicDestinationsAndFilter(@Nullable String operation, boo org.apache.iceberg.Schema icebergSchema = IcebergUtils.beamSchemaToIcebergSchema(filter.outputSchema()); - Table table0 = warehouse.createTable(TableIdentifier.parse(identifier0), icebergSchema); - Table table1 = warehouse.createTable(TableIdentifier.parse(identifier1), icebergSchema); - Table table2 = warehouse.createTable(TableIdentifier.parse(identifier2), icebergSchema); TestStream stream = TestStream.create(beamSchema) @@ -301,6 +296,9 @@ private void writeToDynamicDestinationsAndFilter(@Nullable String operation, boo testPipeline.run().waitUntilFinish(); + Table table0 = warehouse.loadTable(TableIdentifier.parse(identifier0)); + Table table1 = warehouse.loadTable(TableIdentifier.parse(identifier1)); + Table table2 = warehouse.loadTable(TableIdentifier.parse(identifier2)); List table0Records = ImmutableList.copyOf(IcebergGenerics.read(table0).build()); List table1Records = ImmutableList.copyOf(IcebergGenerics.read(table1).build()); List table2Records = ImmutableList.copyOf(IcebergGenerics.read(table2).build()); diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java index 1c2e8bc2c451..2bce390e0992 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java @@ -19,25 +19,30 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.either; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.IOException; +import java.util.List; import java.util.Map; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.Row; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.commons.lang3.RandomStringUtils; import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.DataFile; import org.apache.iceberg.FileFormat; -import org.apache.iceberg.ManifestFile; import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.hadoop.HadoopCatalog; import org.checkerframework.checker.nullness.qual.Nullable; @@ -45,6 +50,7 @@ import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.rules.TemporaryFolder; import org.junit.rules.TestName; import org.junit.runner.RunWith; @@ -74,6 +80,7 @@ public void setUp() { windowedDestination = getWindowedDestination("table_" + testName.getMethodName(), PARTITION_SPEC); catalog = new HadoopCatalog(new Configuration(), warehouse.location); + RecordWriterManager.TABLE_CACHE.invalidateAll(); } private WindowedValue getWindowedDestination( @@ -147,9 +154,10 @@ public void testCreateNewWriterForEachDestination() throws IOException { writerManager.close(); assertEquals(0, writerManager.openWriters); - // We should only have 3 manifest files (one for each destination we wrote to) - assertEquals(3, writerManager.getManifestFiles().keySet().size()); - assertThat(writerManager.getManifestFiles().keySet(), containsInAnyOrder(dest1, dest2, dest3)); + // We should only have 3 data files (one for each destination we wrote to) + assertEquals(3, writerManager.getSerializableDataFiles().keySet().size()); + assertThat( + writerManager.getSerializableDataFiles().keySet(), containsInAnyOrder(dest1, dest2, dest3)); } @Test @@ -195,16 +203,21 @@ public void testCreateNewWriterForEachPartition() throws IOException { assertFalse(writeSuccess); assertEquals(3, writerManager.openWriters); - // Closing PartitionRecordWriter will close all writers. + // Closing RecordWriterManager will close all writers. writerManager.close(); assertEquals(0, writerManager.openWriters); - assertEquals(1, writerManager.getManifestFiles().size()); - ManifestFile manifestFile = - Iterables.getOnlyElement(writerManager.getManifestFiles().get(windowedDestination)); - - assertEquals(3, manifestFile.addedFilesCount().intValue()); - assertEquals(4, manifestFile.addedRowsCount().intValue()); + // We should have only one destination + assertEquals(1, writerManager.getSerializableDataFiles().size()); + assertTrue(writerManager.getSerializableDataFiles().containsKey(windowedDestination)); + // We should have 3 data files (one for each partition we wrote to) + assertEquals(3, writerManager.getSerializableDataFiles().get(windowedDestination).size()); + long totalRows = 0; + for (SerializableDataFile dataFile : + writerManager.getSerializableDataFiles().get(windowedDestination)) { + totalRows += dataFile.getRecordCount(); + } + assertEquals(4L, totalRows); } @Test @@ -255,12 +268,208 @@ public void testRespectMaxFileSize() throws IOException { } @Test - public void testRequireClosingBeforeFetchingManifestFiles() { + public void testRequireClosingBeforeFetchingDataFiles() { RecordWriterManager writerManager = new RecordWriterManager(catalog, "test_file_name", 100, 2); Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "aaa", true).build(); writerManager.write(windowedDestination, row); assertEquals(1, writerManager.openWriters); - assertThrows(IllegalStateException.class, writerManager::getManifestFiles); + assertThrows(IllegalStateException.class, writerManager::getSerializableDataFiles); + } + + /** DataFile doesn't implement a .equals() method. Check equality manually. */ + private static void checkDataFileEquality(DataFile d1, DataFile d2) { + assertEquals(d1.path(), d2.path()); + assertEquals(d1.format(), d2.format()); + assertEquals(d1.recordCount(), d2.recordCount()); + assertEquals(d1.partition(), d2.partition()); + assertEquals(d1.specId(), d2.specId()); + assertEquals(d1.keyMetadata(), d2.keyMetadata()); + assertEquals(d1.splitOffsets(), d2.splitOffsets()); + assertEquals(d1.columnSizes(), d2.columnSizes()); + assertEquals(d1.valueCounts(), d2.valueCounts()); + assertEquals(d1.nullValueCounts(), d2.nullValueCounts()); + assertEquals(d1.nanValueCounts(), d2.nanValueCounts()); + assertEquals(d1.equalityFieldIds(), d2.equalityFieldIds()); + assertEquals(d1.fileSequenceNumber(), d2.fileSequenceNumber()); + assertEquals(d1.dataSequenceNumber(), d2.dataSequenceNumber()); + assertEquals(d1.pos(), d2.pos()); + } + + @Test + public void testSerializableDataFileRoundTripEquality() throws IOException { + PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA); + + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "abcxyz", true).build(); + // same partition for both records (name_trunc=abc, bool=true) + partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + + RecordWriter writer = + new RecordWriter(catalog, windowedDestination.getValue(), "test_file_name", partitionKey); + writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row2)); + + writer.close(); + DataFile datafile = writer.getDataFile(); + assertEquals(2L, datafile.recordCount()); + + DataFile roundTripDataFile = + SerializableDataFile.from(datafile, partitionKey) + .createDataFile(ImmutableMap.of(PARTITION_SPEC.specId(), PARTITION_SPEC)); + + checkDataFileEquality(datafile, roundTripDataFile); + } + + /** + * Users may update the table's spec while a write pipeline is running. Sometimes, this can happen + * after converting {@link DataFile} to {@link SerializableDataFile}s. When converting back to + * {@link DataFile} to commit in the {@link AppendFilesToTables} step, we need to make sure to use + * the same {@link PartitionSpec} it was originally created with. + * + *

    This test checks that we're preserving the right {@link PartitionSpec} when such an update + * happens. + */ + @Test + public void testRecreateSerializableDataAfterUpdatingPartitionSpec() throws IOException { + PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA); + + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "abcxyz", true).build(); + // same partition for both records (name_trunc=abc, bool=true) + partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + + // write some rows + RecordWriter writer = + new RecordWriter(catalog, windowedDestination.getValue(), "test_file_name", partitionKey); + writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row2)); + writer.close(); + + // fetch data file and its serializable version + DataFile datafile = writer.getDataFile(); + SerializableDataFile serializableDataFile = SerializableDataFile.from(datafile, partitionKey); + + assertEquals(2L, datafile.recordCount()); + assertEquals(serializableDataFile.getPartitionSpecId(), datafile.specId()); + + // update spec + Table table = catalog.loadTable(windowedDestination.getValue().getTableIdentifier()); + table.updateSpec().addField("id").removeField("bool").commit(); + + Map updatedSpecs = table.specs(); + DataFile roundTripDataFile = serializableDataFile.createDataFile(updatedSpecs); + + checkDataFileEquality(datafile, roundTripDataFile); + } + + @Test + public void testWriterKeepsUpWithUpdatingPartitionSpec() throws IOException { + Table table = catalog.loadTable(windowedDestination.getValue().getTableIdentifier()); + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "abcxyz", true).build(); + + // write some rows + RecordWriterManager writer = + new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, Integer.MAX_VALUE); + writer.write(windowedDestination, row); + writer.write(windowedDestination, row2); + writer.close(); + DataFile dataFile = + writer + .getSerializableDataFiles() + .get(windowedDestination) + .get(0) + .createDataFile(table.specs()); + + // check data file path contains the correct partition components + assertEquals(2L, dataFile.recordCount()); + assertEquals(dataFile.specId(), PARTITION_SPEC.specId()); + assertThat(dataFile.path().toString(), containsString("name_trunc=abc")); + assertThat(dataFile.path().toString(), containsString("bool=true")); + + // table is cached + assertEquals(1, RecordWriterManager.TABLE_CACHE.size()); + + // update spec + table.updateSpec().addField("id").removeField("bool").commit(); + + // write a second data file + // should refresh the table and use the new partition spec + RecordWriterManager writer2 = + new RecordWriterManager(catalog, "test_prefix_2", Long.MAX_VALUE, Integer.MAX_VALUE); + writer2.write(windowedDestination, row); + writer2.write(windowedDestination, row2); + writer2.close(); + + List serializableDataFiles = + writer2.getSerializableDataFiles().get(windowedDestination); + assertEquals(2, serializableDataFiles.size()); + for (SerializableDataFile serializableDataFile : serializableDataFiles) { + assertEquals(table.spec().specId(), serializableDataFile.getPartitionSpecId()); + dataFile = serializableDataFile.createDataFile(table.specs()); + assertEquals(1L, dataFile.recordCount()); + assertThat(dataFile.path().toString(), containsString("name_trunc=abc")); + assertThat( + dataFile.path().toString(), either(containsString("id=1")).or(containsString("id=2"))); + } + } + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testWriterExceptionGetsCaught() throws IOException { + RecordWriterManager writerManager = new RecordWriterManager(catalog, "test_file_name", 100, 2); + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA); + partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + + writerManager.write(windowedDestination, row); + + RecordWriterManager.DestinationState state = + writerManager.destinations.get(windowedDestination); + // replace with a failing record writer + FailingRecordWriter failingWriter = + new FailingRecordWriter( + catalog, windowedDestination.getValue(), "test_failing_writer", partitionKey); + state.writers.put(partitionKey, failingWriter); + writerManager.write(windowedDestination, row); + + // this tests that we indeed enter the catch block + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Encountered 1 failed writer(s)"); + try { + writerManager.close(); + } catch (IllegalStateException e) { + // fetch underlying exceptions and validate + Throwable[] underlyingExceptions = e.getSuppressed(); + assertEquals(1, underlyingExceptions.length); + for (Throwable t : underlyingExceptions) { + assertThat( + t.getMessage(), + containsString("Encountered an error when closing data writer for table")); + assertThat( + t.getMessage(), + containsString(windowedDestination.getValue().getTableIdentifier().toString())); + assertThat(t.getMessage(), containsString(failingWriter.path())); + Throwable realCause = t.getCause(); + assertEquals("I am failing!", realCause.getMessage()); + } + + throw e; + } + } + + static class FailingRecordWriter extends RecordWriter { + FailingRecordWriter( + Catalog catalog, IcebergDestination destination, String filename, PartitionKey partitionKey) + throws IOException { + super(catalog, destination, filename, partitionKey); + } + + @Override + public void close() throws IOException { + throw new IOException("I am failing!"); + } } } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java index ad4fc6b382d4..1e1c84d31de9 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java @@ -149,4 +149,8 @@ public Table createTable( someTableHasBeenCreated = true; return catalog.createTable(tableId, schema, partitionSpec); } + + public Table loadTable(TableIdentifier tableId) { + return catalog.loadTable(tableId); + } } diff --git a/sdks/java/io/jdbc/build.gradle b/sdks/java/io/jdbc/build.gradle index 2015bf173978..8c5fa685fdad 100644 --- a/sdks/java/io/jdbc/build.gradle +++ b/sdks/java/io/jdbc/build.gradle @@ -48,6 +48,8 @@ dependencies { testImplementation library.java.testcontainers_mysql testImplementation library.java.testcontainers_postgresql testImplementation 'mysql:mysql-connector-java:8.0.22' + // TODO(https://github.com/apache/beam/issues/31678) HikariCP 5.x requires Java11+ + testImplementation 'com.zaxxer:HikariCP:4.0.3' testRuntimeOnly library.java.slf4j_jdk14 testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 2f164fa3bb78..ab2e3e07e817 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -39,6 +39,7 @@ import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; @@ -55,6 +56,7 @@ import org.apache.beam.sdk.io.jdbc.JdbcUtil.PartitioningFn; import org.apache.beam.sdk.io.jdbc.SchemaUtil.FieldWithIndex; import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; @@ -93,6 +95,7 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.sdk.values.TypeDescriptors.TypeVariableExtractor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.commons.dbcp2.BasicDataSource; import org.apache.commons.dbcp2.DataSourceConnectionFactory; import org.apache.commons.dbcp2.PoolableConnectionFactory; @@ -333,6 +336,7 @@ public static Read read() { return new AutoValue_JdbcIO_Read.Builder() .setFetchSize(DEFAULT_FETCH_SIZE) .setOutputParallelization(true) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .build(); } @@ -341,6 +345,7 @@ public static ReadRows readRows() { return new AutoValue_JdbcIO_ReadRows.Builder() .setFetchSize(DEFAULT_FETCH_SIZE) .setOutputParallelization(true) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .setStatementPreparator(ignored -> {}) .build(); } @@ -356,6 +361,7 @@ public static ReadAll readAll() { return new AutoValue_JdbcIO_ReadAll.Builder() .setFetchSize(DEFAULT_FETCH_SIZE) .setOutputParallelization(true) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .build(); } @@ -372,6 +378,7 @@ public static ReadWithPartitions read .setPartitionColumnType(partitioningColumnType) .setNumPartitions(DEFAULT_NUM_PARTITIONS) .setFetchSize(DEFAULT_FETCH_SIZE) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .setUseBeamSchema(false) .build(); } @@ -389,6 +396,7 @@ public static ReadWithPartitions read .setPartitionsHelper(partitionsHelper) .setNumPartitions(DEFAULT_NUM_PARTITIONS) .setFetchSize(DEFAULT_FETCH_SIZE) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .setUseBeamSchema(false) .build(); } @@ -400,6 +408,7 @@ public static ReadWithPartitions readWithPartitions() { private static final long DEFAULT_BATCH_SIZE = 1000L; private static final long DEFAULT_MAX_BATCH_BUFFERING_DURATION = 200L; private static final int DEFAULT_FETCH_SIZE = 50_000; + private static final boolean DEFAULT_DISABLE_AUTO_COMMIT = true; // Default values used from fluent backoff. private static final Duration DEFAULT_INITIAL_BACKOFF = Duration.standardSeconds(1); private static final Duration DEFAULT_MAX_CUMULATIVE_BACKOFF = Duration.standardDays(1000); @@ -733,6 +742,9 @@ public abstract static class ReadRows extends PTransform expand(PBegin input) { ValueProvider query = checkStateNotNull(getQuery(), "withQuery() is required"); @@ -816,6 +839,7 @@ public PCollection expand(PBegin input) { .withCoder(RowCoder.of(schema)) .withRowMapper(SchemaUtil.BeamRowMapper.of(schema)) .withFetchSize(getFetchSize()) + .withDisableAutoCommit(getDisableAutoCommit()) .withOutputParallelization(getOutputParallelization()) .withStatementPreparator(checkStateNotNull(getStatementPreparator()))); rows.setRowSchema(schema); @@ -872,6 +896,9 @@ public abstract static class Read extends PTransform> @Pure abstract boolean getOutputParallelization(); + @Pure + abstract boolean getDisableAutoCommit(); + @Pure abstract Builder toBuilder(); @@ -892,6 +919,8 @@ abstract Builder setDataSourceProviderFn( abstract Builder setOutputParallelization(boolean outputParallelization); + abstract Builder setDisableAutoCommit(boolean disableAutoCommit); + abstract Read build(); } @@ -958,6 +987,15 @@ public Read withOutputParallelization(boolean outputParallelization) { return toBuilder().setOutputParallelization(outputParallelization).build(); } + /** + * Whether to disable auto commit on read. Defaults to true if not provided. The need for this + * config varies depending on the database platform. Informix requires this to be set to false + * while Postgres requires this to be set to true. + */ + public Read withDisableAutoCommit(boolean disableAutoCommit) { + return toBuilder().setDisableAutoCommit(disableAutoCommit).build(); + } + @Override public PCollection expand(PBegin input) { ValueProvider query = checkArgumentNotNull(getQuery(), "withQuery() is required"); @@ -974,6 +1012,7 @@ public PCollection expand(PBegin input) { .withRowMapper(rowMapper) .withFetchSize(getFetchSize()) .withOutputParallelization(getOutputParallelization()) + .withDisableAutoCommit(getDisableAutoCommit()) .withParameterSetter( (element, preparedStatement) -> { if (getStatementPreparator() != null) { @@ -1029,6 +1068,8 @@ public abstract static class ReadAll abstract boolean getOutputParallelization(); + abstract boolean getDisableAutoCommit(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -1049,6 +1090,8 @@ abstract Builder setParameterSetter( abstract Builder setOutputParallelization(boolean outputParallelization); + abstract Builder setDisableAutoCommit(boolean disableAutoCommit); + abstract ReadAll build(); } @@ -1127,6 +1170,15 @@ public ReadAll withOutputParallelization(boolean outputPara return toBuilder().setOutputParallelization(outputParallelization).build(); } + /** + * Whether to disable auto commit on read. Defaults to true if not provided. The need for this + * config varies depending on the database platform. Informix requires this to be set to false + * while Postgres requires this to be set to true. + */ + public ReadAll withDisableAutoCommit(boolean disableAutoCommit) { + return toBuilder().setDisableAutoCommit(disableAutoCommit).build(); + } + private @Nullable Coder inferCoder( CoderRegistry registry, SchemaRegistry schemaRegistry) { if (getCoder() != null) { @@ -1173,7 +1225,8 @@ public PCollection expand(PCollection input) { checkStateNotNull(getQuery()), checkStateNotNull(getParameterSetter()), checkStateNotNull(getRowMapper()), - getFetchSize()))) + getFetchSize(), + getDisableAutoCommit()))) .setCoder(coder); if (getOutputParallelization()) { @@ -1254,6 +1307,9 @@ public abstract static class ReadWithPartitions @Pure abstract @Nullable JdbcReadWithPartitionsHelper getPartitionsHelper(); + @Pure + abstract boolean getDisableAutoCommit(); + @Pure abstract Builder toBuilder(); @@ -1287,6 +1343,8 @@ abstract Builder setPartitionColumnType( abstract Builder setPartitionsHelper( JdbcReadWithPartitionsHelper partitionsHelper); + abstract Builder setDisableAutoCommit(boolean disableAutoCommit); + abstract ReadWithPartitions build(); } @@ -1337,6 +1395,16 @@ public ReadWithPartitions withFetchSize(int fetchSize) { return toBuilder().setFetchSize(fetchSize).build(); } + /** + * Whether to disable auto commit on read. Defaults to true if not provided. The need for this + * config varies depending on the database platform. Informix requires this to be set to false + * while Postgres requires this to be set to true. + */ + public ReadWithPartitions withDisableAutoCommit( + boolean disableAutoCommit) { + return toBuilder().setDisableAutoCommit(disableAutoCommit).build(); + } + /** Data output type is {@link Row}, and schema is auto-inferred from the database. */ public ReadWithPartitions withRowOutput() { return toBuilder().setUseBeamSchema(true).build(); @@ -1419,7 +1487,8 @@ && getLowerBound() instanceof Comparable) { .withQuery(query) .withDataSourceProviderFn(dataSourceProviderFn) .withRowMapper(checkStateNotNull(partitionsHelper)) - .withFetchSize(getFetchSize())) + .withFetchSize(getFetchSize()) + .withDisableAutoCommit(getDisableAutoCommit())) .apply( MapElements.via( new SimpleFunction< @@ -1487,7 +1556,8 @@ public KV> apply( .withRowMapper(rowMapper) .withFetchSize(getFetchSize()) .withParameterSetter(checkStateNotNull(partitionsHelper)) - .withOutputParallelization(false); + .withOutputParallelization(false) + .withDisableAutoCommit(getDisableAutoCommit()); if (getUseBeamSchema()) { checkStateNotNull(schema); @@ -1537,21 +1607,25 @@ private static class ReadFn extends DoFn parameterSetter; private final RowMapper rowMapper; private final int fetchSize; + private final boolean disableAutoCommit; private @Nullable DataSource dataSource; private @Nullable Connection connection; + private @Nullable String reportedLineage; private ReadFn( SerializableFunction dataSourceProviderFn, ValueProvider query, PreparedStatementSetter parameterSetter, RowMapper rowMapper, - int fetchSize) { + int fetchSize, + boolean disableAutoCommit) { this.dataSourceProviderFn = dataSourceProviderFn; this.query = query; this.parameterSetter = parameterSetter; this.rowMapper = rowMapper; this.fetchSize = fetchSize; + this.disableAutoCommit = disableAutoCommit; } @Setup @@ -1560,10 +1634,26 @@ public void setup() throws Exception { } private Connection getConnection() throws SQLException { - if (this.connection == null) { - this.connection = checkStateNotNull(this.dataSource).getConnection(); + Connection connection = this.connection; + if (connection == null) { + DataSource validSource = checkStateNotNull(this.dataSource); + connection = checkStateNotNull(validSource).getConnection(); + this.connection = connection; + + // report Lineage if not haven't done so + String table = JdbcUtil.extractTableFromReadQuery(query.get()); + if (!table.equals(reportedLineage)) { + JdbcUtil.FQNComponents fqn = JdbcUtil.FQNComponents.of(validSource); + if (fqn == null) { + fqn = JdbcUtil.FQNComponents.of(connection); + } + if (fqn != null) { + fqn.reportLineage(Lineage.getSources(), table); + reportedLineage = table; + } + } } - return this.connection; + return connection; } @ProcessElement @@ -1577,8 +1667,12 @@ public void processElement(ProcessContext context) throws Exception { Connection connection = getConnection(); // PostgreSQL requires autocommit to be disabled to enable cursor streaming // see https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor - LOG.info("Autocommit has been disabled"); - connection.setAutoCommit(false); + // This option is configurable as Informix will error + // if calling setAutoCommit on a non-logged database + if (disableAutoCommit) { + LOG.info("Autocommit has been disabled"); + connection.setAutoCommit(false); + } try (PreparedStatement statement = connection.prepareStatement( query.get(), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)) { @@ -2571,6 +2665,7 @@ abstract Builder setMaxBatchBufferingDuration( private @Nullable DataSource dataSource; private @Nullable Connection connection; private @Nullable PreparedStatement preparedStatement; + private @Nullable String reportedLineage; private static @Nullable FluentBackoff retryBackOff; public WriteFn(WriteFnSpec spec) { @@ -2603,11 +2698,28 @@ public void setup() { private Connection getConnection() throws SQLException { Connection connection = this.connection; if (connection == null) { - connection = checkStateNotNull(dataSource).getConnection(); + DataSource validSource = checkStateNotNull(dataSource); + connection = validSource.getConnection(); connection.setAutoCommit(false); preparedStatement = connection.prepareStatement(checkStateNotNull(spec.getStatement()).get()); this.connection = connection; + + // report Lineage if haven't done so + String table = spec.getTable(); + if (Strings.isNullOrEmpty(table) && spec.getStatement() != null) { + table = JdbcUtil.extractTableFromWriteQuery(spec.getStatement().get()); + } + if (!Objects.equals(table, reportedLineage)) { + JdbcUtil.FQNComponents fqn = JdbcUtil.FQNComponents.of(validSource); + if (fqn == null) { + fqn = JdbcUtil.FQNComponents.of(connection); + } + if (fqn != null) { + fqn.reportLineage(Lineage.getSinks(), table); + reportedLineage = table; + } + } } return connection; } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java index 0139207235a0..435bfc138b5b 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java @@ -117,6 +117,10 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { if (outputParallelization != null) { readRows = readRows.withOutputParallelization(outputParallelization); } + Boolean disableAutoCommit = config.getDisableAutoCommit(); + if (disableAutoCommit != null) { + readRows = readRows.withDisableAutoCommit(disableAutoCommit); + } return PCollectionRowTuple.of("output", input.getPipeline().apply(readRows)); } } @@ -174,6 +178,9 @@ public abstract static class JdbcReadSchemaTransformConfiguration implements Ser @Nullable public abstract Boolean getOutputParallelization(); + @Nullable + public abstract Boolean getDisableAutoCommit(); + @Nullable public abstract String getDriverJars(); @@ -238,6 +245,8 @@ public abstract static class Builder { public abstract Builder setOutputParallelization(Boolean value); + public abstract Builder setDisableAutoCommit(Boolean value); + public abstract Builder setDriverJars(String value); public abstract JdbcReadSchemaTransformConfiguration build(); diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java index 4b5dc0d7e24a..23221042938b 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java @@ -65,6 +65,7 @@ public Schema configurationSchema() { .addNullableField("readQuery", FieldType.STRING) .addNullableField("writeStatement", FieldType.STRING) .addNullableField("fetchSize", FieldType.INT16) + .addNullableField("disableAutoCommit", FieldType.BOOLEAN) .addNullableField("outputParallelization", FieldType.BOOLEAN) .addNullableField("autosharding", FieldType.BOOLEAN) // Partitioning support. If you specify a partition column we will use that instead of @@ -73,6 +74,7 @@ public Schema configurationSchema() { .addNullableField("partitions", FieldType.INT16) .addNullableField("maxConnections", FieldType.INT16) .addNullableField("driverJars", FieldType.STRING) + .addNullableField("writeBatchSize", FieldType.INT64) .build(); } @@ -140,6 +142,11 @@ public PCollection expand(PBegin input) { readRows = readRows.withFetchSize(fetchSize); } + @Nullable Boolean disableAutoCommit = config.getBoolean("disableAutoCommit"); + if (disableAutoCommit != null) { + readRows = readRows.withDisableAutoCommit(disableAutoCommit); + } + return input.apply(readRows); } else { @@ -163,6 +170,11 @@ public PCollection expand(PBegin input) { readRows = readRows.withOutputParallelization(outputParallelization); } + @Nullable Boolean disableAutoCommit = config.getBoolean("disableAutoCommit"); + if (disableAutoCommit != null) { + readRows = readRows.withDisableAutoCommit(disableAutoCommit); + } + return input.apply(readRows); } } @@ -183,6 +195,10 @@ public PDone expand(PCollection input) { if (autosharding != null && autosharding) { writeRows = writeRows.withAutoSharding(); } + @Nullable Long writeBatchSize = config.getInt64("writeBatchSize"); + if (writeBatchSize != null) { + writeRows = writeRows.withBatchSize(writeBatchSize); + } return input.apply(writeRows); } }; diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java index b3f46492f745..c0f7d68899b3 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java @@ -19,12 +19,18 @@ import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import com.google.auto.value.AutoValue; import java.io.File; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.URI; import java.net.URL; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; import java.nio.file.Paths; +import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.Date; import java.sql.JDBCType; import java.sql.PreparedStatement; @@ -33,6 +39,7 @@ import java.sql.Time; import java.sql.Timestamp; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.Collection; import java.util.EnumMap; @@ -40,12 +47,17 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Properties; import java.util.TimeZone; import java.util.UUID; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.IntStream; +import javax.sql.DataSource; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.logicaltypes.FixedPrecisionNumeric; import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant; @@ -57,6 +69,8 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; @@ -563,4 +577,251 @@ public KV> mapRow(ResultSet resultSet) throws Excep } } }); + + @AutoValue + abstract static class JdbcUrl { + abstract String getScheme(); + + abstract @Nullable String getHostAndPort(); + + abstract String getDatabase(); + + /** + * Parse Jdbc Url String and return an {@link JdbcUrl} object, or return null for unsupported + * formats. + * + *

    Example of supported format: + * + *

      + *
    • "jdbc:postgresql://localhost:5432/postgres" + *
    • "jdbc:mysql://127.0.0.1:3306/db" + *
    • "jdbc:oracle:thin:HR/hr@localhost:5221:orcl" + *
    • "jdbc:derby:memory:testDB;create=true" + *
    • "jdbc:oracle:thin:@//myhost.example.com:1521/my_service" + *
    • "jdbc:mysql:///cloud_sql" (GCP CloudSQL, supported if Connection name setup via + * HikariDataSource) + *
    + */ + static @Nullable JdbcUrl of(String url) { + if (Strings.isNullOrEmpty(url) || !url.startsWith("jdbc:")) { + return null; + } + String cleanUri = url.substring(5); + + // 1. Resolve the scheme + // handle sub-schemes e.g. oracle:thin (RAC) + int start = cleanUri.indexOf("//"); + if (start != -1) { + List subschemes = Splitter.on(':').splitToList(cleanUri.substring(0, start)); + cleanUri = subschemes.get(0) + ":" + cleanUri.substring(start); + } else { + // not a URI format e.g. oracle:thin (non-RAC); derby in memory + if (cleanUri.startsWith("derby:")) { + String scheme = "derby"; + int endUrl = cleanUri.indexOf(";"); + if (endUrl == -1) { + endUrl = cleanUri.length(); + } + List components = + Splitter.on(':').splitToList(cleanUri.substring("derby:".length(), endUrl)); + if (components.size() < 2) { + return null; + } + return new AutoValue_JdbcUtil_JdbcUrl(scheme, components.get(0), components.get(1)); + } else if (cleanUri.startsWith("oracle:thin:")) { + String scheme = "oracle"; + + int startHost = cleanUri.indexOf("@"); + if (startHost == -1) { + return null; + } + List components = Splitter.on(':').splitToList(cleanUri.substring(startHost + 1)); + if (components.size() < 3) { + return null; + } + return new AutoValue_JdbcUtil_JdbcUrl( + scheme, components.get(0) + ":" + components.get(1), components.get(2)); + } else { + return null; + } + } + + URI uri = URI.create(cleanUri); + String scheme = uri.getScheme(); + + // 2. resolve database + @Nullable String path = uri.getPath(); + if (path != null && path.startsWith("/")) { + path = path.substring(1); + } + if (path == null) { + return null; + } + + // 3. resolve host and port + // treat as self-managed SQL instance + @Nullable String hostAndPort = null; + @Nullable String host = uri.getHost(); + if (host != null) { + int port = uri.getPort(); + hostAndPort = port != -1 ? host + ":" + port : null; + } + return new AutoValue_JdbcUtil_JdbcUrl(scheme, hostAndPort, path); + } + } + + /** Jdbc fully qualified name components. */ + @AutoValue + abstract static class FQNComponents { + abstract String getScheme(); + + abstract Iterable getSegments(); + + void reportLineage(Lineage lineage, @Nullable String table) { + ImmutableList.Builder builder = ImmutableList.builder().addAll(getSegments()); + if (table != null && !table.isEmpty()) { + builder.add(table); + } + lineage.add(getScheme(), builder.build()); + } + + /** Fail-safely extract FQN from supported DataSource. Return null if failed. */ + static @Nullable FQNComponents of(DataSource dataSource) { + // Supported case CloudSql using HikariDataSource + // Had to retrieve properties via Reflection to avoid introduce mandatory Hikari dependencies + String maybeSqlInstance; + String url; + try { + Class hikariClass = Class.forName("com.zaxxer.hikari.HikariDataSource"); + if (!hikariClass.isInstance(dataSource)) { + return null; + } + Method getProperties = hikariClass.getMethod("getDataSourceProperties"); + Properties properties = (Properties) getProperties.invoke(dataSource); + if (properties == null) { + return null; + } + maybeSqlInstance = properties.getProperty("cloudSqlInstance"); + if (maybeSqlInstance == null) { + // not a cloudSqlInstance + return null; + } + Method getUrl = hikariClass.getMethod("getJdbcUrl"); + url = (String) getUrl.invoke(dataSource); + if (url == null) { + return null; + } + } catch (ClassNotFoundException + | InvocationTargetException + | IllegalAccessException + | NoSuchMethodException e) { + return null; + } + + JdbcUrl jdbcUrl = JdbcUrl.of(url); + if (jdbcUrl == null) { + LOG.info("Failed to parse JdbcUrl {}. Lineage will not be reported.", url); + return null; + } + + String scheme = "cloudsql_" + jdbcUrl.getScheme(); + ImmutableList.Builder segments = ImmutableList.builder(); + List sqlInstance = Arrays.asList(maybeSqlInstance.split(":")); + if (sqlInstance.size() > 3) { + // project name contains ":" + segments + .add(String.join(":", sqlInstance.subList(0, sqlInstance.size() - 2))) + .add(sqlInstance.get(sqlInstance.size() - 2)) + .add(sqlInstance.get(sqlInstance.size() - 1)); + } else { + segments.addAll(Arrays.asList(maybeSqlInstance.split(":"))); + } + segments.add(jdbcUrl.getDatabase()); + return new AutoValue_JdbcUtil_FQNComponents(scheme, segments.build()); + } + + /** Fail-safely extract FQN from an active connection. Return null if failed. */ + static @Nullable FQNComponents of(Connection connection) { + try { + DatabaseMetaData metadata = connection.getMetaData(); + if (metadata == null) { + // usually not-null, but can be null when running a mock + return null; + } + String url = metadata.getURL(); + if (url == null) { + // usually not-null, but can be null when running a mock + return null; + } + return of(url); + } catch (Exception e) { + // suppressed + return null; + } + } + + /** + * Fail-safely parse FQN from a Jdbc URL. Return null if failed. + * + *

    e.g. + * + *

    jdbc:postgresql://localhost:5432/postgres -> (postgresql, [localhost:5432, postgres]) + * + *

    jdbc:mysql://127.0.0.1:3306/db -> (mysql, [127.0.0.1:3306, db]) + */ + @VisibleForTesting + static @Nullable FQNComponents of(String url) { + JdbcUrl jdbcUrl = JdbcUrl.of(url); + if (jdbcUrl == null || jdbcUrl.getHostAndPort() == null) { + LOG.info("Failed to parse JdbcUrl {}. Lineage will not be reported.", url); + return null; + } + String hostAndPort = jdbcUrl.getHostAndPort(); + if (hostAndPort == null) { + LOG.info("Failed to parse host/port from JdbcUrl {}. Lineage will not be reported.", url); + return null; + } + + return new AutoValue_JdbcUtil_FQNComponents( + jdbcUrl.getScheme(), ImmutableList.of(hostAndPort, jdbcUrl.getDatabase())); + } + } + + private static final Pattern READ_STATEMENT_PATTERN = + Pattern.compile( + "SELECT\\s+.+?\\s+FROM\\s+\\[?(?[^\\s\\[\\]]+)\\]?", Pattern.CASE_INSENSITIVE); + + private static final Pattern WRITE_STATEMENT_PATTERN = + Pattern.compile( + "INSERT\\s+INTO\\s+\\[?(?[^\\s\\[\\]]+)\\]?", Pattern.CASE_INSENSITIVE); + + /** Extract table name a SELECT statement. Return empty string if fail to extract. */ + static String extractTableFromReadQuery(@Nullable String query) { + if (query == null) { + return ""; + } + Matcher matchRead = READ_STATEMENT_PATTERN.matcher(query); + if (matchRead.find()) { + String matched = matchRead.group("tableName"); + if (matched != null) { + return matched; + } + } + return ""; + } + + /** Extract table name from an INSERT statement. Return empty string if fail to extract. */ + static String extractTableFromWriteQuery(@Nullable String query) { + if (query == null) { + return ""; + } + Matcher matchRead = WRITE_STATEMENT_PATTERN.matcher(query); + if (matchRead.find()) { + String matched = matchRead.group("tableName"); + if (matched != null) { + return matched; + } + } + return ""; + } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java index a409b604b11f..1f970ba0624f 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java @@ -141,6 +141,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { if (autosharding != null && autosharding) { writeRows = writeRows.withAutoSharding(); } + + Long writeBatchSize = config.getBatchSize(); + if (writeBatchSize != null) { + writeRows = writeRows.withBatchSize(writeBatchSize); + } + PCollection postWrite = input .get("input") @@ -205,6 +211,9 @@ public abstract static class JdbcWriteSchemaTransformConfiguration implements Se @Nullable public abstract String getDriverJars(); + @Nullable + public abstract Long getBatchSize(); + public void validate() throws IllegalArgumentException { if (Strings.isNullOrEmpty(getJdbcUrl())) { throw new IllegalArgumentException("JDBC URL cannot be blank"); @@ -268,6 +277,8 @@ public abstract Builder setConnectionInitSql( public abstract Builder setDriverJars(String value); + public abstract Builder setBatchSize(Long value); + public abstract JdbcWriteSchemaTransformConfiguration build(); } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index 013fc7996a95..8725ef4b3f78 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertEquals; @@ -71,6 +72,7 @@ import org.apache.beam.sdk.io.jdbc.JdbcIO.DataSourceConfiguration; import org.apache.beam.sdk.io.jdbc.JdbcIO.PoolableDataSourceProvider; import org.apache.beam.sdk.io.jdbc.JdbcUtil.PartitioningFn; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.FixedPrecisionNumeric; @@ -243,7 +245,10 @@ public void testRead() { Iterable expectedValues = TestRow.getExpectedValues(0, EXPECTED_ROW_COUNT); PAssert.that(rows).containsInAnyOrder(expectedValues); - pipeline.run(); + PipelineResult result = pipeline.run(); + assertThat( + Lineage.query(result.metrics(), Lineage.Type.SOURCE), + hasItem(Lineage.getFqName("derby", ImmutableList.of("memory", "testDB", READ_TABLE_NAME)))); } @Test @@ -263,7 +268,10 @@ public void testReadWithSingleStringParameter() { Iterable expectedValues = Collections.singletonList(TestRow.fromSeed(1)); PAssert.that(rows).containsInAnyOrder(expectedValues); - pipeline.run(); + PipelineResult result = pipeline.run(); + assertThat( + Lineage.query(result.metrics(), Lineage.Type.SOURCE), + hasItem(Lineage.getFqName("derby", ImmutableList.of("memory", "testDB", READ_TABLE_NAME)))); } @Test @@ -531,6 +539,24 @@ public void testWrite() throws Exception { ArrayList> data = getDataToWrite(EXPECTED_ROW_COUNT); pipeline.apply(Create.of(data)).apply(getJdbcWrite(tableName)); + PipelineResult result = pipeline.run(); + assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT); + assertThat( + Lineage.query(result.metrics(), Lineage.Type.SINK), + hasItem(Lineage.getFqName("derby", ImmutableList.of("memory", "testDB", tableName)))); + } finally { + DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName); + } + } + + @Test + public void testWriteWithBatchSize() throws Exception { + String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); + DatabaseTestHelper.createTable(DATA_SOURCE, tableName); + try { + ArrayList> data = getDataToWrite(EXPECTED_ROW_COUNT); + pipeline.apply(Create.of(data)).apply(getJdbcWrite(tableName).withBatchSize(10L)); + pipeline.run(); assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT); diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java index ed380d813625..193a1f0c3477 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java @@ -133,6 +133,7 @@ public void testAbleToReadDataSourceConfiguration() { .withFieldValue("connectionInitSqls", new ArrayList<>(Collections.singleton("initSql"))) .withFieldValue("maxConnections", (short) 3) .withFieldValue("driverJars", "test.jar") + .withFieldValue("writeBatchSize", 10L) .build(); JdbcSchemaIOProvider.JdbcSchemaIO schemaIO = provider.from(READ_TABLE_NAME, config, Schema.builder().build()); @@ -148,6 +149,7 @@ public void testAbleToReadDataSourceConfiguration() { Objects.requireNonNull(dataSourceConf.getConnectionInitSqls()).get()); assertEquals(3, (int) dataSourceConf.getMaxConnections().get()); assertEquals("test.jar", Objects.requireNonNull(dataSourceConf.getDriverJars()).get()); + assertEquals(10L, schemaIO.config.getInt64("writeBatchSize").longValue()); } /** Create test data that is consistent with that generated by TestRow. */ diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java index 5b2e9f27f0a8..356d6c7f8de7 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java @@ -22,7 +22,10 @@ import static org.hamcrest.number.IsCloseTo.closeTo; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; import java.io.File; import java.io.IOException; import java.net.URL; @@ -34,12 +37,17 @@ import java.sql.SQLException; import java.util.ArrayList; import java.util.List; +import java.util.Map.Entry; import java.util.Random; +import javax.sql.DataSource; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.junit.Rule; import org.junit.Test; @@ -264,4 +272,97 @@ public void testSavesFilesAsExpected() throws IOException { expectedContent2, new String(Files.readAllBytes(Paths.get(urls[1].getFile())), StandardCharsets.UTF_8)); } + + @Test + public void testJdbcUrl() { + ImmutableMap> testCases = + ImmutableMap.>builder() + .put( + "jdbc:postgresql://localhost:5432/postgres", + ImmutableList.of("postgresql", "localhost:5432", "postgres")) + .put( + "jdbc:mysql://127.0.0.1:3306/db", ImmutableList.of("mysql", "127.0.0.1:3306", "db")) + .put( + "jdbc:oracle:thin:HR/hr@localhost:5221:orcl", + ImmutableList.of("oracle", "localhost:5221", "orcl")) + .put( + "jdbc:derby:memory:testDB;create=true", + ImmutableList.of("derby", "memory", "testDB")) + .put( + "jdbc:oracle:thin:@//myhost.example.com:1521/my_service", + ImmutableList.of("oracle", "myhost.example.com:1521", "my_service")) + .put("jdbc:mysql:///cloud_sql", ImmutableList.of("mysql", "", "cloud_sql")) + .put("invalid", ImmutableList.of()) + .build(); + for (Entry> entry : testCases.entrySet()) { + JdbcUtil.JdbcUrl jdbcUrl = JdbcUtil.JdbcUrl.of(entry.getKey()); + + System.out.println(entry.getKey()); + if (entry.getValue().equals(ImmutableList.of())) { + assertNull(jdbcUrl); + } else { + assertEquals(entry.getValue().get(0), jdbcUrl.getScheme()); + assertEquals( + entry.getValue().get(1), + jdbcUrl.getHostAndPort() == null ? "" : jdbcUrl.getHostAndPort()); + assertEquals(entry.getValue().get(2), jdbcUrl.getDatabase()); + } + } + } + + @Test + public void testFqnFromHikariDataSourcePostgreSql() { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl("jdbc:postgresql:///postgres"); + config.setUsername("postgres"); + config.addDataSourceProperty( + "cloudSqlInstance", "example.com:project:some-region:instance-name"); + // instead of `new HikariDataSource(config)`, initialize an empty source to avoid creation + // of actual connection pool + DataSource dataSource = new HikariDataSource(); + config.validate(); + config.copyStateTo((HikariConfig) dataSource); + JdbcUtil.FQNComponents components = JdbcUtil.FQNComponents.of(dataSource); + assertEquals("cloudsql_postgresql", components.getScheme()); + assertEquals( + ImmutableList.of("example.com:project", "some-region", "instance-name", "postgres"), + components.getSegments()); + } + + @Test + public void testFqnFromHikariDataSourceMySql() { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl("jdbc:mysql:///db"); + config.setUsername("root"); + config.addDataSourceProperty("cloudSqlInstance", "some-project:US:instance-name"); + // instead of `new HikariDataSource(config)`, initialize an empty source to avoid creation + // of actual connection pool + DataSource dataSource = new HikariDataSource(); + config.validate(); + config.copyStateTo((HikariConfig) dataSource); + JdbcUtil.FQNComponents components = JdbcUtil.FQNComponents.of(dataSource); + assertEquals("cloudsql_mysql", components.getScheme()); + assertEquals( + ImmutableList.of("some-project", "US", "instance-name", "db"), components.getSegments()); + } + + @Test + public void testExtractTableFromQuery() { + ImmutableList> readCases = + ImmutableList.of( + KV.of("select * from table_1", "table_1"), + KV.of("SELECT a, b FROM [table-2]", "table-2"), + KV.of("drop table not-select", "")); + for (KV testCase : readCases) { + assertEquals(testCase.getValue(), JdbcUtil.extractTableFromReadQuery(testCase.getKey())); + } + ImmutableList> writeCases = + ImmutableList.of( + KV.of("insert into table_1 values ...", "table_1"), + KV.of("INSERT INTO [table-2] values ...", "table-2"), + KV.of("drop table not-select", "")); + for (KV testCase : writeCases) { + assertEquals(testCase.getValue(), JdbcUtil.extractTableFromWriteQuery(testCase.getKey())); + } + } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java index f66a143323e5..d6be4d9f89c8 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProviderTest.java @@ -175,6 +175,7 @@ public void testWriteToTable() throws SQLException { .setDriverClassName(DATA_SOURCE_CONFIGURATION.getDriverClassName().get()) .setJdbcUrl(DATA_SOURCE_CONFIGURATION.getUrl().get()) .setLocation(writeTableName) + .setBatchSize(1L) .build())); pipeline.run(); DatabaseTestHelper.assertRowCount(DATA_SOURCE, writeTableName, 2); diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java index eddcb0de5561..266d04342d1f 100644 --- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java +++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java @@ -215,7 +215,7 @@ public void testPublishingThenReadingAll() throws IOException, JMSException { int unackRecords = countRemain(QUEUE); assertTrue( String.format("Too many unacknowledged messages: %d", unackRecords), - unackRecords < OPTIONS.getNumberOfRecords() * 0.002); + unackRecords < OPTIONS.getNumberOfRecords() * 0.003); // acknowledged records int ackRecords = OPTIONS.getNumberOfRecords() - unackRecords; diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle index bbb212e76c92..04563c478d6d 100644 --- a/sdks/java/io/kafka/build.gradle +++ b/sdks/java/io/kafka/build.gradle @@ -36,9 +36,6 @@ ext { } def kafkaVersions = [ - '01103': "0.11.0.3", - '100': "1.0.0", - '111': "1.1.1", '201': "2.0.1", '211': "2.1.1", '222': "2.2.2", @@ -55,6 +52,7 @@ dependencies { provided library.java.jackson_dataformat_csv permitUnusedDeclared library.java.jackson_dataformat_csv implementation project(path: ":sdks:java:core", configuration: "shadow") + implementation project(path: ":model:pipeline", configuration: "shadow") implementation project(":sdks:java:extensions:avro") implementation project(":sdks:java:extensions:protobuf") implementation project(":sdks:java:expansion-service") @@ -73,6 +71,7 @@ dependencies { implementation library.java.jackson_annotations implementation library.java.jackson_databind implementation "org.springframework:spring-expression:5.3.27" + implementation group: 'com.google.cloud.hosted.kafka', name: 'managed-kafka-auth-login-handler', version: '1.0.2' implementation ("io.confluent:kafka-avro-serializer:${confluentVersion}") { // zookeeper depends on "spotbugs-annotations:3.1.9" which clashes with current // "spotbugs-annotations:3.1.12" used in Beam. Not required. @@ -123,6 +122,8 @@ kafkaVersions.each {kv -> outputs.upToDateWhen { false } testClassesDirs = sourceSets.test.output.classesDirs classpath = configurations."kafkaVersion${kv.key}" + sourceSets.test.runtimeClasspath + systemProperty "beam.target.kafka.version", kv.value + include '**/KafkaIOTest.class' } } @@ -136,15 +137,13 @@ task kafkaVersionsCompatibilityTest { description = 'Runs KafkaIO with different Kafka client APIs' def testNames = createTestList(kafkaVersions, "Test") dependsOn testNames - dependsOn (":sdks:java:io:kafka:kafka-01103:kafkaVersion01103BatchIT") - dependsOn (":sdks:java:io:kafka:kafka-100:kafkaVersion100BatchIT") - dependsOn (":sdks:java:io:kafka:kafka-111:kafkaVersion111BatchIT") dependsOn (":sdks:java:io:kafka:kafka-201:kafkaVersion201BatchIT") dependsOn (":sdks:java:io:kafka:kafka-211:kafkaVersion211BatchIT") dependsOn (":sdks:java:io:kafka:kafka-222:kafkaVersion222BatchIT") dependsOn (":sdks:java:io:kafka:kafka-231:kafkaVersion231BatchIT") dependsOn (":sdks:java:io:kafka:kafka-241:kafkaVersion241BatchIT") dependsOn (":sdks:java:io:kafka:kafka-251:kafkaVersion251BatchIT") + dependsOn (":sdks:java:io:kafka:kafka-312:kafkaVersion312BatchIT") } static def createTestList(Map prefixMap, String suffix) { diff --git a/sdks/java/io/kafka/kafka-100/build.gradle b/sdks/java/io/kafka/kafka-100/build.gradle deleted file mode 100644 index bd5fa67b1cfc..000000000000 --- a/sdks/java/io/kafka/kafka-100/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -project.ext { - delimited="1.0.0" - undelimited="100" - sdfCompatible=false -} - -apply from: "../kafka-integration-test.gradle" diff --git a/sdks/java/io/kafka/kafka-integration-test.gradle b/sdks/java/io/kafka/kafka-integration-test.gradle index 1aeb0c97f93b..3bbab72ff77c 100644 --- a/sdks/java/io/kafka/kafka-integration-test.gradle +++ b/sdks/java/io/kafka/kafka-integration-test.gradle @@ -29,10 +29,8 @@ provideIntegrationTestingDependencies() enableJavaPerformanceTesting() dependencies { - implementation "org.apache.kafka:kafka-clients:$delimited" - permitUnusedDeclared "org.apache.kafka:kafka-clients:$delimited" - implementation project(":sdks:java:io:kafka") - permitUnusedDeclared project(":sdks:java:io:kafka") + // Do not set kafka-client dependency here otherwise the version will be overwritten by BeamModulePlugin + // instead, rely on io/kafka/build.gradle's custom configurations with forced kafka-client resolutionStrategy testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1' } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 0f28edf19dd8..cb7b3020c66a 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -109,6 +109,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Comparators; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.kafka.clients.CommonClientConfigs; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.KafkaConsumer; @@ -118,6 +119,7 @@ import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.SaslConfigs; import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.Serializer; @@ -1453,6 +1455,24 @@ public Read withConsumerPollingTimeout(long duration) { return toBuilder().setConsumerPollingTimeout(duration).build(); } + /** + * Creates and sets the Application Default Credentials for a Kafka consumer. This allows the + * consumer to be authenticated with a Google Kafka Server using OAuth. + */ + public Read withGCPApplicationDefaultCredentials() { + + return withConsumerConfigUpdates( + ImmutableMap.of( + CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + "SASL_SSL", + SaslConfigs.SASL_MECHANISM, + "OAUTHBEARER", + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, + "com.google.cloud.hosted.kafka.auth.GcpLoginCallbackHandler", + SaslConfigs.SASL_JAAS_CONFIG, + "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;")); + } + /** Returns a {@link PTransform} for PCollection of {@link KV}, dropping Kafka metatdata. */ public PTransform>> withoutMetadata() { return new TypedWithoutMetadata<>(this); @@ -3362,6 +3382,23 @@ public Write withBadRecordErrorHandler(ErrorHandler badRecor getWriteRecordsTransform().withBadRecordErrorHandler(badRecordErrorHandler)); } + /** + * Creates and sets the Application Default Credentials for a Kafka producer. This allows the + * consumer to be authenticated with a Google Kafka Server using OAuth. + */ + public Write withGCPApplicationDefaultCredentials() { + return withProducerConfigUpdates( + ImmutableMap.of( + CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + "SASL_SSL", + SaslConfigs.SASL_MECHANISM, + "OAUTHBEARER", + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, + "com.google.cloud.hosted.kafka.auth.GcpLoginCallbackHandler", + SaslConfigs.SASL_JAAS_CONFIG, + "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;")); + } + @Override public PDone expand(PCollection> input) { final String topic = Preconditions.checkStateNotNull(getTopic(), "withTopic() is required"); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOInitializer.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOInitializer.java new file mode 100644 index 000000000000..3dfb31715ced --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOInitializer.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.harness.JvmInitializer; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; + +/** Initialize KafkaIO feature flags on worker. */ +@AutoService(JvmInitializer.class) +public class KafkaIOInitializer implements JvmInitializer { + @Override + public void beforeProcessing(PipelineOptions options) { + if (ExperimentalOptions.hasExperiment(options, "enable_kafka_metrics")) { + KafkaSinkMetrics.setSupportKafkaMetrics(true); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java new file mode 100644 index 000000000000..147a30dcdd1a --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import com.google.auto.value.AutoValue; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Stores and exports metrics for a batch of Kafka Client RPCs. */ +public interface KafkaMetrics { + + void updateSuccessfulRpcMetrics(String topic, Duration elapsedTime); + + void updateKafkaMetrics(); + + /** No-op implementation of {@code KafkaResults}. */ + class NoOpKafkaMetrics implements KafkaMetrics { + private NoOpKafkaMetrics() {} + + @Override + public void updateSuccessfulRpcMetrics(String topic, Duration elapsedTime) {} + + @Override + public void updateKafkaMetrics() {} + + private static NoOpKafkaMetrics singleton = new NoOpKafkaMetrics(); + + static NoOpKafkaMetrics getInstance() { + return singleton; + } + } + + /** + * Metrics of a batch of RPCs. Member variables are thread safe; however, this class does not have + * atomicity across member variables. + * + *

    Expected usage: A number of threads record metrics in an instance of this class with the + * member methods. Afterwards, a single thread should call {@code updateStreamingInsertsMetrics} + * which will export all counters metrics and RPC latency distribution metrics to the underlying + * {@code perWorkerMetrics} container. Afterwards, metrics should not be written/read from this + * object. + */ + @AutoValue + abstract class KafkaMetricsImpl implements KafkaMetrics { + + private static final Logger LOG = LoggerFactory.getLogger(KafkaMetricsImpl.class); + + static HashMap latencyHistograms = new HashMap(); + + abstract HashMap> perTopicRpcLatencies(); + + abstract AtomicBoolean isWritable(); + + public static KafkaMetricsImpl create() { + return new AutoValue_KafkaMetrics_KafkaMetricsImpl( + new HashMap>(), new AtomicBoolean(true)); + } + + /** Record the rpc status and latency of a successful Kafka poll RPC call. */ + @Override + public void updateSuccessfulRpcMetrics(String topic, Duration elapsedTime) { + if (isWritable().get()) { + ConcurrentLinkedQueue latencies = perTopicRpcLatencies().get(topic); + if (latencies == null) { + latencies = new ConcurrentLinkedQueue(); + latencies.add(elapsedTime); + perTopicRpcLatencies().put(topic, latencies); + } else { + latencies.add(elapsedTime); + } + } + } + + /** Record rpc latency histogram metrics for all recorded topics. */ + private void recordRpcLatencyMetrics() { + for (Map.Entry> topicLatencies : + perTopicRpcLatencies().entrySet()) { + Histogram topicHistogram; + if (latencyHistograms.containsKey(topicLatencies.getKey())) { + topicHistogram = latencyHistograms.get(topicLatencies.getKey()); + } else { + topicHistogram = + KafkaSinkMetrics.createRPCLatencyHistogram( + KafkaSinkMetrics.RpcMethod.POLL, topicLatencies.getKey()); + latencyHistograms.put(topicLatencies.getKey(), topicHistogram); + } + // update all the latencies + for (Duration d : topicLatencies.getValue()) { + Preconditions.checkArgumentNotNull(topicHistogram); + topicHistogram.update(d.toMillis()); + } + } + } + + /** + * Export all metrics recorded in this instance to the underlying {@code perWorkerMetrics} + * containers. This function will only report metrics once per instance. Subsequent calls to + * this function will no-op. + */ + @Override + public void updateKafkaMetrics() { + if (!isWritable().compareAndSet(true, false)) { + LOG.warn("Updating stale Kafka metrics container"); + return; + } + recordRpcLatencyMetrics(); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java index e87669ab2b0a..a3fd1d8c3fd7 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.kafka; import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import com.google.auto.service.AutoService; import java.io.FileOutputStream; @@ -34,6 +35,7 @@ import java.util.Map; import java.util.stream.Collectors; import org.apache.avro.generic.GenericRecord; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils; @@ -103,7 +105,7 @@ public Row apply(byte[] input) { @Override public String identifier() { - return "beam:schematransform:org.apache.beam:kafka_read:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ); } @Override diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java new file mode 100644 index 000000000000..f71926f97d27 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import org.apache.beam.sdk.metrics.DelegatingHistogram; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.LabeledMetricNameUtils; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.util.HistogramData; + +/** + * Helper class to create per worker metrics for Kafka Sink stages. + * + *

    Metrics will be in the namespace 'KafkaSink' and have their name formatted as: + * + *

    '{baseName}-{metricLabelKey1}:{metricLabelVal1};...{metricLabelKeyN}:{metricLabelValN};' ???? + */ + +// TODO, refactor out common parts for BQ sink, so it can be reused with other sinks, eg, GCS? +// @SuppressWarnings("unused") +public class KafkaSinkMetrics { + private static boolean supportKafkaMetrics = false; + + public static final String METRICS_NAMESPACE = "KafkaSink"; + + // Base Metric names + private static final String RPC_LATENCY = "RpcLatency"; + + // Kafka Consumer Method names + enum RpcMethod { + POLL, + } + + // Metric labels + private static final String TOPIC_LABEL = "topic_name"; + private static final String RPC_METHOD = "rpc_method"; + + /** + * Creates an Histogram metric to record RPC latency. Metric will have name. + * + *

    'RpcLatency*rpc_method:{method};topic_name:{topic};' + * + * @param method Kafka method associated with this metric. + * @param topic Kafka topic associated with this metric. + * @return Histogram with exponential buckets with a sqrt(2) growth factor. + */ + public static Histogram createRPCLatencyHistogram(RpcMethod method, String topic) { + LabeledMetricNameUtils.MetricNameBuilder nameBuilder = + LabeledMetricNameUtils.MetricNameBuilder.baseNameBuilder(RPC_LATENCY); + nameBuilder.addLabel(RPC_METHOD, method.toString()); + nameBuilder.addLabel(TOPIC_LABEL, topic); + + MetricName metricName = nameBuilder.build(METRICS_NAMESPACE); + HistogramData.BucketType buckets = HistogramData.ExponentialBuckets.of(1, 17); + + return new DelegatingHistogram(metricName, buckets, false, true); + } + + /** + * Returns a container to store metrics for Kafka metrics in Unbounded Readed. If these metrics + * are disabled, then we return a no-op container. + */ + static KafkaMetrics kafkaMetrics() { + if (supportKafkaMetrics) { + return KafkaMetrics.KafkaMetricsImpl.create(); + } else { + return KafkaMetrics.NoOpKafkaMetrics.getInstance(); + } + } + + public static void setSupportKafkaMetrics(boolean supportKafkaMetrics) { + KafkaSinkMetrics.supportKafkaMetrics = supportKafkaMetrics; + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index fed03047cf16..069607955c6d 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -23,12 +23,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; -import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -53,6 +53,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; @@ -137,14 +138,12 @@ public boolean start() throws IOException { name, spec.getOffsetConsumerConfig(), spec.getConsumerConfig()); offsetConsumer = spec.getConsumerFactoryFn().apply(offsetConsumerConfig); - ConsumerSpEL.evaluateAssign(offsetConsumer, topicPartitions); // Fetch offsets once before running periodically. updateLatestOffsets(); offsetFetcherThread.scheduleAtFixedRate( this::updateLatestOffsets, 0, OFFSET_UPDATE_INTERVAL_SECONDS, TimeUnit.SECONDS); - return advance(); } @@ -158,6 +157,9 @@ public boolean advance() throws IOException { */ while (true) { if (curBatch.hasNext()) { + // Initalize metrics container. + kafkaResults = KafkaSinkMetrics.kafkaMetrics(); + PartitionState pState = curBatch.next(); if (!pState.recordIter.hasNext()) { // -- (c) @@ -225,11 +227,9 @@ public boolean advance() throws IOException { METRIC_NAMESPACE, RAW_SIZE_METRIC_PREFIX + pState.topicPartition.toString()); rawSizes.update(recordSize); - for (Map.Entry backlogSplit : perPartitionBacklogMetrics.entrySet()) { - backlogBytesOfSplit.set(backlogSplit.getValue()); - } + // Pass metrics to container. + kafkaResults.updateKafkaMetrics(); return true; - } else { // -- (b) nextBatch(); @@ -345,7 +345,6 @@ public long getSplitBacklogBytes() { private final Counter bytesReadBySplit; private final Gauge backlogBytesOfSplit; private final Gauge backlogElementsOfSplit; - private HashMap perPartitionBacklogMetrics = new HashMap();; private final Counter checkpointMarkCommitsEnqueued = Metrics.counter(METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC); // Checkpoint marks skipped in favor of newer mark (only the latest needs to be committed). @@ -377,6 +376,7 @@ public long getSplitBacklogBytes() { .setDaemon(true) .setNameFormat("KafkaConsumerPoll-thread") .build()); + private AtomicReference consumerPollException = new AtomicReference<>(); private final SynchronousQueue> availableRecordsQueue = new SynchronousQueue<>(); @@ -399,6 +399,11 @@ public long getSplitBacklogBytes() { /** watermark before any records have been read. */ private static Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + public KafkaMetrics kafkaResults = KafkaSinkMetrics.kafkaMetrics(); + private Stopwatch stopwatch = Stopwatch.createUnstarted(); + + private Set kafkaTopics; + @Override public String toString() { return name; @@ -496,10 +501,6 @@ Instant updateAndGetWatermark() { lastWatermark = timestampPolicy.getWatermark(mkTimestampPolicyContext()); return lastWatermark; } - - String name() { - return this.topicPartition.toString(); - } } KafkaUnboundedReader( @@ -509,6 +510,10 @@ String name() { List partitions = Preconditions.checkArgumentNotNull(source.getSpec().getTopicPartitions()); + + this.kafkaTopics = partitions.stream().map(TopicPartition::topic).collect(Collectors.toSet()); + + LOG.info("{} is reading from topics {}", this.name, kafkaTopics); List> states = new ArrayList<>(partitions.size()); if (checkpointMark != null) { @@ -537,16 +542,14 @@ String name() { prevWatermark = Optional.of(new Instant(ckptMark.getWatermarkMillis())); } - PartitionState state = - new PartitionState( + states.add( + new PartitionState<>( tp, nextOffset, source .getSpec() .getTimestampPolicyFactory() - .createTimestampPolicy(tp, prevWatermark)); - states.add(state); - perPartitionBacklogMetrics.put(state.name(), 0L); + .createTimestampPolicy(tp, prevWatermark))); } partitionStates = ImmutableList.copyOf(states); @@ -568,7 +571,14 @@ private void consumerPollLoop() { while (!closed.get()) { try { if (records.isEmpty()) { + stopwatch.start(); records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); + stopwatch.stop(); + for (String kafkaTopic : kafkaTopics) { + kafkaResults.updateSuccessfulRpcMetrics( + kafkaTopic, + java.time.Duration.ofMillis(stopwatch.elapsed(TimeUnit.MILLISECONDS))); + } } else if (availableRecordsQueue.offer( records, RECORDS_ENQUEUE_POLL_TIMEOUT.getMillis(), TimeUnit.MILLISECONDS)) { records = ConsumerRecords.empty(); @@ -592,7 +602,6 @@ private void consumerPollLoop() { private void commitCheckpointMark() { KafkaCheckpointMark checkpointMark = finalizedCheckpointMark.getAndSet(null); - if (checkpointMark != null) { LOG.debug("{}: Committing finalized checkpoint {}", this, checkpointMark); Consumer consumer = Preconditions.checkStateNotNull(this.consumer); @@ -655,6 +664,8 @@ private void nextBatch() throws IOException { partitionStates.forEach(p -> p.recordIter = records.records(p.topicPartition).iterator()); + reportBacklog(); + // cycle through the partitions in order to interleave records from each. curBatch = Iterators.cycle(new ArrayList<>(partitionStates)); } @@ -685,23 +696,28 @@ private void setupInitialOffset(PartitionState pState) { // Called from setupInitialOffset() at the start and then periodically from offsetFetcher thread. private void updateLatestOffsets() { Consumer offsetConsumer = Preconditions.checkStateNotNull(this.offsetConsumer); - for (PartitionState p : partitionStates) { - try { - Instant fetchTime = Instant.now(); - ConsumerSpEL.evaluateSeek2End(offsetConsumer, p.topicPartition); - long offset = offsetConsumer.position(p.topicPartition); - p.setLatestOffset(offset, fetchTime); - } catch (Exception e) { - if (closed.get()) { // Ignore the exception if the reader is closed. - break; - } + List topicPartitions = + Preconditions.checkStateNotNull(source.getSpec().getTopicPartitions()); + Instant fetchTime = Instant.now(); + try { + Map endOffsets = offsetConsumer.endOffsets(topicPartitions); + for (PartitionState p : partitionStates) { + p.setLatestOffset( + Preconditions.checkStateNotNull( + endOffsets.get(p.topicPartition), + "No end offset found for partition %s.", + p.topicPartition), + fetchTime); + } + } catch (Exception e) { + if (!closed.get()) { // Ignore the exception if the reader is closed. LOG.warn( - "{}: exception while fetching latest offset for partition {}. will be retried.", + "{}: exception while fetching latest offset for partitions {}. will be retried.", this, - p.topicPartition, + topicPartitions, e); - // Don't update the latest offset. } + // Don't update the latest offset. } LOG.debug("{}: backlog {}", this, getSplitBacklogBytes()); @@ -728,7 +744,6 @@ private long getSplitBacklogMessageCount() { if (pBacklog == UnboundedReader.BACKLOG_UNKNOWN) { return UnboundedReader.BACKLOG_UNKNOWN; } - perPartitionBacklogMetrics.put(p.name(), pBacklog); backlogCount += pBacklog; } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java index 09b338492b47..d6f46b11cb7d 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.kafka; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; + import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import java.io.Serializable; @@ -26,6 +28,7 @@ import java.util.Map; import java.util.Set; import javax.annotation.Nullable; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils; import org.apache.beam.sdk.metrics.Counter; @@ -249,7 +252,7 @@ public byte[] apply(Row input) { @Override public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:kafka_write:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE); } @Override diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 952e29f75104..1cf4aad34e4e 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -19,13 +19,19 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import java.math.BigDecimal; +import java.math.MathContext; +import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; import org.apache.beam.sdk.io.kafka.KafkaIOUtils.MovingAvg; @@ -58,6 +64,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; @@ -203,6 +210,23 @@ private ReadFromKafkaDoFn( private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class); + /** + * A holder class for all construction time unique instances of {@link ReadFromKafkaDoFn}. Caches + * must run clean up tasks when {@link #teardown()} is called. + */ + private static final class SharedStateHolder { + + private static final Map> + OFFSET_ESTIMATOR_CACHE = new ConcurrentHashMap<>(); + private static final Map> + AVG_RECORD_SIZE_CACHE = new ConcurrentHashMap<>(); + } + + private static final AtomicLong FN_ID = new AtomicLong(); + + // A unique identifier for the instance. Generally unique unless the ID generator overflows. + private final long fnId = FN_ID.getAndIncrement(); + private final @Nullable Map offsetConsumerConfig; private final @Nullable CheckStopReadingFn checkStopReadingFn; @@ -221,13 +245,12 @@ private ReadFromKafkaDoFn( // Valid between bundle start and bundle finish. private transient @Nullable Deserializer keyDeserializerInstance = null; private transient @Nullable Deserializer valueDeserializerInstance = null; - private transient @Nullable Map offsetEstimatorCache; + private transient @Nullable LoadingCache + offsetEstimatorCache; - private transient @Nullable LoadingCache avgRecordSize; + private transient @Nullable LoadingCache + avgRecordSizeCache; private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L; - - private HashMap perPartitionBacklogMetrics = new HashMap();; - @VisibleForTesting final long consumerPollingTimeout; @VisibleForTesting final DeserializerProvider keyDeserializerProvider; @VisibleForTesting final DeserializerProvider valueDeserializerProvider; @@ -247,19 +270,21 @@ private static class KafkaLatestOffsetEstimator private final Consumer offsetConsumer; private final TopicPartition topicPartition; private final Supplier memoizedBacklog; - private boolean closed; KafkaLatestOffsetEstimator( Consumer offsetConsumer, TopicPartition topicPartition) { this.offsetConsumer = offsetConsumer; this.topicPartition = topicPartition; - ConsumerSpEL.evaluateAssign(this.offsetConsumer, ImmutableList.of(this.topicPartition)); memoizedBacklog = Suppliers.memoizeWithExpiration( () -> { synchronized (offsetConsumer) { - ConsumerSpEL.evaluateSeek2End(offsetConsumer, topicPartition); - return offsetConsumer.position(topicPartition); + return Preconditions.checkStateNotNull( + offsetConsumer + .endOffsets(Collections.singleton(topicPartition)) + .get(topicPartition), + "No end offset found for partition %s.", + topicPartition); } }, 1, @@ -270,7 +295,6 @@ private static class KafkaLatestOffsetEstimator protected void finalize() { try { Closeables.close(offsetConsumer, true); - closed = true; LOG.info("Offset Estimator consumer was closed for {}", topicPartition); } catch (Exception anyException) { LOG.warn("Failed to close offset consumer for {}", topicPartition); @@ -281,10 +305,6 @@ protected void finalize() { public long estimate() { return memoizedBacklog.get(); } - - public boolean isClosed() { - return closed; - } } @GetInitialRestriction @@ -292,7 +312,7 @@ public OffsetRange initialRestriction(@Element KafkaSourceDescriptor kafkaSource Map updatedConsumerConfig = overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); TopicPartition partition = kafkaSourceDescriptor.getTopicPartition(); - LOG.info("Creating Kafka consumer for initial restriction for {}", partition); + LOG.info("Creating Kafka consumer for initial restriction for {}", kafkaSourceDescriptor); try (Consumer offsetConsumer = consumerFactoryFn.apply(updatedConsumerConfig)) { ConsumerSpEL.evaluateAssign(offsetConsumer, ImmutableList.of(partition)); long startOffset; @@ -339,28 +359,31 @@ public WatermarkEstimator newWatermarkEstimator( @GetSize public double getSize( @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange offsetRange) - throws Exception { - final LoadingCache avgRecordSize = - Preconditions.checkStateNotNull(this.avgRecordSize); - double numRecords = + throws ExecutionException { + // If present, estimates the record size to offset gap ratio. Compacted topics may hold less + // records than the estimated offset range due to record deletion within a partition. + final LoadingCache avgRecordSizeCache = + Preconditions.checkStateNotNull(this.avgRecordSizeCache); + final @Nullable AverageRecordSize avgRecordSize = + avgRecordSizeCache.getIfPresent(kafkaSourceDescriptor); + // The tracker estimates the offset range by subtracting the last claimed position from the + // currently observed end offset for the partition belonging to this split. + double estimatedOffsetRange = restrictionTracker(kafkaSourceDescriptor, offsetRange).getProgress().getWorkRemaining(); // Before processing elements, we don't have a good estimated size of records and offset gap. - if (!avgRecordSize.asMap().containsKey(kafkaSourceDescriptor.getTopicPartition())) { - return numRecords; + // Return the estimated offset range without scaling by a size to gap ratio. + if (avgRecordSize == null) { + return estimatedOffsetRange; } - if (offsetEstimatorCache != null) { - for (Map.Entry tp : - offsetEstimatorCache.entrySet()) { - perPartitionBacklogMetrics.put(tp.getKey().toString(), tp.getValue().estimate()); - } - } - - return avgRecordSize.get(kafkaSourceDescriptor.getTopicPartition()).getTotalSize(numRecords); + // When processing elements, a moving average estimates the size of records and offset gap. + // Return the estimated offset range scaled by the estimated size to gap ratio. + return estimatedOffsetRange * avgRecordSize.estimateRecordByteSizeToOffsetCountRatio(); } @NewTracker public OffsetRangeTracker restrictionTracker( - @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange restriction) { + @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange restriction) + throws ExecutionException { if (restriction.getTo() < Long.MAX_VALUE) { return new OffsetRangeTracker(restriction); } @@ -368,24 +391,10 @@ public OffsetRangeTracker restrictionTracker( // OffsetEstimators are cached for each topic-partition because they hold a stateful connection, // so we want to minimize the amount of connections that we start and track with Kafka. Another // point is that it has a memoized backlog, and this should make that more reusable estimations. - final Map offsetEstimatorCacheInstance = + final LoadingCache offsetEstimatorCache = Preconditions.checkStateNotNull(this.offsetEstimatorCache); - - TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); - KafkaLatestOffsetEstimator offsetEstimator = offsetEstimatorCacheInstance.get(topicPartition); - if (offsetEstimator == null || offsetEstimator.isClosed()) { - Map updatedConsumerConfig = - overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); - - LOG.info("Creating Kafka consumer for offset estimation for {}", topicPartition); - - Consumer offsetConsumer = - consumerFactoryFn.apply( - KafkaIOUtils.getOffsetConsumerConfig( - "tracker-" + topicPartition, offsetConsumerConfig, updatedConsumerConfig)); - offsetEstimator = new KafkaLatestOffsetEstimator(offsetConsumer, topicPartition); - offsetEstimatorCacheInstance.put(topicPartition, offsetEstimator); - } + final KafkaLatestOffsetEstimator offsetEstimator = + offsetEstimatorCache.get(kafkaSourceDescriptor); return new GrowableOffsetRangeTracker(restriction.getFrom(), offsetEstimator); } @@ -397,22 +406,22 @@ public ProcessContinuation processElement( WatermarkEstimator watermarkEstimator, MultiOutputReceiver receiver) throws Exception { - final LoadingCache avgRecordSize = - Preconditions.checkStateNotNull(this.avgRecordSize); + final LoadingCache avgRecordSizeCache = + Preconditions.checkStateNotNull(this.avgRecordSizeCache); + final LoadingCache offsetEstimatorCache = + Preconditions.checkStateNotNull(this.offsetEstimatorCache); final Deserializer keyDeserializerInstance = Preconditions.checkStateNotNull(this.keyDeserializerInstance); final Deserializer valueDeserializerInstance = Preconditions.checkStateNotNull(this.valueDeserializerInstance); + final TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); + final AverageRecordSize avgRecordSize = avgRecordSizeCache.get(kafkaSourceDescriptor); + // TODO: Metrics should be reported per split instead of partition, add bootstrap server hash? final Distribution rawSizes = - Metrics.distribution( - METRIC_NAMESPACE, - RAW_SIZE_METRIC_PREFIX + kafkaSourceDescriptor.getTopicPartition().toString()); - for (Map.Entry backlogSplit : perPartitionBacklogMetrics.entrySet()) { - Gauge backlog = - Metrics.gauge( - METRIC_NAMESPACE, RAW_SIZE_METRIC_PREFIX + "backlogBytes_" + backlogSplit.getKey()); - backlog.set(backlogSplit.getValue()); - } + Metrics.distribution(METRIC_NAMESPACE, RAW_SIZE_METRIC_PREFIX + topicPartition.toString()); + final Gauge backlogBytes = + Metrics.gauge( + METRIC_NAMESPACE, RAW_SIZE_METRIC_PREFIX + "backlogBytes_" + topicPartition.toString()); // Stop processing current TopicPartition when it's time to stop. if (checkStopReadingFn != null @@ -430,13 +439,10 @@ public ProcessContinuation processElement( if (timestampPolicyFactory != null) { timestampPolicy = timestampPolicyFactory.createTimestampPolicy( - kafkaSourceDescriptor.getTopicPartition(), - Optional.ofNullable(watermarkEstimator.currentWatermark())); + topicPartition, Optional.ofNullable(watermarkEstimator.currentWatermark())); } - LOG.info( - "Creating Kafka consumer for process continuation for {}", - kafkaSourceDescriptor.getTopicPartition()); + LOG.info("Creating Kafka consumer for process continuation for {}", kafkaSourceDescriptor); try (Consumer consumer = consumerFactoryFn.apply(updatedConsumerConfig)) { ConsumerSpEL.evaluateAssign( consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition())); @@ -453,7 +459,8 @@ public ProcessContinuation processElement( // and move to process the next element. if (rawRecords.isEmpty()) { if (!topicPartitionExists( - kafkaSourceDescriptor.getTopicPartition(), consumer.listTopics())) { + kafkaSourceDescriptor.getTopicPartition(), + consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) { return ProcessContinuation.stop(); } if (timestampPolicy != null) { @@ -509,8 +516,8 @@ public ProcessContinuation processElement( int recordSize = (rawRecord.key() == null ? 0 : rawRecord.key().length) + (rawRecord.value() == null ? 0 : rawRecord.value().length); - avgRecordSize - .getUnchecked(kafkaSourceDescriptor.getTopicPartition()) + avgRecordSizeCache + .getUnchecked(kafkaSourceDescriptor) .update(recordSize, rawRecord.offset() - expectedOffset); rawSizes.update(recordSize); expectedOffset = rawRecord.offset() + 1; @@ -542,25 +549,24 @@ public ProcessContinuation processElement( } } } + + backlogBytes.set( + (long) + (BigDecimal.valueOf( + Preconditions.checkStateNotNull( + offsetEstimatorCache.get(kafkaSourceDescriptor).estimate())) + .subtract(BigDecimal.valueOf(expectedOffset), MathContext.DECIMAL128) + .doubleValue() + * avgRecordSize.estimateRecordByteSizeToOffsetCountRatio())); } } } private boolean topicPartitionExists( - TopicPartition topicPartition, Map> topicListMap) { + TopicPartition topicPartition, List partitionInfos) { // Check if the current TopicPartition still exists. - Set existingTopicPartitions = new HashSet<>(); - for (List topicPartitionList : topicListMap.values()) { - topicPartitionList.forEach( - partitionInfo -> { - existingTopicPartitions.add( - new TopicPartition(partitionInfo.topic(), partitionInfo.partition())); - }); - } - if (!existingTopicPartitions.contains(topicPartition)) { - return false; - } - return true; + return partitionInfos.stream() + .anyMatch(partitionInfo -> partitionInfo.partition() == (topicPartition.partition())); } // see https://github.com/apache/beam/issues/25962 @@ -612,19 +618,57 @@ public Coder restrictionCoder() { @Setup public void setup() throws Exception { // Start to track record size and offset gap per bundle. - avgRecordSize = - CacheBuilder.newBuilder() - .maximumSize(1000L) - .build( - new CacheLoader() { - @Override - public AverageRecordSize load(TopicPartition topicPartition) throws Exception { - return new AverageRecordSize(); - } - }); + avgRecordSizeCache = + SharedStateHolder.AVG_RECORD_SIZE_CACHE.computeIfAbsent( + fnId, + k -> { + return CacheBuilder.newBuilder() + .maximumSize(1000L) + .build( + new CacheLoader() { + @Override + public AverageRecordSize load(KafkaSourceDescriptor kafkaSourceDescriptor) + throws Exception { + return new AverageRecordSize(); + } + }); + }); keyDeserializerInstance = keyDeserializerProvider.getDeserializer(consumerConfig, true); valueDeserializerInstance = valueDeserializerProvider.getDeserializer(consumerConfig, false); - offsetEstimatorCache = new HashMap<>(); + offsetEstimatorCache = + SharedStateHolder.OFFSET_ESTIMATOR_CACHE.computeIfAbsent( + fnId, + k -> { + final Map consumerConfig = ImmutableMap.copyOf(this.consumerConfig); + final @Nullable Map offsetConsumerConfig = + this.offsetConsumerConfig == null + ? null + : ImmutableMap.copyOf(this.offsetConsumerConfig); + return CacheBuilder.newBuilder() + .weakValues() + .expireAfterAccess(1, TimeUnit.MINUTES) + .build( + new CacheLoader() { + @Override + public KafkaLatestOffsetEstimator load( + KafkaSourceDescriptor kafkaSourceDescriptor) throws Exception { + LOG.info( + "Creating Kafka consumer for offset estimation for {}", + kafkaSourceDescriptor); + + TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); + Map updatedConsumerConfig = + overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); + Consumer offsetConsumer = + consumerFactoryFn.apply( + KafkaIOUtils.getOffsetConsumerConfig( + "tracker-" + topicPartition, + offsetConsumerConfig, + updatedConsumerConfig)); + return new KafkaLatestOffsetEstimator(offsetConsumer, topicPartition); + } + }); + }); if (checkStopReadingFn != null) { checkStopReadingFn.setup(); } @@ -632,23 +676,29 @@ public AverageRecordSize load(TopicPartition topicPartition) throws Exception { @Teardown public void teardown() throws Exception { - final Deserializer keyDeserializerInstance = - Preconditions.checkStateNotNull(this.keyDeserializerInstance); - final Deserializer valueDeserializerInstance = - Preconditions.checkStateNotNull(this.valueDeserializerInstance); + final LoadingCache avgRecordSizeCache = + Preconditions.checkStateNotNull(this.avgRecordSizeCache); + final LoadingCache offsetEstimatorCache = + Preconditions.checkStateNotNull(this.offsetEstimatorCache); try { - Closeables.close(keyDeserializerInstance, true); - Closeables.close(valueDeserializerInstance, true); + if (valueDeserializerInstance != null) { + Closeables.close(valueDeserializerInstance, true); + valueDeserializerInstance = null; + } + if (keyDeserializerInstance != null) { + Closeables.close(keyDeserializerInstance, true); + keyDeserializerInstance = null; + } } catch (Exception anyException) { LOG.warn("Fail to close resource during finishing bundle.", anyException); } - - if (offsetEstimatorCache != null) { - offsetEstimatorCache.clear(); - } if (checkStopReadingFn != null) { checkStopReadingFn.teardown(); } + + // Allow the cache to perform clean up tasks when this instance is about to be deleted. + avgRecordSizeCache.cleanUp(); + offsetEstimatorCache.cleanUp(); } private Map overrideBootstrapServersConfig( @@ -665,8 +715,15 @@ private Map overrideBootstrapServersConfig( return config; } + // TODO: Collapse the two moving average trackers into a single accumulator using a single Guava + // AtomicDouble. Note that this requires that a single thread will call update and that while get + // may be called by multiple threads the method must only load the accumulator itself. + @ThreadSafe private static class AverageRecordSize { + @GuardedBy("this") private MovingAvg avgRecordSize; + + @GuardedBy("this") private MovingAvg avgRecordGap; public AverageRecordSize() { @@ -674,13 +731,26 @@ public AverageRecordSize() { this.avgRecordGap = new MovingAvg(); } - public void update(int recordSize, long gap) { + public synchronized void update(int recordSize, long gap) { avgRecordSize.update(recordSize); avgRecordGap.update(gap); } - public double getTotalSize(double numRecords) { - return avgRecordSize.get() * numRecords / (1 + avgRecordGap.get()); + public double estimateRecordByteSizeToOffsetCountRatio() { + double avgRecordSize; + double avgRecordGap; + + synchronized (this) { + avgRecordSize = this.avgRecordSize.get(); + avgRecordGap = this.avgRecordGap.get(); + } + + // The offset increases between records in a batch fetched from a compacted topic may be + // greater than 1. Compacted topics only store records with the greatest offset per key per + // partition, the records in between are deleted and will not be observed by a consumer. + // The observed gap between offsets is used to estimate the number of records that are likely + // to be observed for the provided number of records. + return avgRecordSize / (1 + avgRecordGap); } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java index fba81c51130d..cef3bc80d613 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assume.assumeFalse; import java.io.IOException; import java.time.Instant; @@ -86,6 +87,7 @@ import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -99,10 +101,12 @@ import org.apache.kafka.common.serialization.IntegerSerializer; import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.AppInfoParser; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -168,6 +172,13 @@ public class KafkaIOIT { @BeforeClass public static void setup() throws IOException { + // check kafka version first + @Nullable String targetVer = System.getProperty("beam.target.kafka.version"); + if (!Strings.isNullOrEmpty(targetVer)) { + String actualVer = AppInfoParser.getVersion(); + assertEquals(targetVer, actualVer); + } + options = IOITHelper.readIOTestPipelineOptions(Options.class); sourceOptions = fromJsonString(options.getSourceOptions(), SyntheticSourceOptions.class); if (options.isWithTestcontainers()) { @@ -359,6 +370,10 @@ public void processElement(@Element String element, OutputReceiver outpu // This test verifies that bad data from Kafka is properly sent to the error handler @Test public void testKafkaIOSDFReadWithErrorHandler() throws IOException { + // TODO(https://github.com/apache/beam/issues/32704) re-enable when fixed, or remove the support + // for these old kafka-client versions + String actualVer = AppInfoParser.getVersion(); + assumeFalse(actualVer.compareTo("2.0.0") >= 0 && actualVer.compareTo("2.3.0") < 0); writePipeline .apply(Create.of(KV.of("key", "val"))) .apply( @@ -815,6 +830,62 @@ public void testWatermarkUpdateWithSparseMessages() throws IOException, Interrup } } + @Ignore( + "Test is ignored until GMK is utilized as part of this test suite (https://github.com/apache/beam/issues/32721).") + @Test + public void testReadAndWriteFromKafkaIOWithGCPApplicationDefaultCredentials() throws IOException { + AdminClient client = + AdminClient.create( + ImmutableMap.of("bootstrap.servers", options.getKafkaBootstrapServerAddresses())); + + String topicName = "TestApplicationDefaultCreds-" + UUID.randomUUID(); + Map records = new HashMap<>(); + for (int i = 0; i < 5; i++) { + records.put(i, String.valueOf(i)); + } + + try { + client.createTopics(ImmutableSet.of(new NewTopic(topicName, 1, (short) 1))); + + writePipeline + .apply("Generate Write Elements", Create.of(records)) + .apply( + "Write to Kafka", + KafkaIO.write() + .withBootstrapServers(options.getKafkaBootstrapServerAddresses()) + .withTopic(topicName) + .withKeySerializer(IntegerSerializer.class) + .withValueSerializer(StringSerializer.class) + .withGCPApplicationDefaultCredentials()); + + writePipeline.run().waitUntilFinish(Duration.standardSeconds(15)); + + client.createPartitions(ImmutableMap.of(topicName, NewPartitions.increaseTo(3))); + + sdfReadPipeline.apply( + "Read from Kafka", + KafkaIO.read() + .withBootstrapServers(options.getKafkaBootstrapServerAddresses()) + .withConsumerConfigUpdates(ImmutableMap.of("auto.offset.reset", "earliest")) + .withTopic(topicName) + .withKeyDeserializer(IntegerDeserializer.class) + .withValueDeserializer(StringDeserializer.class) + .withGCPApplicationDefaultCredentials() + .withoutMetadata()); + + PipelineResult readResult = sdfReadPipeline.run(); + + // Only waiting 5 seconds here because we don't expect any processing at this point + PipelineResult.State readState = readResult.waitUntilFinish(Duration.standardSeconds(5)); + + cancelIfTimeouted(readResult, readState); + // Fail the test if pipeline failed. + assertNotEquals(readState, PipelineResult.State.FAILED); + } finally { + client.deleteTopics(ImmutableSet.of(topicName)); + } + } + private static class KeyByPartition extends DoFn, KV>> { diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java index fb8b29fe7280..e614320db150 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -77,7 +77,7 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.io.kafka.KafkaIO.Read.FakeFlinkPipelineOptions; -import org.apache.beam.sdk.io.kafka.KafkaMocks.PositionErrorConsumerFactory; +import org.apache.beam.sdk.io.kafka.KafkaMocks.EndOffsetErrorConsumerFactory; import org.apache.beam.sdk.io.kafka.KafkaMocks.SendErrorProducerFactory; import org.apache.beam.sdk.metrics.DistributionResult; import org.apache.beam.sdk.metrics.Lineage; @@ -115,6 +115,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -146,12 +147,14 @@ import org.apache.kafka.common.serialization.LongDeserializer; import org.apache.kafka.common.serialization.LongSerializer; import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.AppInfoParser; import org.apache.kafka.common.utils.Utils; import org.checkerframework.checker.nullness.qual.Nullable; import org.hamcrest.collection.IsIterableContainingInAnyOrder; import org.hamcrest.collection.IsIterableWithSize; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.Assume; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -264,10 +267,6 @@ private static MockConsumer mkMockConsumer( public synchronized void assign(final Collection assigned) { super.assign(assigned); assignedPartitions.set(ImmutableList.copyOf(assigned)); - for (TopicPartition tp : assigned) { - updateBeginningOffsets(ImmutableMap.of(tp, 0L)); - updateEndOffsets(ImmutableMap.of(tp, (long) records.get(tp).size())); - } } // Override offsetsForTimes() in order to look up the offsets by timestamp. @Override @@ -287,9 +286,12 @@ public synchronized Map offsetsForTimes( } }; - for (String topic : topics) { - consumer.updatePartitions(topic, partitionMap.get(topic)); - } + partitionMap.forEach(consumer::updatePartitions); + consumer.updateBeginningOffsets( + records.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> 0L))); + consumer.updateEndOffsets( + records.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size()))); // MockConsumer does not maintain any relationship between partition seek position and the // records added. e.g. if we add 10 records to a partition and then seek to end of the @@ -515,6 +517,15 @@ public void testReadAvroGenericRecordsWithConfluentSchemaRegistry() { p.run(); } + @Test + public void testKafkaVersion() { + // KafkaIO compatibility tests run unit tests in KafkaIOTest + @Nullable String targetVer = System.getProperty("beam.target.kafka.version"); + Assume.assumeTrue(!Strings.isNullOrEmpty(targetVer)); + String actualVer = AppInfoParser.getVersion(); + assertEquals(targetVer, actualVer); + } + @Test public void testReadAvroSpecificRecordsWithConfluentSchemaRegistry() { int numElements = 100; @@ -1513,13 +1524,14 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception { List topics = ImmutableList.of("topic_a"); - PositionErrorConsumerFactory positionErrorConsumerFactory = new PositionErrorConsumerFactory(); + EndOffsetErrorConsumerFactory endOffsetErrorConsumerFactory = + new EndOffsetErrorConsumerFactory(); UnboundedSource, KafkaCheckpointMark> source = KafkaIO.read() .withBootstrapServers("myServer1:9092,myServer2:9092") .withTopics(topics) - .withConsumerFactoryFn(positionErrorConsumerFactory) + .withConsumerFactoryFn(endOffsetErrorConsumerFactory) .withKeyDeserializer(IntegerDeserializer.class) .withValueDeserializer(LongDeserializer.class) .makeSource(); @@ -1528,7 +1540,7 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception { reader.start(); - unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partition"); + unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partitions"); reader.close(); } @@ -1582,6 +1594,11 @@ public byte[] serialize(String topic, Long data) { public void configure(Map configs, boolean isKey) { // intentionally left blank for compatibility with older kafka versions } + + @Override + public void close() { + // intentionally left blank for compatibility with kafka-client v2.2 or older + } } @Test diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java new file mode 100644 index 000000000000..b84e143be773 --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.runners.core.metrics.MetricsContainerImpl; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.MetricsEnvironment; +import org.apache.beam.sdk.util.HistogramData; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link KafkaSinkMetrics}. */ +// TODO:Naireen - Refactor to remove duplicate code between the two sinks +@RunWith(JUnit4.class) +public class KafkaMetricsTest { + public static class TestHistogram implements Histogram { + public List values = Lists.newArrayList(); + private MetricName metricName = MetricName.named("KafkaSink", "name"); + + @Override + public void update(double value) { + values.add(value); + } + + @Override + public MetricName getName() { + return metricName; + } + } + + public static class TestMetricsContainer extends MetricsContainerImpl { + public ConcurrentHashMap, TestHistogram> + perWorkerHistograms = + new ConcurrentHashMap, TestHistogram>(); + + public TestMetricsContainer() { + super("TestStep"); + } + + @Override + public Histogram getPerWorkerHistogram( + MetricName metricName, HistogramData.BucketType bucketType) { + perWorkerHistograms.computeIfAbsent(KV.of(metricName, bucketType), kv -> new TestHistogram()); + return perWorkerHistograms.get(KV.of(metricName, bucketType)); + } + + @Override + public void reset() { + perWorkerHistograms.clear(); + } + } + + @Test + public void testNoOpKafkaMetrics() throws Exception { + TestMetricsContainer testContainer = new TestMetricsContainer(); + MetricsEnvironment.setCurrentContainer(testContainer); + + KafkaMetrics results = KafkaMetrics.NoOpKafkaMetrics.getInstance(); + results.updateSuccessfulRpcMetrics("test-topic", Duration.ofMillis(10)); + + results.updateKafkaMetrics(); + + assertThat(testContainer.perWorkerHistograms.size(), equalTo(0)); + } + + @Test + public void testKafkaRPCLatencyMetrics() throws Exception { + TestMetricsContainer testContainer = new TestMetricsContainer(); + MetricsEnvironment.setCurrentContainer(testContainer); + + KafkaSinkMetrics.setSupportKafkaMetrics(true); + + KafkaMetrics results = KafkaSinkMetrics.kafkaMetrics(); + + results.updateSuccessfulRpcMetrics("test-topic", Duration.ofMillis(10)); + + results.updateKafkaMetrics(); + // RpcLatency*rpc_method:POLL;topic_name:test-topic + MetricName histogramName = + MetricName.named("KafkaSink", "RpcLatency*rpc_method:POLL;topic_name:test-topic;"); + HistogramData.BucketType bucketType = HistogramData.ExponentialBuckets.of(1, 17); + + assertThat(testContainer.perWorkerHistograms.size(), equalTo(1)); + assertThat( + testContainer.perWorkerHistograms.get(KV.of(histogramName, bucketType)).values, + containsInAnyOrder(Double.valueOf(10.0))); + } + + @Test + public void testKafkaRPCLatencyMetricsAreNotRecorded() throws Exception { + TestMetricsContainer testContainer = new TestMetricsContainer(); + MetricsEnvironment.setCurrentContainer(testContainer); + + KafkaSinkMetrics.setSupportKafkaMetrics(false); + + KafkaMetrics results = KafkaSinkMetrics.kafkaMetrics(); + + results.updateSuccessfulRpcMetrics("test-topic", Duration.ofMillis(10)); + + results.updateKafkaMetrics(); + assertThat(testContainer.perWorkerHistograms.size(), equalTo(0)); + } +} diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java index 0844d71e7105..1303f1da3bcd 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.kafka; import java.io.Serializable; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -27,8 +28,8 @@ import org.apache.beam.sdk.values.KV; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; -import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; import org.apache.kafka.clients.producer.Callback; import org.apache.kafka.clients.producer.MockProducer; import org.apache.kafka.clients.producer.Producer; @@ -66,51 +67,33 @@ public Producer apply(Map input) { } } - public static final class PositionErrorConsumer extends MockConsumer { - - public PositionErrorConsumer() { - super(null); - } - - @Override - public synchronized long position(TopicPartition partition) { - throw new KafkaException("fakeException"); - } - - @Override - public synchronized List partitionsFor(String topic) { - return Collections.singletonList( - new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null)); - } - } - - public static final class PositionErrorConsumerFactory + public static final class EndOffsetErrorConsumerFactory implements SerializableFunction, Consumer> { - public PositionErrorConsumerFactory() {} + public EndOffsetErrorConsumerFactory() {} @Override public MockConsumer apply(Map input) { + final MockConsumer consumer; if (input.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { - return new PositionErrorConsumer(); - } else { - MockConsumer consumer = - new MockConsumer(null) { + consumer = + new MockConsumer(OffsetResetStrategy.EARLIEST) { @Override - public synchronized long position(TopicPartition partition) { - return 1L; - } - - @Override - public synchronized ConsumerRecords poll(long timeout) { - return ConsumerRecords.empty(); + public synchronized Map endOffsets( + Collection partitions) { + throw new KafkaException("fakeException"); } }; - consumer.updatePartitions( - "topic_a", - Collections.singletonList( - new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null))); - return consumer; + } else { + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST); } + consumer.updatePartitions( + "topic_a", + Collections.singletonList( + new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null))); + consumer.updateBeginningOffsets( + Collections.singletonMap(new TopicPartition("topic_a", 1), 0L)); + consumer.updateEndOffsets(Collections.singletonMap(new TopicPartition("topic_a", 1), 0L)); + return consumer; } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java new file mode 100644 index 000000000000..625a75c5316b --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.MetricName; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link KafkaSinkMetrics}. */ +// TODO:Naireen - Refactor to remove duplicate code between the Kafka and BigQuery sinks +@RunWith(JUnit4.class) +public class KafkaSinkMetricsTest { + @Test + public void testCreatingHistogram() throws Exception { + + Histogram histogram = + KafkaSinkMetrics.createRPCLatencyHistogram(KafkaSinkMetrics.RpcMethod.POLL, "topic1"); + + MetricName histogramName = + MetricName.named("KafkaSink", "RpcLatency*rpc_method:POLL;topic_name:topic1;"); + assertThat(histogram.getName(), equalTo(histogramName)); + } +} diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index a9e4a4eddb61..cbff0f896619 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -150,6 +150,11 @@ public static class FailingDeserializer implements Deserializer { public FailingDeserializer() {} + @Override + public void configure(Map configs, boolean isKey) { + // intentionally left blank for compatibility with older kafka versions + } + @Override public String deserialize(String topic, byte[] data) { throw new SerializationException("Intentional serialization exception"); @@ -200,6 +205,8 @@ public SimpleMockKafkaConsumer( OffsetResetStrategy offsetResetStrategy, TopicPartition topicPartition) { super(offsetResetStrategy); this.topicPartition = topicPartition; + updateBeginningOffsets(ImmutableMap.of(topicPartition, 0L)); + updateEndOffsets(ImmutableMap.of(topicPartition, Long.MAX_VALUE)); } public void reset() { @@ -209,6 +216,8 @@ public void reset() { this.startOffsetForTime = KV.of(0L, Instant.now()); this.stopOffsetForTime = KV.of(Long.MAX_VALUE, null); this.numOfRecordsPerPoll = 0L; + updateBeginningOffsets(ImmutableMap.of(topicPartition, 0L)); + updateEndOffsets(ImmutableMap.of(topicPartition, Long.MAX_VALUE)); } public void setRemoved() { @@ -243,6 +252,17 @@ public synchronized Map> listTopics() { topicPartition.topic(), topicPartition.partition(), null, null, null))); } + @Override + public synchronized List partitionsFor(String partition) { + if (this.isRemoved) { + return ImmutableList.of(); + } else { + return ImmutableList.of( + new PartitionInfo( + topicPartition.topic(), topicPartition.partition(), null, null, null)); + } + } + @Override public synchronized void assign(Collection partitions) { assertTrue(Iterables.getOnlyElement(partitions).equals(this.topicPartition)); diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java index e1868e2c8461..efc51362d06a 100644 --- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java +++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.mqtt; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import com.google.auto.value.AutoValue; import java.io.IOException; @@ -36,6 +37,7 @@ import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -80,6 +82,48 @@ * * } * + *

    Reading with Metadata from a MQTT broker

    + * + *

    The {@code readWithMetadata} method extends the functionality of the basic {@code read} method + * by returning a {@link PCollection} of metadata that includes both the topic name and the payload. + * The metadata is encapsulated in a container class {@link MqttRecord} that includes the topic name + * and payload. This allows you to implement business logic that can differ depending on the topic + * from which the message was received. + * + *

    {@code
    + * PCollection records = pipeline.apply(
    + *   MqttIO.readWithMetadata()
    + *    .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create(
    + *      "tcp://host:11883",
    + *      "my_topic_pattern"))
    + *
    + * }
    + * + *

    By using the topic information, you can apply different processing logic depending on the + * source topic, enhancing the flexibility of message processing. + * + *

    Example

    + * + *
    {@code
    + * pipeline
    + *   .apply(MqttIO.readWithMetadata()
    + *     .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create(
    + *       "tcp://host:1883", "my_topic_pattern")))
    + *   .apply(ParDo.of(new DoFn() {
    + *     @ProcessElement
    + *     public void processElement(ProcessContext c) {
    + *       MqttRecord record = c.element();
    + *       String topic = record.getTopic();
    + *       byte[] payload = record.getPayload();
    + *       // Apply business logic based on the topic
    + *       if (topic.equals("important_topic")) {
    + *         // Special processing for important_topic
    + *       }
    + *     }
    + *   }));
    + *
    + * }
    + * *

    Writing to a MQTT broker

    * *

    MqttIO sink supports writing {@code byte[]} to a topic on a MQTT broker. @@ -130,9 +174,18 @@ public class MqttIO { private static final Logger LOG = LoggerFactory.getLogger(MqttIO.class); private static final int MQTT_3_1_MAX_CLIENT_ID_LENGTH = 23; - public static Read read() { - return new AutoValue_MqttIO_Read.Builder() + public static Read read() { + return new AutoValue_MqttIO_Read.Builder() .setMaxReadTime(null) + .setWithMetadata(false) + .setMaxNumRecords(Long.MAX_VALUE) + .build(); + } + + public static Read readWithMetadata() { + return new AutoValue_MqttIO_Read.Builder() + .setMaxReadTime(null) + .setWithMetadata(true) .setMaxNumRecords(Long.MAX_VALUE) .build(); } @@ -267,7 +320,7 @@ private MQTT createClient() throws Exception { /** A {@link PTransform} to read from a MQTT broker. */ @AutoValue - public abstract static class Read extends PTransform> { + public abstract static class Read extends PTransform> { abstract @Nullable ConnectionConfiguration connectionConfiguration(); @@ -275,21 +328,29 @@ public abstract static class Read extends PTransform abstract @Nullable Duration maxReadTime(); - abstract Builder builder(); + abstract Builder builder(); + + abstract boolean withMetadata(); + + abstract @Nullable Coder coder(); @AutoValue.Builder - abstract static class Builder { - abstract Builder setConnectionConfiguration(ConnectionConfiguration config); + abstract static class Builder { + abstract Builder setConnectionConfiguration(ConnectionConfiguration config); + + abstract Builder setMaxNumRecords(long maxNumRecords); - abstract Builder setMaxNumRecords(long maxNumRecords); + abstract Builder setMaxReadTime(Duration maxReadTime); - abstract Builder setMaxReadTime(Duration maxReadTime); + abstract Builder setWithMetadata(boolean withMetadata); - abstract Read build(); + abstract Builder setCoder(Coder coder); + + abstract Read build(); } /** Define the MQTT connection configuration used to connect to the MQTT broker. */ - public Read withConnectionConfiguration(ConnectionConfiguration configuration) { + public Read withConnectionConfiguration(ConnectionConfiguration configuration) { checkArgument(configuration != null, "configuration can not be null"); return builder().setConnectionConfiguration(configuration).build(); } @@ -299,7 +360,7 @@ public Read withConnectionConfiguration(ConnectionConfiguration configuration) { * records is lower than {@code Long.MAX_VALUE}, the {@link Read} will provide a bounded {@link * PCollection}. */ - public Read withMaxNumRecords(long maxNumRecords) { + public Read withMaxNumRecords(long maxNumRecords) { return builder().setMaxNumRecords(maxNumRecords).build(); } @@ -307,19 +368,33 @@ public Read withMaxNumRecords(long maxNumRecords) { * Define the max read time (duration) while the {@link Read} will receive messages. When this * max read time is not null, the {@link Read} will provide a bounded {@link PCollection}. */ - public Read withMaxReadTime(Duration maxReadTime) { + public Read withMaxReadTime(Duration maxReadTime) { return builder().setMaxReadTime(maxReadTime).build(); } @Override - public PCollection expand(PBegin input) { + @SuppressWarnings("unchecked") + public PCollection expand(PBegin input) { checkArgument(connectionConfiguration() != null, "connectionConfiguration can not be null"); checkArgument(connectionConfiguration().getTopic() != null, "topic can not be null"); - org.apache.beam.sdk.io.Read.Unbounded unbounded = - org.apache.beam.sdk.io.Read.from(new UnboundedMqttSource(this)); + Coder coder; + if (withMetadata()) { + try { + coder = + (Coder) input.getPipeline().getSchemaRegistry().getSchemaCoder(MqttRecord.class); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e.getMessage()); + } + } else { + coder = (Coder) ByteArrayCoder.of(); + } + + org.apache.beam.sdk.io.Read.Unbounded unbounded = + org.apache.beam.sdk.io.Read.from( + new UnboundedMqttSource<>(this.builder().setCoder(coder).build())); - PTransform> transform = unbounded; + PTransform> transform = unbounded; if (maxNumRecords() < Long.MAX_VALUE || maxReadTime() != null) { transform = unbounded.withMaxReadTime(maxReadTime()).withMaxNumRecords(maxNumRecords()); @@ -403,27 +478,39 @@ public int hashCode() { } @VisibleForTesting - static class UnboundedMqttSource extends UnboundedSource { + static class UnboundedMqttSource extends UnboundedSource { - private final Read spec; + private final Read spec; - public UnboundedMqttSource(Read spec) { + public UnboundedMqttSource(Read spec) { this.spec = spec; } @Override - public UnboundedReader createReader( + @SuppressWarnings("unchecked") + public UnboundedReader createReader( PipelineOptions options, MqttCheckpointMark checkpointMark) { - return new UnboundedMqttReader(this, checkpointMark); + final UnboundedMqttReader unboundedMqttReader; + if (spec.withMetadata()) { + unboundedMqttReader = + new UnboundedMqttReader<>( + this, + checkpointMark, + message -> (T) MqttRecord.of(message.getTopic(), message.getPayload())); + } else { + unboundedMqttReader = new UnboundedMqttReader<>(this, checkpointMark); + } + + return unboundedMqttReader; } @Override - public List split(int desiredNumSplits, PipelineOptions options) { + public List> split(int desiredNumSplits, PipelineOptions options) { // MQTT is based on a pub/sub pattern // so, if we create several subscribers on the same topic, they all will receive the same // message, resulting to duplicate messages in the PCollection. // So, for MQTT, we limit to number of split ot 1 (unique source). - return Collections.singletonList(new UnboundedMqttSource(spec)); + return Collections.singletonList(new UnboundedMqttSource<>(spec)); } @Override @@ -437,23 +524,24 @@ public Coder getCheckpointMarkCoder() { } @Override - public Coder getOutputCoder() { - return ByteArrayCoder.of(); + public Coder getOutputCoder() { + return checkNotNull(this.spec.coder(), "coder can not be null"); } } @VisibleForTesting - static class UnboundedMqttReader extends UnboundedSource.UnboundedReader { + static class UnboundedMqttReader extends UnboundedSource.UnboundedReader { - private final UnboundedMqttSource source; + private final UnboundedMqttSource source; private MQTT client; private BlockingConnection connection; - private byte[] current; + private T current; private Instant currentTimestamp; private MqttCheckpointMark checkpointMark; + private SerializableFunction extractFn; - public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkpointMark) { + public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkpointMark) { this.source = source; this.current = null; if (checkpointMark != null) { @@ -461,12 +549,21 @@ public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkp } else { this.checkpointMark = new MqttCheckpointMark(); } + this.extractFn = message -> (T) message.getPayload(); + } + + public UnboundedMqttReader( + UnboundedMqttSource source, + MqttCheckpointMark checkpointMark, + SerializableFunction extractFn) { + this(source, checkpointMark); + this.extractFn = extractFn; } @Override public boolean start() throws IOException { LOG.debug("Starting MQTT reader ..."); - Read spec = source.spec; + Read spec = source.spec; try { client = spec.connectionConfiguration().createClient(); LOG.debug("Reader client ID is {}", client.getClientId()); @@ -488,7 +585,7 @@ public boolean advance() throws IOException { if (message == null) { return false; } - current = message.getPayload(); + current = this.extractFn.apply(message); currentTimestamp = Instant.now(); checkpointMark.add(message, currentTimestamp); } catch (Exception e) { @@ -520,7 +617,7 @@ public UnboundedSource.CheckpointMark getCheckpointMark() { } @Override - public byte[] getCurrent() { + public T getCurrent() { if (current == null) { throw new NoSuchElementException(); } @@ -536,7 +633,7 @@ public Instant getCurrentTimestamp() { } @Override - public UnboundedMqttSource getCurrentSource() { + public UnboundedMqttSource getCurrentSource() { return source; } } diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java new file mode 100644 index 000000000000..bbf27f5c73e7 --- /dev/null +++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttRecord.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.mqtt; + +import com.google.auto.value.AutoValue; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; + +/** A container class for MQTT message metadata, including the topic name and payload. */ +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract class MqttRecord { + public abstract String getTopic(); + + @SuppressWarnings("mutable") + public abstract byte[] getPayload(); + + static Builder builder() { + return new AutoValue_MqttRecord.Builder(); + } + + static MqttRecord of(String topic, byte[] payload) { + return builder().setTopic(topic).setPayload(payload).build(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setTopic(String topic); + + abstract Builder setPayload(byte[] payload); + + abstract MqttRecord build(); + } +} diff --git a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java index 8dfa7838d66a..3ee6ed577a07 100644 --- a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java +++ b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java @@ -44,6 +44,7 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -68,6 +69,18 @@ @RunWith(JUnit4.class) public class MqttIOTest { + /** Functional interface used to verify the connection status of an MQTT client. */ + @FunctionalInterface + interface ConnectionCondition { + /** + * Evaluates whether the given {@link Connection} satisfies the condition. + * + * @param connection the MQTT connection to check + * @return {@code true} if the condition is met, {@code false} otherwise + */ + boolean check(Connection connection); + } + private static final Logger LOG = LoggerFactory.getLogger(MqttIOTest.class); private BrokerService brokerService; @@ -93,7 +106,7 @@ public void startBroker() throws Exception { @Ignore("https://github.com/apache/beam/issues/18723 Test timeout failure.") public void testReadNoClientId() throws Exception { final String topicName = "READ_TOPIC_NO_CLIENT_ID"; - Read mqttReader = + Read mqttReader = MqttIO.read() .withConnectionConfiguration( MqttIO.ConnectionConfiguration.create("tcp://localhost:" + port, topicName)) @@ -122,18 +135,7 @@ public void testReadNoClientId() throws Exception { new Thread( () -> { try { - LOG.info( - "Waiting pipeline connected to the MQTT broker before sending " - + "messages ..."); - boolean pipelineConnected = false; - while (!pipelineConnected) { - Thread.sleep(1000); - for (Connection connection : brokerService.getBroker().getClients()) { - if (!connection.getConnectionId().isEmpty()) { - pipelineConnected = true; - } - } - } + doConnect(connection -> !connection.getConnectionId().isEmpty()); for (int i = 0; i < 10; i++) { publishConnection.publish( topicName, @@ -184,18 +186,7 @@ public void testRead() throws Exception { new Thread( () -> { try { - LOG.info( - "Waiting pipeline connected to the MQTT broker before sending " - + "messages ..."); - boolean pipelineConnected = false; - while (!pipelineConnected) { - for (Connection connection : brokerService.getBroker().getClients()) { - if (connection.getConnectionId().startsWith("READ_PIPELINE")) { - pipelineConnected = true; - } - } - Thread.sleep(1000); - } + doConnect(connection -> connection.getConnectionId().startsWith("READ_PIPELINE")); for (int i = 0; i < 10; i++) { publishConnection.publish( "READ_TOPIC", @@ -214,6 +205,71 @@ public void testRead() throws Exception { publishConnection.disconnect(); } + @Test(timeout = 60 * 1000) + public void testReadWithMetadata() throws Exception { + final String wildcardTopic = "topic/#"; + final String topic1 = "topic/1"; + final String topic2 = "topic/2"; + + final PTransform> mqttReaderWithMetadata = + MqttIO.readWithMetadata() + .withConnectionConfiguration( + MqttIO.ConnectionConfiguration.create("tcp://localhost:" + port, wildcardTopic)) + .withMaxNumRecords(10) + .withMaxReadTime(Duration.standardSeconds(5)); + + final PCollection output = pipeline.apply(mqttReaderWithMetadata); + PAssert.that(output) + .containsInAnyOrder( + MqttRecord.of(topic1, "This is test 0".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic1, "This is test 1".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic1, "This is test 2".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic1, "This is test 3".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic1, "This is test 4".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 5".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 6".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 7".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 8".getBytes(StandardCharsets.UTF_8)), + MqttRecord.of(topic2, "This is test 9".getBytes(StandardCharsets.UTF_8))); + + // produce messages on the brokerService in another thread + // This thread prevents to block the pipeline waiting for new messages + MQTT client = new MQTT(); + client.setHost("tcp://localhost:" + port); + final BlockingConnection publishConnection = client.blockingConnection(); + publishConnection.connect(); + Thread publisherThread = + new Thread( + () -> { + try { + doConnect(connection -> !connection.getConnectionId().isEmpty()); + for (int i = 0; i < 5; i++) { + publishConnection.publish( + topic1, + ("This is test " + i).getBytes(StandardCharsets.UTF_8), + QoS.EXACTLY_ONCE, + false); + } + for (int i = 5; i < 10; i++) { + publishConnection.publish( + topic2, + ("This is test " + i).getBytes(StandardCharsets.UTF_8), + QoS.EXACTLY_ONCE, + false); + } + + } catch (Exception e) { + // nothing to do + } + }); + + publisherThread.start(); + pipeline.run(); + + publishConnection.disconnect(); + publisherThread.join(); + } + /** Test for BEAM-3282: this test should not timeout. */ @Test(timeout = 30 * 1000) public void testReceiveWithTimeoutAndNoData() throws Exception { @@ -505,6 +561,30 @@ public void testReadObject() throws Exception { assertEquals(cp1.oldestMessageTimestamp, cp2.oldestMessageTimestamp); } + /** + * Attempts to establish a connection to the MQTT broker by checking each available client + * connection until the specified condition is met. + * + *

    This method repeatedly checks the connection status of each MQTT client using the provided + * {@link ConnectionCondition}. It blocks execution within a loop, sleeping for 1 second between + * each check, until the condition is satisfied. + * + * @param condition the condition used to verify the connection status + * @throws Exception if any error occurs during the connection process + */ + private void doConnect(ConnectionCondition condition) throws Exception { + LOG.info("Waiting pipeline connected to the MQTT broker before sending messages ..."); + boolean pipelineConnected = false; + while (!pipelineConnected) { + for (Connection connection : brokerService.getBroker().getClients()) { + if (condition.check(connection)) { + pipelineConnected = true; + } + } + Thread.sleep(1000); + } + } + @After public void stopBroker() throws Exception { if (brokerService != null) { diff --git a/sdks/java/io/parquet/build.gradle b/sdks/java/io/parquet/build.gradle index e8f1603f0b58..d5f22b31cc56 100644 --- a/sdks/java/io/parquet/build.gradle +++ b/sdks/java/io/parquet/build.gradle @@ -27,10 +27,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Parquet" ext.summary = "IO to read and write on Parquet storage format." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/snowflake/build.gradle b/sdks/java/io/snowflake/build.gradle index 2bdb9a867a34..9be257033edb 100644 --- a/sdks/java/io/snowflake/build.gradle +++ b/sdks/java/io/snowflake/build.gradle @@ -30,7 +30,7 @@ dependencies { implementation project(path: ":sdks:java:extensions:google-cloud-platform-core") permitUnusedDeclared project(path: ":sdks:java:extensions:google-cloud-platform-core") implementation library.java.slf4j_api - implementation group: 'net.snowflake', name: 'snowflake-jdbc', version: '3.12.11' + implementation group: 'net.snowflake', name: 'snowflake-jdbc', version: '3.20.0' implementation group: 'com.opencsv', name: 'opencsv', version: '5.0' implementation 'net.snowflake:snowflake-ingest-sdk:0.9.9' implementation "org.bouncycastle:bcprov-jdk15on:1.70" diff --git a/sdks/java/io/solace/build.gradle b/sdks/java/io/solace/build.gradle index 741db51a5772..ef0d49891f08 100644 --- a/sdks/java/io/solace/build.gradle +++ b/sdks/java/io/solace/build.gradle @@ -53,6 +53,7 @@ dependencies { testImplementation library.java.junit testImplementation project(path: ":sdks:java:io:common") testImplementation project(path: ":sdks:java:testing:test-utils") + testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testRuntimeOnly library.java.slf4j_jdk14 testImplementation library.java.testcontainers_solace testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java index dcfdcc4fabb9..a55d8a0a4217 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.solace; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; @@ -38,16 +39,29 @@ import org.apache.beam.sdk.io.solace.broker.SessionService; import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Record; import org.apache.beam.sdk.io.solace.data.Solace.SolaceRecordMapper; import org.apache.beam.sdk.io.solace.read.UnboundedSolaceSource; +import org.apache.beam.sdk.io.solace.write.AddShardKeyDoFn; import org.apache.beam.sdk.io.solace.write.SolaceOutput; +import org.apache.beam.sdk.io.solace.write.UnboundedBatchedSolaceWriter; +import org.apache.beam.sdk.io.solace.write.UnboundedStreamingSolaceWriter; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; @@ -147,7 +161,7 @@ * function. * *

    {@code
    - * @DefaultSchema(JavaBeanSchema.class)
    + * {@literal @}DefaultSchema(JavaBeanSchema.class)
      * public static class SimpleRecord {
      *    public String payload;
      *    public String messageId;
    @@ -238,7 +252,7 @@
      * default VPN name by setting the required JCSMP property in the session factory (in this case,
      * with {@link BasicAuthJcsmpSessionServiceFactory#vpnName()}), the number of clients per worker
      * with {@link Write#withNumberOfClientsPerWorker(int)} and the number of parallel write clients
    - * using {@link Write#withMaxNumOfUsedWorkers(int)}.
    + * using {@link Write#withNumShards(int)}.
      *
      * 

    Writing to dynamic destinations

    * @@ -345,13 +359,17 @@ * *

    The streaming connector publishes each message individually, without holding up or batching * before the message is sent to Solace. This will ensure the lowest possible latency, but it will - * offer a much lower throughput. The streaming connector does not use state & timers. + * offer a much lower throughput. The streaming connector does not use state and timers. * - *

    Both connectors uses state & timers to control the level of parallelism. If you are using + *

    Both connectors uses state and timers to control the level of parallelism. If you are using * Cloud Dataflow, it is recommended that you enable Streaming Engine to use this * connector. * + *

    For full control over all the properties, use {@link SubmissionMode#CUSTOM}. The connector + * will not override any property that you set, and you will have full control over all the JCSMP + * properties. + * *

    Authentication

    * *

    When writing to Solace, the user must use {@link @@ -396,7 +414,7 @@ public class SolaceIO { private static final boolean DEFAULT_DEDUPLICATE_RECORDS = false; private static final Duration DEFAULT_WATERMARK_IDLE_DURATION_THRESHOLD = Duration.standardSeconds(30); - public static final int DEFAULT_WRITER_MAX_NUMBER_OF_WORKERS = 20; + public static final int DEFAULT_WRITER_NUM_SHARDS = 20; public static final int DEFAULT_WRITER_CLIENTS_PER_WORKER = 4; public static final Boolean DEFAULT_WRITER_PUBLISH_LATENCY_METRICS = false; public static final SubmissionMode DEFAULT_WRITER_SUBMISSION_MODE = @@ -445,6 +463,7 @@ public static Read read() { .setDeduplicateRecords(DEFAULT_DEDUPLICATE_RECORDS) .setWatermarkIdleDurationThreshold(DEFAULT_WATERMARK_IDLE_DURATION_THRESHOLD)); } + /** * Create a {@link Read} transform, to read from Solace. Specify a {@link SerializableFunction} to * map incoming {@link BytesXMLMessage} records, to the object of your choice. You also need to @@ -805,7 +824,9 @@ private Queue initializeQueueForTopicIfNeeded( public enum SubmissionMode { HIGHER_THROUGHPUT, - LOWER_LATENCY + LOWER_LATENCY, + CUSTOM, // Don't override any property set by the user + TESTING // Send acks 1 by 1, this will be very slow, never use this in an actual pipeline! } public enum WriterType { @@ -816,8 +837,9 @@ public enum WriterType { @AutoValue public abstract static class Write extends PTransform, SolaceOutput> { - public static final TupleTag FAILED_PUBLISH_TAG = - new TupleTag() {}; + private static final Logger LOG = LoggerFactory.getLogger(Write.class); + + public static final TupleTag FAILED_PUBLISH_TAG = new TupleTag() {}; public static final TupleTag SUCCESSFUL_PUBLISH_TAG = new TupleTag() {}; @@ -863,8 +885,8 @@ public Write to(Solace.Queue queue) { * cluster, and the need for performance when writing to Solace (more workers will achieve * higher throughput). */ - public Write withMaxNumOfUsedWorkers(int maxNumOfUsedWorkers) { - return toBuilder().setMaxNumOfUsedWorkers(maxNumOfUsedWorkers).build(); + public Write withNumShards(int numShards) { + return toBuilder().setNumShards(numShards).build(); } /** @@ -877,8 +899,8 @@ public Write withMaxNumOfUsedWorkers(int maxNumOfUsedWorkers) { * the number of clients created per VM. The clients will be re-used across different threads in * the same worker. * - *

    Set this number in combination with {@link #withMaxNumOfUsedWorkers}, to ensure that the - * limit for number of clients in your Solace cluster is not exceeded. + *

    Set this number in combination with {@link #withNumShards}, to ensure that the limit for + * number of clients in your Solace cluster is not exceeded. * *

    Normally, using a higher number of clients with fewer workers will achieve better * throughput at a lower cost, since the workers are better utilized. A good rule of thumb to @@ -921,15 +943,19 @@ public Write publishLatencyMetrics() { *

    For full details, please check https://docs.solace.com/API/API-Developer-Guide/Java-API-Best-Practices.htm. * - *

    The Solace JCSMP client libraries can dispatch messages using two different modes: + *

    The Solace JCSMP client libraries can dispatch messages using three different modes: * *

    One of the modes dispatches messages directly from the same thread that is doing the rest * of I/O work. This mode favors lower latency but lower throughput. Set this to LOWER_LATENCY * to use that mode (MESSAGE_CALLBACK_ON_REACTOR set to True). * - *

    The other mode uses a parallel thread to accumulate and dispatch messages. This mode - * favors higher throughput but also has higher latency. Set this to HIGHER_THROUGHPUT to use - * that mode. This is the default mode (MESSAGE_CALLBACK_ON_REACTOR set to False). + *

    Another mode uses a parallel thread to accumulate and dispatch messages. This mode favors + * higher throughput but also has higher latency. Set this to HIGHER_THROUGHPUT to use that + * mode. This is the default mode (MESSAGE_CALLBACK_ON_REACTOR set to False). + * + *

    If you prefer to have full control over all the JCSMP properties, set this to CUSTOM, and + * override the classes {@link SessionServiceFactory} and {@link SessionService} to have full + * control on how to create the JCSMP sessions and producers used by the connector. * *

    This is optional, the default value is HIGHER_THROUGHPUT. */ @@ -945,10 +971,12 @@ public Write withSubmissionMode(SubmissionMode submissionMode) { *

    In streaming mode, the publishing latency will be lower, but the throughput will also be * lower. * - *

    With the batched mode, messages are accumulated until a batch size of 50 is reached, or 5 - * seconds have elapsed since the first message in the batch was received. The 50 messages are - * sent to Solace in a single batch. This writer offers higher throughput but higher publishing - * latency, as messages can be held up for up to 5 seconds until they are published. + *

    With the batched mode, messages are accumulated until a batch size of 50 is reached, or + * {@link UnboundedBatchedSolaceWriter#ACKS_FLUSHING_INTERVAL_SECS} seconds have elapsed since + * the first message in the batch was received. The 50 messages are sent to Solace in a single + * batch. This writer offers higher throughput but higher publishing latency, as messages can be + * held up for up to {@link UnboundedBatchedSolaceWriter#ACKS_FLUSHING_INTERVAL_SECS}5seconds + * until they are published. * *

    Notice that this is the message publishing latency, not the end-to-end latency. For very * large scale pipelines, you will probably prefer to use the HIGHER_THROUGHPUT mode, as with @@ -971,7 +999,20 @@ public Write withSessionServiceFactory(SessionServiceFactory factory) { return toBuilder().setSessionServiceFactory(factory).build(); } - abstract int getMaxNumOfUsedWorkers(); + /** + * An optional error handler for handling records that failed to publish to Solace. + * + *

    If provided, this error handler will be invoked for each record that could not be + * successfully published. The error handler can implement custom logic for dealing with failed + * records, such as writing them to a dead-letter queue or logging them. + * + *

    If no error handler is provided, failed records will be ignored. + */ + public Write withErrorHandler(ErrorHandler errorHandler) { + return toBuilder().setErrorHandler(errorHandler).build(); + } + + abstract int getNumShards(); abstract int getNumberOfClientsPerWorker(); @@ -989,10 +1030,12 @@ public Write withSessionServiceFactory(SessionServiceFactory factory) { abstract @Nullable SessionServiceFactory getSessionServiceFactory(); + abstract @Nullable ErrorHandler getErrorHandler(); + static Builder builder() { return new AutoValue_SolaceIO_Write.Builder() .setDeliveryMode(DEFAULT_WRITER_DELIVERY_MODE) - .setMaxNumOfUsedWorkers(DEFAULT_WRITER_MAX_NUMBER_OF_WORKERS) + .setNumShards(DEFAULT_WRITER_NUM_SHARDS) .setNumberOfClientsPerWorker(DEFAULT_WRITER_CLIENTS_PER_WORKER) .setPublishLatencyMetrics(DEFAULT_WRITER_PUBLISH_LATENCY_METRICS) .setDispatchMode(DEFAULT_WRITER_SUBMISSION_MODE) @@ -1003,7 +1046,7 @@ static Builder builder() { @AutoValue.Builder abstract static class Builder { - abstract Builder setMaxNumOfUsedWorkers(int maxNumOfUsedWorkers); + abstract Builder setNumShards(int numShards); abstract Builder setNumberOfClientsPerWorker(int numberOfClientsPerWorker); @@ -1021,13 +1064,121 @@ abstract static class Builder { abstract Builder setSessionServiceFactory(SessionServiceFactory factory); + abstract Builder setErrorHandler(ErrorHandler errorHandler); + abstract Write build(); } @Override public SolaceOutput expand(PCollection input) { - // TODO: will be sent in upcoming PR - return SolaceOutput.in(input.getPipeline(), null, null); + boolean usingSolaceRecord = + TypeDescriptor.of(Solace.Record.class) + .isSupertypeOf(checkNotNull(input.getTypeDescriptor())); + + validateWriteTransform(usingSolaceRecord); + + boolean usingDynamicDestinations = getDestination() == null; + SerializableFunction destinationFn; + if (usingDynamicDestinations) { + destinationFn = x -> SolaceIO.convertToJcsmpDestination(checkNotNull(x.getDestination())); + } else { + // Constant destination for all messages (same topic or queue) + // This should not be non-null, as nulls would have been flagged by the + // validateWriteTransform method + destinationFn = x -> checkNotNull(getDestination()); + } + + @SuppressWarnings("unchecked") + PCollection records = + usingSolaceRecord + ? (PCollection) input + : input.apply( + "Format records", + MapElements.into(TypeDescriptor.of(Solace.Record.class)) + .via(checkNotNull(getFormatFunction()))); + + PCollection withGlobalWindow = + records.apply("Global window", Window.into(new GlobalWindows())); + + PCollection> withShardKeys = + withGlobalWindow.apply("Add shard key", ParDo.of(new AddShardKeyDoFn(getNumShards()))); + + String label = + getWriterType() == WriterType.STREAMING ? "Publish (streaming)" : "Publish (batched)"; + + PCollectionTuple solaceOutput = withShardKeys.apply(label, getWriterTransform(destinationFn)); + + SolaceOutput output; + if (getDeliveryMode() == DeliveryMode.PERSISTENT) { + if (getErrorHandler() != null) { + checkNotNull(getErrorHandler()).addErrorCollection(solaceOutput.get(FAILED_PUBLISH_TAG)); + } + output = SolaceOutput.in(input.getPipeline(), solaceOutput.get(SUCCESSFUL_PUBLISH_TAG)); + } else { + LOG.info( + "Solace.Write: omitting writer output because delivery mode is {}", getDeliveryMode()); + output = SolaceOutput.in(input.getPipeline(), null); + } + + return output; + } + + private ParDo.MultiOutput, Solace.PublishResult> getWriterTransform( + SerializableFunction destinationFn) { + + ParDo.SingleOutput, Solace.PublishResult> writer = + ParDo.of( + getWriterType() == WriterType.STREAMING + ? new UnboundedStreamingSolaceWriter( + destinationFn, + checkNotNull(getSessionServiceFactory()), + getDeliveryMode(), + getDispatchMode(), + getNumberOfClientsPerWorker(), + getPublishLatencyMetrics()) + : new UnboundedBatchedSolaceWriter( + destinationFn, + checkNotNull(getSessionServiceFactory()), + getDeliveryMode(), + getDispatchMode(), + getNumberOfClientsPerWorker(), + getPublishLatencyMetrics())); + + return writer.withOutputTags(SUCCESSFUL_PUBLISH_TAG, TupleTagList.of(FAILED_PUBLISH_TAG)); + } + + /** + * Called before running the Pipeline to verify this transform is fully and correctly specified. + */ + private void validateWriteTransform(boolean usingSolaceRecords) { + if (!usingSolaceRecords) { + checkNotNull( + getFormatFunction(), + "SolaceIO.Write: If you are not using Solace.Record as the input type, you" + + " must set a format function using withFormatFunction()."); + } + + checkArgument( + getNumShards() > 0, "SolaceIO.Write: The number of used workers must be positive."); + checkArgument( + getNumberOfClientsPerWorker() > 0, + "SolaceIO.Write: The number of clients per worker must be positive."); + checkArgument( + getDeliveryMode() == DeliveryMode.DIRECT || getDeliveryMode() == DeliveryMode.PERSISTENT, + String.format( + "SolaceIO.Write: Delivery mode must be either DIRECT or PERSISTENT. %s" + + " not supported", + getDeliveryMode())); + if (getPublishLatencyMetrics()) { + checkArgument( + getDeliveryMode() == DeliveryMode.PERSISTENT, + "SolaceIO.Write: Publish latency metrics can only be enabled for PERSISTENT" + + " delivery mode."); + } + checkNotNull( + getSessionServiceFactory(), + "SolaceIO: You need to pass a session service factory. For basic" + + " authentication, you can use BasicAuthJcsmpSessionServiceFactory."); } } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java index 2137d574b09a..b2196dbf1067 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java @@ -19,6 +19,7 @@ import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import com.google.auto.value.AutoValue; import com.solacesystems.jcsmp.ConsumerFlowProperties; import com.solacesystems.jcsmp.EndpointProperties; import com.solacesystems.jcsmp.FlowReceiver; @@ -28,9 +29,15 @@ import com.solacesystems.jcsmp.JCSMPProperties; import com.solacesystems.jcsmp.JCSMPSession; import com.solacesystems.jcsmp.Queue; +import com.solacesystems.jcsmp.XMLMessageProducer; import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentLinkedQueue; import javax.annotation.Nullable; import org.apache.beam.sdk.io.solace.RetryCallableManager; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; /** @@ -39,34 +46,50 @@ *

    This class provides a way to connect to a Solace broker and receive messages from a queue. The * connection is established using basic authentication. */ -public class BasicAuthJcsmpSessionService extends SessionService { - private final String queueName; - private final String host; - private final String username; - private final String password; - private final String vpnName; - @Nullable private JCSMPSession jcsmpSession; - @Nullable private MessageReceiver messageReceiver; - private final RetryCallableManager retryCallableManager = RetryCallableManager.create(); +@AutoValue +public abstract class BasicAuthJcsmpSessionService extends SessionService { + + /** The name of the queue to receive messages from. */ + public abstract @Nullable String queueName(); + + /** The host name or IP address of the Solace broker. Format: Host[:Port] */ + public abstract String host(); + + /** The username to use for authentication. */ + public abstract String username(); + + /** The password to use for authentication. */ + public abstract String password(); + + /** The name of the VPN to connect to. */ + public abstract String vpnName(); + + public static Builder builder() { + return new AutoValue_BasicAuthJcsmpSessionService.Builder().vpnName(DEFAULT_VPN_NAME); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder queueName(@Nullable String queueName); + + public abstract Builder host(String host); - /** - * Creates a new {@link BasicAuthJcsmpSessionService} with the given parameters. - * - * @param queueName The name of the queue to receive messages from. - * @param host The host name or IP address of the Solace broker. Format: Host[:Port] - * @param username The username to use for authentication. - * @param password The password to use for authentication. - * @param vpnName The name of the VPN to connect to. - */ - public BasicAuthJcsmpSessionService( - String queueName, String host, String username, String password, String vpnName) { - this.queueName = queueName; - this.host = host; - this.username = username; - this.password = password; - this.vpnName = vpnName; + public abstract Builder username(String username); + + public abstract Builder password(String password); + + public abstract Builder vpnName(String vpnName); + + public abstract BasicAuthJcsmpSessionService build(); } + @Nullable private transient JCSMPSession jcsmpSession; + @Nullable private transient MessageReceiver messageReceiver; + @Nullable private transient MessageProducer messageProducer; + private final java.util.Queue publishedResultsQueue = + new ConcurrentLinkedQueue<>(); + private final RetryCallableManager retryCallableManager = RetryCallableManager.create(); + @Override public void connect() { retryCallableManager.retryCallable(this::connectSession, ImmutableSet.of(JCSMPException.class)); @@ -79,6 +102,9 @@ public void close() { if (messageReceiver != null) { messageReceiver.close(); } + if (messageProducer != null) { + messageProducer.close(); + } if (!isClosed()) { checkStateNotNull(jcsmpSession).closeSession(); } @@ -88,24 +114,64 @@ public void close() { } @Override - public MessageReceiver createReceiver() { - this.messageReceiver = - retryCallableManager.retryCallable( - this::createFlowReceiver, ImmutableSet.of(JCSMPException.class)); + public MessageReceiver getReceiver() { + if (this.messageReceiver == null) { + this.messageReceiver = + retryCallableManager.retryCallable( + this::createFlowReceiver, ImmutableSet.of(JCSMPException.class)); + } return this.messageReceiver; } + @Override + public MessageProducer getInitializedProducer(SubmissionMode submissionMode) { + if (this.messageProducer == null || this.messageProducer.isClosed()) { + Callable create = () -> createXMLMessageProducer(submissionMode); + this.messageProducer = + retryCallableManager.retryCallable(create, ImmutableSet.of(JCSMPException.class)); + } + return checkStateNotNull(this.messageProducer); + } + + @Override + public java.util.Queue getPublishedResultsQueue() { + return publishedResultsQueue; + } + @Override public boolean isClosed() { return jcsmpSession == null || jcsmpSession.isClosed(); } + private MessageProducer createXMLMessageProducer(SubmissionMode submissionMode) + throws JCSMPException, IOException { + + if (isClosed()) { + connectWriteSession(submissionMode); + } + + @SuppressWarnings("nullness") + Callable initProducer = + () -> + Objects.requireNonNull(jcsmpSession) + .getMessageProducer(new PublishResultHandler(publishedResultsQueue)); + + XMLMessageProducer producer = + retryCallableManager.retryCallable(initProducer, ImmutableSet.of(JCSMPException.class)); + if (producer == null) { + throw new IOException("SolaceIO.Write: Could not create producer, producer object is null"); + } + return new SolaceMessageProducer(producer); + } + private MessageReceiver createFlowReceiver() throws JCSMPException, IOException { if (isClosed()) { connectSession(); } - Queue queue = JCSMPFactory.onlyInstance().createQueue(queueName); + Queue queue = + JCSMPFactory.onlyInstance() + .createQueue(checkStateNotNull(queueName(), "SolaceIO.Read: Queue is not set.")); ConsumerFlowProperties flowProperties = new ConsumerFlowProperties(); flowProperties.setEndpoint(queue); @@ -118,7 +184,8 @@ private MessageReceiver createFlowReceiver() throws JCSMPException, IOException createFlowReceiver(jcsmpSession, flowProperties, endpointProperties)); } throw new IOException( - "SolaceIO.Read: Could not create a receiver from the Jcsmp session: session object is null."); + "SolaceIO.Read: Could not create a receiver from the Jcsmp session: session object is" + + " null."); } // The `@SuppressWarning` is needed here, because the checkerframework reports an error for the @@ -141,20 +208,33 @@ private int connectSession() throws JCSMPException { return 0; } + private int connectWriteSession(SubmissionMode mode) throws JCSMPException { + if (jcsmpSession == null) { + jcsmpSession = createWriteSessionObject(mode); + } + jcsmpSession.connect(); + return 0; + } + private JCSMPSession createSessionObject() throws InvalidPropertiesException { JCSMPProperties properties = initializeSessionProperties(new JCSMPProperties()); return JCSMPFactory.onlyInstance().createSession(properties); } + private JCSMPSession createWriteSessionObject(SubmissionMode mode) + throws InvalidPropertiesException { + return JCSMPFactory.onlyInstance().createSession(initializeWriteSessionProperties(mode)); + } + @Override public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProps) { - baseProps.setProperty(JCSMPProperties.VPN_NAME, vpnName); + baseProps.setProperty(JCSMPProperties.VPN_NAME, vpnName()); baseProps.setProperty( JCSMPProperties.AUTHENTICATION_SCHEME, JCSMPProperties.AUTHENTICATION_SCHEME_BASIC); - baseProps.setProperty(JCSMPProperties.USERNAME, username); - baseProps.setProperty(JCSMPProperties.PASSWORD, password); - baseProps.setProperty(JCSMPProperties.HOST, host); + baseProps.setProperty(JCSMPProperties.USERNAME, username()); + baseProps.setProperty(JCSMPProperties.PASSWORD, password()); + baseProps.setProperty(JCSMPProperties.HOST, host()); return baseProps; } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java index 2084e61b7e38..199dcccee854 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.io.solace.broker; import static org.apache.beam.sdk.io.solace.broker.SessionService.DEFAULT_VPN_NAME; -import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import com.google.auto.value.AutoValue; @@ -31,12 +30,16 @@ */ @AutoValue public abstract class BasicAuthJcsmpSessionServiceFactory extends SessionServiceFactory { + /** The host name or IP address of the Solace broker. Format: Host[:Port] */ public abstract String host(); + /** The username to use for authentication. */ public abstract String username(); + /** The password to use for authentication. */ public abstract String password(); + /** The name of the VPN to connect to. */ public abstract String vpnName(); public static Builder builder() { @@ -54,6 +57,7 @@ public abstract static class Builder { /** Set Solace username. */ public abstract Builder username(String username); + /** Set Solace password. */ public abstract Builder password(String password); @@ -65,11 +69,15 @@ public abstract static class Builder { @Override public SessionService create() { - return new BasicAuthJcsmpSessionService( - checkStateNotNull(queue, "SolaceIO.Read: Queue is not set.").getName(), - host(), - username(), - password(), - vpnName()); + BasicAuthJcsmpSessionService.Builder builder = BasicAuthJcsmpSessionService.builder(); + if (queue != null) { + builder = builder.queueName(queue.getName()); + } + return builder + .host(host()) + .username(username()) + .password(password()) + .vpnName(vpnName()) + .build(); } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java index 4884bb61e628..0a9ee4618b1e 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java @@ -17,14 +17,10 @@ */ package org.apache.beam.sdk.io.solace.broker; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.api.client.http.HttpRequestFactory; import com.solacesystems.jcsmp.JCSMPFactory; import java.io.IOException; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.io.solace.data.Semp.Queue; import org.apache.beam.sdk.util.SerializableSupplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,8 +36,6 @@ @Internal public class BasicAuthSempClient implements SempClient { private static final Logger LOG = LoggerFactory.getLogger(BasicAuthSempClient.class); - private final ObjectMapper objectMapper = - new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); private final SempBasicAuthClientExecutor sempBasicAuthClientExecutor; @@ -58,13 +52,12 @@ public BasicAuthSempClient( @Override public boolean isQueueNonExclusive(String queueName) throws IOException { - LOG.info("SolaceIO.Read: SempOperations: query SEMP if queue {} is nonExclusive", queueName); - BrokerResponse response = sempBasicAuthClientExecutor.getQueueResponse(queueName); - if (response.content == null) { - throw new IOException("SolaceIO: response from SEMP is empty!"); - } - Queue q = mapJsonToClass(response.content, Queue.class); - return q.data().accessType().equals("non-exclusive"); + boolean queueNonExclusive = sempBasicAuthClientExecutor.isQueueNonExclusive(queueName); + LOG.info( + "SolaceIO.Read: SempOperations: queried SEMP if queue {} is non-exclusive: {}", + queueName, + queueNonExclusive); + return queueNonExclusive; } @Override @@ -77,12 +70,7 @@ public com.solacesystems.jcsmp.Queue createQueueForTopic(String queueName, Strin @Override public long getBacklogBytes(String queueName) throws IOException { - BrokerResponse response = sempBasicAuthClientExecutor.getQueueResponse(queueName); - if (response.content == null) { - throw new IOException("SolaceIO: response from SEMP is empty!"); - } - Queue q = mapJsonToClass(response.content, Queue.class); - return q.data().msgSpoolUsage(); + return sempBasicAuthClientExecutor.getBacklogBytes(queueName); } private void createQueue(String queueName) throws IOException { @@ -94,9 +82,4 @@ private void createSubscription(String queueName, String topicName) throws IOExc LOG.info("SolaceIO.Read: Creating new subscription {} for topic {}.", queueName, topicName); sempBasicAuthClientExecutor.createSubscriptionResponse(queueName, topicName); } - - private T mapJsonToClass(String content, Class mapSuccessToClass) - throws JsonProcessingException { - return objectMapper.readValue(content, mapSuccessToClass); - } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java index dd87e1d75fa5..7f691b46be31 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java @@ -117,7 +117,7 @@ public abstract static class Builder { @Override public SessionService create() { - String password = null; + String password; try { password = retrieveSecret(); } catch (IOException e) { diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducer.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducer.java new file mode 100644 index 000000000000..8aa254b92cb1 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducer.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.SerializableFunction; + +/** + * Base class for publishing messages to a Solace broker. + * + *

    Implementations of this interface are responsible for managing the connection to the broker + * and for publishing messages to the broker. + */ +@Internal +public interface MessageProducer { + + /** Publishes a message to the broker. */ + void publishSingleMessage( + Solace.Record msg, + Destination topicOrQueue, + boolean useCorrelationKeyLatency, + DeliveryMode deliveryMode); + + /** + * Publishes a batch of messages to the broker. + * + *

    The size of the batch cannot exceed 50 messages, this is a limitation of the Solace API. + * + *

    It returns the number of messages written. + */ + int publishBatch( + List records, + boolean useCorrelationKeyLatency, + SerializableFunction destinationFn, + DeliveryMode deliveryMode); + + /** Returns {@literal true} if the message producer is closed, {@literal false} otherwise. */ + boolean isClosed(); + + /** Closes the message producer. */ + void close(); +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducerUtils.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducerUtils.java new file mode 100644 index 000000000000..dd4610910ff4 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducerUtils.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import com.solacesystems.jcsmp.BytesXMLMessage; +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import com.solacesystems.jcsmp.JCSMPFactory; +import com.solacesystems.jcsmp.JCSMPSendMultipleEntry; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.SerializableFunction; + +@Internal +public class MessageProducerUtils { + // This is the batch limit supported by the send multiple JCSMP API method. + static final int SOLACE_BATCH_LIMIT = 50; + + /** + * Create a {@link BytesXMLMessage} to be published in Solace. + * + * @param record The record to be published. + * @param useCorrelationKeyLatency Whether to use a complex key for tracking latency. + * @param deliveryMode The {@link DeliveryMode} used to publish the message. + * @return A {@link BytesXMLMessage} that can be sent to Solace "as is". + */ + public static BytesXMLMessage createBytesXMLMessage( + Solace.Record record, boolean useCorrelationKeyLatency, DeliveryMode deliveryMode) { + JCSMPFactory jcsmpFactory = JCSMPFactory.onlyInstance(); + BytesXMLMessage msg = jcsmpFactory.createBytesXMLMessage(); + byte[] payload = record.getPayload(); + msg.writeBytes(payload); + + Long senderTimestamp = record.getSenderTimestamp(); + if (senderTimestamp == null) { + senderTimestamp = System.currentTimeMillis(); + } + msg.setSenderTimestamp(senderTimestamp); + msg.setDeliveryMode(deliveryMode); + if (useCorrelationKeyLatency) { + Solace.CorrelationKey key = + Solace.CorrelationKey.builder() + .setMessageId(record.getMessageId()) + .setPublishMonotonicNanos(System.nanoTime()) + .build(); + msg.setCorrelationKey(key); + } else { + // Use only a string as correlation key + msg.setCorrelationKey(record.getMessageId()); + } + msg.setApplicationMessageId(record.getMessageId()); + return msg; + } + + /** + * Create a {@link JCSMPSendMultipleEntry} array to be published in Solace. This can be used with + * `sendMultiple` to send all the messages in a single API call. + * + *

    The size of the list cannot be larger than 50 messages. This is a hard limit enforced by the + * Solace API. + * + * @param records A {@link List} of records to be published + * @param useCorrelationKeyLatency Whether to use a complex key for tracking latency. + * @param destinationFn A function that maps every record to its destination. + * @param deliveryMode The {@link DeliveryMode} used to publish the message. + * @return A {@link JCSMPSendMultipleEntry} array that can be sent to Solace "as is". + */ + public static JCSMPSendMultipleEntry[] createJCSMPSendMultipleEntry( + List records, + boolean useCorrelationKeyLatency, + SerializableFunction destinationFn, + DeliveryMode deliveryMode) { + if (records.size() > SOLACE_BATCH_LIMIT) { + throw new RuntimeException( + String.format( + "SolaceIO.Write: Trying to create a batch of %d, but Solace supports a" + + " maximum of %d. The batch will likely be rejected by Solace.", + records.size(), SOLACE_BATCH_LIMIT)); + } + + JCSMPSendMultipleEntry[] entries = new JCSMPSendMultipleEntry[records.size()]; + for (int i = 0; i < records.size(); i++) { + Solace.Record record = records.get(i); + JCSMPSendMultipleEntry entry = + JCSMPFactory.onlyInstance() + .createSendMultipleEntry( + createBytesXMLMessage(record, useCorrelationKeyLatency, deliveryMode), + destinationFn.apply(record)); + entries[i] = entry; + } + + return entries; + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/PublishResultHandler.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/PublishResultHandler.java new file mode 100644 index 000000000000..1153bfcb7a1c --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/PublishResultHandler.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import com.solacesystems.jcsmp.JCSMPException; +import com.solacesystems.jcsmp.JCSMPStreamingPublishCorrelatingEventHandler; +import java.util.Queue; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; +import org.apache.beam.sdk.io.solace.write.UnboundedSolaceWriter; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class is required to handle callbacks from Solace, to find out if messages were actually + * published or there were any kind of error. + * + *

    This class is also used to calculate the latency of the publication. The correlation key + * contains the original timestamp of when the message was sent from the pipeline to Solace. The + * comparison of that value with the clock now, using a monotonic clock, is understood as the + * latency of the publication + */ +public final class PublishResultHandler implements JCSMPStreamingPublishCorrelatingEventHandler { + + private static final Logger LOG = LoggerFactory.getLogger(PublishResultHandler.class); + private final Queue publishResultsQueue; + private final Counter batchesRejectedByBroker = + Metrics.counter(UnboundedSolaceWriter.class, "batches_rejected"); + + public PublishResultHandler(Queue publishResultsQueue) { + this.publishResultsQueue = publishResultsQueue; + } + + @Override + public void handleErrorEx(Object key, JCSMPException cause, long timestamp) { + processKey(key, false, cause); + } + + @Override + public void responseReceivedEx(Object key) { + processKey(key, true, null); + } + + private void processKey(Object key, boolean isPublished, @Nullable JCSMPException cause) { + PublishResult.Builder resultBuilder = PublishResult.builder(); + String messageId; + if (key == null) { + messageId = ""; + } else if (key instanceof Solace.CorrelationKey) { + messageId = ((Solace.CorrelationKey) key).getMessageId(); + long latencyNanos = calculateLatency((Solace.CorrelationKey) key); + resultBuilder = resultBuilder.setLatencyNanos(latencyNanos); + } else { + messageId = key.toString(); + } + + resultBuilder = resultBuilder.setMessageId(messageId).setPublished(isPublished); + if (!isPublished) { + batchesRejectedByBroker.inc(); + if (cause != null) { + resultBuilder = resultBuilder.setError(cause.getMessage()); + } else { + resultBuilder = resultBuilder.setError("NULL - Not set by Solace"); + } + } else if (cause != null) { + LOG.warn( + "Message with id {} is published but exception is populated. Ignoring exception", + messageId); + } + + PublishResult publishResult = resultBuilder.build(); + // Static reference, it receives all callbacks from all publications + // from all threads + publishResultsQueue.add(publishResult); + } + + private static long calculateLatency(Solace.CorrelationKey key) { + long currentMillis = System.nanoTime(); + long publishMillis = key.getPublishMonotonicNanos(); + return currentMillis - publishMillis; + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java index 99a81f716435..965fc8741374 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java @@ -19,6 +19,9 @@ import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.api.client.http.GenericUrl; import com.google.api.client.http.HttpContent; import com.google.api.client.http.HttpHeaders; @@ -40,6 +43,7 @@ import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import org.apache.beam.sdk.io.solace.data.Semp.Queue; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.checkerframework.checker.nullness.qual.Nullable; @@ -52,7 +56,7 @@ * response is 401 Unauthorized, the client will execute an additional request with Basic Auth * header to refresh the token. */ -class SempBasicAuthClientExecutor implements Serializable { +public class SempBasicAuthClientExecutor implements Serializable { // Every request will be repeated 2 times in case of abnormal connection failures. private static final int REQUEST_NUM_RETRIES = 2; private static final Map COOKIE_MANAGER_MAP = @@ -65,8 +69,10 @@ class SempBasicAuthClientExecutor implements Serializable { private final String password; private final CookieManagerKey cookieManagerKey; private final transient HttpRequestFactory requestFactory; + private final ObjectMapper objectMapper = + new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - SempBasicAuthClientExecutor( + public SempBasicAuthClientExecutor( String host, String username, String password, @@ -78,7 +84,16 @@ class SempBasicAuthClientExecutor implements Serializable { this.password = password; this.requestFactory = httpRequestFactory; this.cookieManagerKey = new CookieManagerKey(this.baseUrl, this.username); - COOKIE_MANAGER_MAP.putIfAbsent(this.cookieManagerKey, new CookieManager()); + COOKIE_MANAGER_MAP.computeIfAbsent(this.cookieManagerKey, key -> new CookieManager()); + } + + public boolean isQueueNonExclusive(String queueName) throws IOException { + BrokerResponse response = getQueueResponse(queueName); + if (response.content == null) { + throw new IOException("SolaceIO: response from SEMP is empty!"); + } + Queue q = mapJsonToClass(response.content, Queue.class); + return q.data().accessType().equals("non-exclusive"); } private static String getQueueEndpoint(String messageVpn, String queueName) @@ -199,6 +214,20 @@ private static String urlEncode(String queueName) throws UnsupportedEncodingExce return URLEncoder.encode(queueName, StandardCharsets.UTF_8.name()); } + private T mapJsonToClass(String content, Class mapSuccessToClass) + throws JsonProcessingException { + return objectMapper.readValue(content, mapSuccessToClass); + } + + public long getBacklogBytes(String queueName) throws IOException { + BrokerResponse response = getQueueResponse(queueName); + if (response.content == null) { + throw new IOException("SolaceIO: response from SEMP is empty!"); + } + Queue q = mapJsonToClass(response.content, Queue.class); + return q.data().msgSpoolUsage(); + } + private static class CookieManagerKey implements Serializable { private final String baseUrl; private final String username; diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java index aed700a71ded..84a876a9d0bc 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java @@ -19,7 +19,11 @@ import com.solacesystems.jcsmp.JCSMPProperties; import java.io.Serializable; +import java.util.Queue; import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; +import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,21 +73,23 @@ *

    For basic authentication, use {@link BasicAuthJcsmpSessionService} and {@link * BasicAuthJcsmpSessionServiceFactory}. * - *

    For other situations, you need to extend this class. For instance: + *

    For other situations, you need to extend this class and implement the `equals` method, so two + * instances of your class can be compared by value. We recommend using AutoValue for that. For + * instance: * *

    {@code
    + * {@literal }@AutoValue
      * public class MySessionService extends SessionService {
    - *   private final String authToken;
    + *   abstract String authToken();
      *
    - *   public MySessionService(String token) {
    - *    this.oauthToken = token;
    - *    ...
    + *   public static MySessionService create(String authToken) {
    + *       return new AutoValue_MySessionService(authToken);
      *   }
      *
      *   {@literal }@Override
      *   public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProps) {
      *     baseProps.setProperty(JCSMPProperties.AUTHENTICATION_SCHEME, JCSMPProperties.AUTHENTICATION_SCHEME_OAUTH2);
    - *     baseProps.setProperty(JCSMPProperties.OAUTH2_ACCESS_TOKEN, authToken);
    + *     baseProps.setProperty(JCSMPProperties.OAUTH2_ACCESS_TOKEN, authToken());
      *     return props;
      *   }
      *
    @@ -101,6 +107,7 @@ public abstract class SessionService implements Serializable {
     
       public static final String DEFAULT_VPN_NAME = "default";
     
    +  private static final int TESTING_PUB_ACK_WINDOW = 1;
       private static final int STREAMING_PUB_ACK_WINDOW = 50;
       private static final int BATCHED_PUB_ACK_WINDOW = 255;
     
    @@ -121,10 +128,25 @@ public abstract class SessionService implements Serializable {
       public abstract boolean isClosed();
     
       /**
    -   * Creates a MessageReceiver object for receiving messages from Solace. Typically, this object is
    -   * created from the session instance.
    +   * Returns a MessageReceiver object for receiving messages from Solace. If it is the first time
    +   * this method is used, the receiver is created from the session instance, otherwise it returns
    +   * the receiver created initially.
        */
    -  public abstract MessageReceiver createReceiver();
    +  public abstract MessageReceiver getReceiver();
    +
    +  /**
    +   * Returns a MessageProducer object for publishing messages to Solace. If it is the first time
    +   * this method is used, the producer is created from the session instance, otherwise it returns
    +   * the producer created initially.
    +   */
    +  public abstract MessageProducer getInitializedProducer(SubmissionMode mode);
    +
    +  /**
    +   * Returns the {@link Queue} instance associated with this session, with the
    +   * asynchronously received callbacks from Solace for message publications. The queue
    +   * implementation has to be thread-safe for production use-cases.
    +   */
    +  public abstract Queue getPublishedResultsQueue();
     
       /**
        * Override this method and provide your specific properties, including all those related to
    @@ -147,6 +169,20 @@ public abstract class SessionService implements Serializable {
        */
       public abstract JCSMPProperties initializeSessionProperties(JCSMPProperties baseProperties);
     
    +  /**
    +   * You need to override this method to be able to compare these objects by value. We recommend
    +   * using AutoValue for that.
    +   */
    +  @Override
    +  public abstract boolean equals(@Nullable Object other);
    +
    +  /**
    +   * You need to override this method to be able to compare these objects by value. We recommend
    +   * using AutoValue for that.
    +   */
    +  @Override
    +  public abstract int hashCode();
    +
       /**
        * This method will be called by the write connector when a new session is started.
        *
    @@ -186,50 +222,80 @@ private static JCSMPProperties overrideConnectorProperties(
         // received from Solace. A value of 1 will have the lowest latency, but a very low
         // throughput and a monumental backpressure.
     
    -    // This controls how the messages are sent to Solace
    -    if (mode == SolaceIO.SubmissionMode.HIGHER_THROUGHPUT) {
    -      // Create a parallel thread and a queue to send the messages
    +    // Retrieve current values of the properties
    +    Boolean msgCbProp = props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR);
    +    Integer ackWindowSize = props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE);
     
    -      Boolean msgCbProp = props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR);
    -      if (msgCbProp != null && msgCbProp) {
    -        LOG.warn(
    -            "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to false since"
    -                + " HIGHER_THROUGHPUT mode was selected");
    -      }
    +    switch (mode) {
    +      case HIGHER_THROUGHPUT:
    +        // Check if it was set by user, show override warning
    +        if (msgCbProp != null && msgCbProp) {
    +          LOG.warn(
    +              "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to false since"
    +                  + " HIGHER_THROUGHPUT mode was selected");
    +        }
    +        if ((ackWindowSize != null && ackWindowSize != BATCHED_PUB_ACK_WINDOW)) {
    +          LOG.warn(
    +              String.format(
    +                  "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since"
    +                      + " HIGHER_THROUGHPUT mode was selected",
    +                  BATCHED_PUB_ACK_WINDOW));
    +        }
     
    -      props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, false);
    +        // Override the properties
    +        // Use a dedicated thread for callbacks, increase the ack window size
    +        props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, false);
    +        props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, BATCHED_PUB_ACK_WINDOW);
    +        LOG.info(
    +            "SolaceIO.Write: Using HIGHER_THROUGHPUT mode, MESSAGE_CALLBACK_ON_REACTOR is FALSE,"
    +                + " PUB_ACK_WINDOW_SIZE is {}",
    +            BATCHED_PUB_ACK_WINDOW);
    +        break;
    +      case LOWER_LATENCY:
    +        // Check if it was set by user, show override warning
    +        if (msgCbProp != null && !msgCbProp) {
    +          LOG.warn(
    +              "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to true since"
    +                  + " LOWER_LATENCY mode was selected");
    +        }
     
    -      Integer ackWindowSize = props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE);
    -      if ((ackWindowSize != null && ackWindowSize != BATCHED_PUB_ACK_WINDOW)) {
    -        LOG.warn(
    -            String.format(
    -                "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since"
    -                    + " HIGHER_THROUGHPUT mode was selected",
    -                BATCHED_PUB_ACK_WINDOW));
    -      }
    -      props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, BATCHED_PUB_ACK_WINDOW);
    -    } else {
    -      // Send from the same thread where the produced is being called. This offers the lowest
    -      // latency, but a low throughput too.
    -      Boolean msgCbProp = props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR);
    -      if (msgCbProp != null && !msgCbProp) {
    -        LOG.warn(
    -            "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to true since"
    -                + " LOWER_LATENCY mode was selected");
    -      }
    +        if ((ackWindowSize != null && ackWindowSize != STREAMING_PUB_ACK_WINDOW)) {
    +          LOG.warn(
    +              String.format(
    +                  "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since"
    +                      + " LOWER_LATENCY mode was selected",
    +                  STREAMING_PUB_ACK_WINDOW));
    +        }
     
    -      props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, true);
    +        // Override the properties
    +        // Send from the same thread where the produced is being called. This offers the lowest
    +        // latency, but a low throughput too.
    +        props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, true);
    +        props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, STREAMING_PUB_ACK_WINDOW);
    +        LOG.info(
    +            "SolaceIO.Write: Using LOWER_LATENCY mode, MESSAGE_CALLBACK_ON_REACTOR is TRUE,"
    +                + " PUB_ACK_WINDOW_SIZE is {}",
    +            STREAMING_PUB_ACK_WINDOW);
     
    -      Integer ackWindowSize = props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE);
    -      if ((ackWindowSize != null && ackWindowSize != STREAMING_PUB_ACK_WINDOW)) {
    +        break;
    +      case CUSTOM:
    +        LOG.info(
    +            " SolaceIO.Write: Using the custom JCSMP properties set by the user. No property has"
    +                + " been overridden by the connector.");
    +        break;
    +      case TESTING:
             LOG.warn(
    -            String.format(
    -                "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since"
    -                    + " LOWER_LATENCY mode was selected",
    -                STREAMING_PUB_ACK_WINDOW));
    -      }
    -
    -      props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, STREAMING_PUB_ACK_WINDOW);
    +            "SolaceIO.Write: Overriding JCSMP properties for testing. **IF THIS IS AN"
    +                + " ACTUAL PIPELINE, CHANGE THE SUBMISSION MODE TO HIGHER_THROUGHPUT "
    +                + "OR LOWER_LATENCY.**");
    +        // Minimize multi-threading for testing
    +        props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, true);
    +        props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, TESTING_PUB_ACK_WINDOW);
    +        break;
    +      default:
    +        LOG.error(
    +            "SolaceIO.Write: no submission mode is selected. Set the submission mode to"
    +                + " HIGHER_THROUGHPUT or LOWER_LATENCY;");
         }
         return props;
       }
    diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java
    index 027de2cff134..bd1f3c23694d 100644
    --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java
    +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java
    @@ -19,11 +19,40 @@
     
     import com.solacesystems.jcsmp.Queue;
     import java.io.Serializable;
    +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode;
     import org.checkerframework.checker.nullness.qual.Nullable;
     
     /**
    - * This abstract class serves as a blueprint for creating `SessionService` objects. It introduces a
    - * queue property and mandates the implementation of a create() method in concrete subclasses.
    + * This abstract class serves as a blueprint for creating `SessionServiceFactory` objects. It
    + * introduces a queue property and mandates the implementation of a create() method in concrete
    + * subclasses.
    + *
    + * 

    For basic authentication, use {@link BasicAuthJcsmpSessionServiceFactory}. + * + *

    For other situations, you need to extend this class. Classes extending from this abstract + * class must implement the `equals` method so two instances can be compared by value, and not by + * reference. We recommend using AutoValue for that. + * + *

    {@code
    + * {@literal @}AutoValue
    + * public abstract class MyFactory implements SessionServiceClientFactory {
    + *
    + *   abstract String value1();
    + *
    + *   abstract String value2();
    + *
    + *   public static MyFactory create(String value1, String value2) {
    + *     return new AutoValue_MyFactory.Builder(value1, value2);
    + *   }
    + *
    + *   ...
    + *
    + *   {@literal @}Override
    + *   public SessionService create() {
    + *     ...
    + *   }
    + * }
    + * }
    */ public abstract class SessionServiceFactory implements Serializable { /** @@ -34,12 +63,32 @@ public abstract class SessionServiceFactory implements Serializable { */ @Nullable Queue queue; + /** + * The write submission mode. This is set when the writers are created. This property is used only + * by the write connector. + */ + @Nullable SubmissionMode submissionMode; + /** * This is the core method that subclasses must implement. It defines how to construct and return * a SessionService object. */ public abstract SessionService create(); + /** + * You need to override this method to be able to compare these objects by value. We recommend + * using AutoValue for that. + */ + @Override + public abstract boolean equals(@Nullable Object other); + + /** + * You need to override this method to be able to compare these objects by value. We recommend + * using AutoValue for that. + */ + @Override + public abstract int hashCode(); + /** * This method is called in the {@link * org.apache.beam.sdk.io.solace.SolaceIO.Read#expand(org.apache.beam.sdk.values.PBegin)} method @@ -48,4 +97,15 @@ public abstract class SessionServiceFactory implements Serializable { public void setQueue(Queue queue) { this.queue = queue; } + + /** + * Called by the write connector to set the submission mode used to create the message producers. + */ + public void setSubmissionMode(SubmissionMode submissionMode) { + this.submissionMode = submissionMode; + } + + public @Nullable SubmissionMode getSubmissionMode() { + return submissionMode; + } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageProducer.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageProducer.java new file mode 100644 index 000000000000..b3806b5afae9 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageProducer.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import static org.apache.beam.sdk.io.solace.broker.MessageProducerUtils.createBytesXMLMessage; +import static org.apache.beam.sdk.io.solace.broker.MessageProducerUtils.createJCSMPSendMultipleEntry; + +import com.solacesystems.jcsmp.BytesXMLMessage; +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import com.solacesystems.jcsmp.JCSMPException; +import com.solacesystems.jcsmp.JCSMPSendMultipleEntry; +import com.solacesystems.jcsmp.XMLMessageProducer; +import java.util.List; +import java.util.concurrent.Callable; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.RetryCallableManager; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; + +@Internal +public class SolaceMessageProducer implements MessageProducer { + + private final XMLMessageProducer producer; + private final RetryCallableManager retryCallableManager = RetryCallableManager.create(); + + public SolaceMessageProducer(XMLMessageProducer producer) { + this.producer = producer; + } + + @Override + public void publishSingleMessage( + Solace.Record record, + Destination topicOrQueue, + boolean useCorrelationKeyLatency, + DeliveryMode deliveryMode) { + BytesXMLMessage msg = createBytesXMLMessage(record, useCorrelationKeyLatency, deliveryMode); + Callable publish = + () -> { + producer.send(msg, topicOrQueue); + return 0; + }; + + retryCallableManager.retryCallable(publish, ImmutableSet.of(JCSMPException.class)); + } + + @Override + public int publishBatch( + List records, + boolean useCorrelationKeyLatency, + SerializableFunction destinationFn, + DeliveryMode deliveryMode) { + JCSMPSendMultipleEntry[] batch = + createJCSMPSendMultipleEntry( + records, useCorrelationKeyLatency, destinationFn, deliveryMode); + Callable publish = () -> producer.sendMultiple(batch, 0, batch.length, 0); + return retryCallableManager.retryCallable(publish, ImmutableSet.of(JCSMPException.class)); + } + + @Override + public boolean isClosed() { + return producer == null || producer.isClosed(); + } + + @Override + public void close() { + if (!isClosed()) { + this.producer.close(); + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java index 00b94b5b9ea9..21274237f46a 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java @@ -21,7 +21,6 @@ import com.solacesystems.jcsmp.BytesXMLMessage; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.ByteBuffer; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaFieldNumber; @@ -52,6 +51,7 @@ public String getName() { return name; } } + /** Represents a Solace topic. */ public static class Topic { private final String name; @@ -68,6 +68,7 @@ public String getName() { return name; } } + /** Represents a Solace destination type. */ public enum DestinationType { TOPIC, @@ -93,17 +94,17 @@ public abstract static class Destination { */ public abstract DestinationType getType(); - static Builder builder() { + public static Builder builder() { return new AutoValue_Solace_Destination.Builder(); } @AutoValue.Builder - abstract static class Builder { - abstract Builder setName(String name); + public abstract static class Builder { + public abstract Builder setName(String name); - abstract Builder setType(DestinationType type); + public abstract Builder setType(DestinationType type); - abstract Destination build(); + public abstract Destination build(); } } @@ -120,17 +121,19 @@ public abstract static class Record { * @return The message ID, or null if not available. */ @SchemaFieldNumber("0") - public abstract @Nullable String getMessageId(); + public abstract String getMessageId(); /** - * Gets the payload of the message as a ByteString. + * Gets the payload of the message as a byte array. * *

    Mapped from {@link BytesXMLMessage#getBytes()} * * @return The message payload. */ + @SuppressWarnings("mutable") @SchemaFieldNumber("1") - public abstract ByteBuffer getPayload(); + public abstract byte[] getPayload(); + /** * Gets the destination (topic or queue) to which the message was sent. * @@ -192,7 +195,7 @@ public abstract static class Record { * @return The timestamp. */ @SchemaFieldNumber("7") - public abstract long getReceiveTimestamp(); + public abstract @Nullable Long getReceiveTimestamp(); /** * Gets the timestamp (in milliseconds since the Unix epoch) when the message was sent by the @@ -241,55 +244,62 @@ public abstract static class Record { public abstract @Nullable String getReplicationGroupMessageId(); /** - * Gets the attachment data of the message as a ByteString, if any. This might represent files + * Gets the attachment data of the message as a byte array, if any. This might represent files * or other binary content associated with the message. * *

    Mapped from {@link BytesXMLMessage#getAttachmentByteBuffer()} * - * @return The attachment data, or an empty ByteString if no attachment is present. + * @return The attachment data, or an empty byte array if no attachment is present. */ + @SuppressWarnings("mutable") @SchemaFieldNumber("12") - public abstract ByteBuffer getAttachmentBytes(); + public abstract byte[] getAttachmentBytes(); - static Builder builder() { - return new AutoValue_Solace_Record.Builder(); + public static Builder builder() { + return new AutoValue_Solace_Record.Builder() + .setExpiration(0L) + .setPriority(-1) + .setRedelivered(false) + .setTimeToLive(0) + .setAttachmentBytes(new byte[0]); } @AutoValue.Builder - abstract static class Builder { - abstract Builder setMessageId(@Nullable String messageId); + public abstract static class Builder { + public abstract Builder setMessageId(String messageId); - abstract Builder setPayload(ByteBuffer payload); + public abstract Builder setPayload(byte[] payload); - abstract Builder setDestination(@Nullable Destination destination); + public abstract Builder setDestination(@Nullable Destination destination); - abstract Builder setExpiration(long expiration); + public abstract Builder setExpiration(long expiration); - abstract Builder setPriority(int priority); + public abstract Builder setPriority(int priority); - abstract Builder setRedelivered(boolean redelivered); + public abstract Builder setRedelivered(boolean redelivered); - abstract Builder setReplyTo(@Nullable Destination replyTo); + public abstract Builder setReplyTo(@Nullable Destination replyTo); - abstract Builder setReceiveTimestamp(long receiveTimestamp); + public abstract Builder setReceiveTimestamp(@Nullable Long receiveTimestamp); - abstract Builder setSenderTimestamp(@Nullable Long senderTimestamp); + public abstract Builder setSenderTimestamp(@Nullable Long senderTimestamp); - abstract Builder setSequenceNumber(@Nullable Long sequenceNumber); + public abstract Builder setSequenceNumber(@Nullable Long sequenceNumber); - abstract Builder setTimeToLive(long timeToLive); + public abstract Builder setTimeToLive(long timeToLive); - abstract Builder setReplicationGroupMessageId(@Nullable String replicationGroupMessageId); + public abstract Builder setReplicationGroupMessageId( + @Nullable String replicationGroupMessageId); - abstract Builder setAttachmentBytes(ByteBuffer attachmentBytes); + public abstract Builder setAttachmentBytes(byte[] attachmentBytes); - abstract Record build(); + public abstract Record build(); } } /** * The result of writing a message to Solace. This will be returned by the {@link - * com.google.cloud.dataflow.dce.io.solace.SolaceIO.Write} connector. + * org.apache.beam.sdk.io.solace.SolaceIO.Write} connector. * *

    This class provides a builder to create instances, but you will probably not need it. The * write connector will create and return instances of {@link Solace.PublishResult}. @@ -311,12 +321,12 @@ public abstract static class PublishResult { public abstract Boolean getPublished(); /** - * The publishing latency in milliseconds. This is the difference between the time the message + * The publishing latency in nanoseconds. This is the difference between the time the message * was created, and the time the message was published. It is only available if the {@link - * CorrelationKey} class is used as correlation key of the messages. + * CorrelationKey} class is used as correlation key of the messages, and null otherwise. */ @SchemaFieldNumber("2") - public abstract @Nullable Long getLatencyMilliseconds(); + public abstract @Nullable Long getLatencyNanos(); /** The error details if the message could not be published. */ @SchemaFieldNumber("3") @@ -332,7 +342,7 @@ public abstract static class Builder { public abstract Builder setPublished(Boolean published); - public abstract Builder setLatencyMilliseconds(Long latencyMs); + public abstract Builder setLatencyNanos(Long latencyNanos); public abstract Builder setError(String error); @@ -354,7 +364,7 @@ public abstract static class CorrelationKey { public abstract String getMessageId(); @SchemaFieldNumber("1") - public abstract long getPublishMonotonicMillis(); + public abstract long getPublishMonotonicNanos(); public static Builder builder() { return new AutoValue_Solace_CorrelationKey.Builder(); @@ -364,7 +374,7 @@ public static Builder builder() { public abstract static class Builder { public abstract Builder setMessageId(String messageId); - public abstract Builder setPublishMonotonicMillis(long millis); + public abstract Builder setPublishMonotonicNanos(long nanos); public abstract CorrelationKey build(); } @@ -414,7 +424,7 @@ public static class SolaceRecordMapper { Destination destination = getDestination(msg.getCorrelationId(), msg.getDestination()); return Record.builder() .setMessageId(msg.getApplicationMessageId()) - .setPayload(ByteBuffer.wrap(payloadBytesStream.toByteArray())) + .setPayload(payloadBytesStream.toByteArray()) .setDestination(destination) .setExpiration(msg.getExpiration()) .setPriority(msg.getPriority()) @@ -428,7 +438,7 @@ public static class SolaceRecordMapper { msg.getReplicationGroupMessageId() != null ? msg.getReplicationGroupMessageId().toString() : null) - .setAttachmentBytes(ByteBuffer.wrap(attachmentBytesStream.toByteArray())) + .setAttachmentBytes(attachmentBytesStream.toByteArray()) .build(); } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java index c18a9d110b2a..a421970370da 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java @@ -29,7 +29,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; -import org.apache.beam.sdk.io.solace.broker.MessageReceiver; import org.apache.beam.sdk.io.solace.broker.SempClient; import org.apache.beam.sdk.io.solace.broker.SessionService; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -49,7 +48,6 @@ class UnboundedSolaceReader extends UnboundedReader { private final SempClient sempClient; private @Nullable BytesXMLMessage solaceOriginalRecord; private @Nullable T solaceMappedRecord; - private @Nullable MessageReceiver messageReceiver; private @Nullable SessionService sessionService; AtomicBoolean active = new AtomicBoolean(true); @@ -72,7 +70,7 @@ public UnboundedSolaceReader(UnboundedSolaceSource currentSource) { @Override public boolean start() { populateSession(); - populateMessageConsumer(); + checkNotNull(sessionService).getReceiver().start(); return advance(); } @@ -85,22 +83,11 @@ public void populateSession() { } } - private void populateMessageConsumer() { - if (messageReceiver == null) { - messageReceiver = checkNotNull(sessionService).createReceiver(); - messageReceiver.start(); - } - MessageReceiver receiver = checkNotNull(messageReceiver); - if (receiver.isClosed()) { - receiver.start(); - } - } - @Override public boolean advance() { BytesXMLMessage receivedXmlMessage; try { - receivedXmlMessage = checkNotNull(messageReceiver).receive(); + receivedXmlMessage = checkNotNull(sessionService).getReceiver().receive(); } catch (IOException e) { LOG.warn("SolaceIO.Read: Exception when pulling messages from the broker.", e); return false; @@ -125,7 +112,7 @@ public void close() { @Override public Instant getWatermark() { // should be only used by a test receiver - if (checkNotNull(messageReceiver).isEOF()) { + if (checkNotNull(sessionService).getReceiver().isEOF()) { return BoundedWindow.TIMESTAMP_MAX_VALUE; } return watermarkPolicy.getWatermark(); diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java new file mode 100644 index 000000000000..c55d37942c72 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.values.KV; + +/** + * This class adds pseudo-key with a given cardinality. The downstream steps will use state + * {@literal &} timers to distribute the data and control for the number of parallel workers used + * for writing. + */ +@Internal +public class AddShardKeyDoFn extends DoFn> { + private final int shardCount; + private int shardKey; + + public AddShardKeyDoFn(int shardCount) { + this.shardCount = shardCount; + shardKey = -1; + } + + @ProcessElement + public void processElement( + @Element Solace.Record record, OutputReceiver> c) { + shardKey = (shardKey + 1) % shardCount; + c.output(KV.of(shardKey, record)); + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/RecordToPublishResultDoFn.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/RecordToPublishResultDoFn.java new file mode 100644 index 000000000000..4be5b0a014b3 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/RecordToPublishResultDoFn.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.DoFn; + +/** + * This class just transforms to PublishResult to be able to capture the windowing with the right + * strategy. The output is not used for anything else. + */ +@Internal +public class RecordToPublishResultDoFn extends DoFn { + @ProcessElement + public void processElement( + @Element Solace.Record record, OutputReceiver receiver) { + Solace.PublishResult result = + Solace.PublishResult.builder() + .setPublished(true) + .setMessageId(record.getMessageId()) + .setLatencyNanos(0L) + .build(); + receiver.output(result); + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java index 6c37f879ae7f..d9c37326f83f 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.io.solace.SolaceIO; import org.apache.beam.sdk.io.solace.data.Solace; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; @@ -31,50 +32,33 @@ import org.checkerframework.checker.nullness.qual.Nullable; /** - * The {@link SolaceIO.Write} transform's output return this type, containing both the successful - * publishes ({@link #getSuccessfulPublish()}) and the failed publishes ({@link - * #getFailedPublish()}). + * The {@link SolaceIO.Write} transform's output return this type, containing the successful + * publishes ({@link #getSuccessfulPublish()}). To access failed records, configure the connector + * with {@link SolaceIO.Write#withErrorHandler(ErrorHandler)}. * *

    The streaming writer with DIRECT messages does not return anything, and the output {@link - * PCollection}s will be equal to null. + * PCollection} will be equal to null. */ public final class SolaceOutput implements POutput { private final Pipeline pipeline; - private final TupleTag failedPublishTag; private final TupleTag successfulPublishTag; - private final @Nullable PCollection failedPublish; private final @Nullable PCollection successfulPublish; - public @Nullable PCollection getFailedPublish() { - return failedPublish; - } - public @Nullable PCollection getSuccessfulPublish() { return successfulPublish; } public static SolaceOutput in( - Pipeline pipeline, - @Nullable PCollection failedPublish, - @Nullable PCollection successfulPublish) { - return new SolaceOutput( - pipeline, - SolaceIO.Write.FAILED_PUBLISH_TAG, - SolaceIO.Write.SUCCESSFUL_PUBLISH_TAG, - failedPublish, - successfulPublish); + Pipeline pipeline, @Nullable PCollection successfulPublish) { + return new SolaceOutput(pipeline, SolaceIO.Write.SUCCESSFUL_PUBLISH_TAG, successfulPublish); } private SolaceOutput( Pipeline pipeline, - TupleTag failedPublishTag, TupleTag successfulPublishTag, - @Nullable PCollection failedPublish, @Nullable PCollection successfulPublish) { this.pipeline = pipeline; - this.failedPublishTag = failedPublishTag; this.successfulPublishTag = successfulPublishTag; - this.failedPublish = failedPublish; this.successfulPublish = successfulPublish; } @@ -87,10 +71,6 @@ public Pipeline getPipeline() { public Map, PValue> expand() { ImmutableMap.Builder, PValue> builder = ImmutableMap., PValue>builder(); - if (failedPublish != null) { - builder.put(failedPublishTag, failedPublish); - } - if (successfulPublish != null) { builder.put(successfulPublishTag, successfulPublish); } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceWriteSessionsHandler.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceWriteSessionsHandler.java new file mode 100644 index 000000000000..109010231d17 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceWriteSessionsHandler.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import static org.apache.beam.sdk.io.solace.SolaceIO.DEFAULT_WRITER_CLIENTS_PER_WORKER; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import com.google.auto.value.AutoValue; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.SessionService; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; + +/** + * All the writer threads belonging to the same factory share the same instance of this class, to + * control for the number of clients that are connected to Solace, and minimize problems with quotas + * and limits. + * + *

    This class maintains a map of all the session open in a worker, and control the size of that + * map, to avoid creating more sessions than Solace could handle. + * + *

    This class is thread-safe and creates a pool of producers per SessionServiceFactory. If there + * is only a Write transform in the pipeline, this is effectively a singleton. If there are more + * than one, each {@link SessionServiceFactory} instance keeps their own pool of producers. + */ +final class SolaceWriteSessionsHandler { + + private static final ConcurrentHashMap sessionsMap = + new ConcurrentHashMap<>(DEFAULT_WRITER_CLIENTS_PER_WORKER); + + public static SessionService getSessionServiceWithProducer( + int producerIndex, SessionServiceFactory sessionServiceFactory, UUID writerTransformUuid) { + SessionConfigurationIndex key = + SessionConfigurationIndex.builder() + .producerIndex(producerIndex) + .sessionServiceFactory(sessionServiceFactory) + .writerTransformUuid(writerTransformUuid) + .build(); + return sessionsMap.computeIfAbsent( + key, SolaceWriteSessionsHandler::createSessionAndStartProducer); + } + + private static SessionService createSessionAndStartProducer(SessionConfigurationIndex key) { + SessionServiceFactory factory = key.sessionServiceFactory(); + SessionService sessionService = factory.create(); + // Start the producer now that the initialization is locked for other threads + SubmissionMode mode = factory.getSubmissionMode(); + checkStateNotNull( + mode, + "SolaceIO.Write: Submission mode is not set. You need to set it to create write sessions."); + sessionService.getInitializedProducer(mode); + return sessionService; + } + + /** Disconnect all the sessions from Solace, and clear the corresponding state. */ + public static void disconnectFromSolace( + SessionServiceFactory factory, int producersCardinality, UUID writerTransformUuid) { + for (int i = 0; i < producersCardinality; i++) { + SessionConfigurationIndex key = + SessionConfigurationIndex.builder() + .producerIndex(i) + .sessionServiceFactory(factory) + .writerTransformUuid(writerTransformUuid) + .build(); + + SessionService sessionService = sessionsMap.remove(key); + if (sessionService != null) { + sessionService.close(); + } + } + } + + @AutoValue + abstract static class SessionConfigurationIndex { + abstract int producerIndex(); + + abstract SessionServiceFactory sessionServiceFactory(); + + abstract UUID writerTransformUuid(); + + static Builder builder() { + return new AutoValue_SolaceWriteSessionsHandler_SessionConfigurationIndex.Builder(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder producerIndex(int producerIndex); + + abstract Builder sessionServiceFactory(SessionServiceFactory sessionServiceFactory); + + abstract Builder writerTransformUuid(UUID writerTransformUuid); + + abstract SessionConfigurationIndex build(); + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedBatchedSolaceWriter.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedBatchedSolaceWriter.java new file mode 100644 index 000000000000..dd4f81eeb082 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedBatchedSolaceWriter.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import java.io.IOException; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Record; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.KV; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This DoFn is the responsible for writing to Solace in batch mode (holding up any messages), and + * emit the corresponding output (success or fail; only for persistent messages), so the + * SolaceIO.Write connector can be composed with other subsequent transforms in the pipeline. + * + *

    The DoFn will create several JCSMP sessions per VM, and the sessions and producers will be + * reused across different threads (if the number of threads is higher than the number of sessions, + * which is probably the most common case). + * + *

    The producer uses the JCSMP send multiple mode to publish a batch of messages together with a + * single API call. The acks from this publication are also processed in batch, and returned as the + * output of the DoFn. + * + *

    The batch size is 50, and this is currently the maximum value supported by Solace. + * + *

    There are no acks if the delivery mode is set to DIRECT. + * + *

    This writer DoFn offers higher throughput than {@link UnboundedStreamingSolaceWriter} but also + * higher latency. + */ +@Internal +public final class UnboundedBatchedSolaceWriter extends UnboundedSolaceWriter { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedBatchedSolaceWriter.class); + + private static final int ACKS_FLUSHING_INTERVAL_SECS = 10; + + private final Counter sentToBroker = + Metrics.counter(UnboundedBatchedSolaceWriter.class, "msgs_sent_to_broker"); + + private final Counter batchesRejectedByBroker = + Metrics.counter(UnboundedSolaceWriter.class, "batches_rejected"); + + // State variables are never explicitly "used" + @SuppressWarnings("UnusedVariable") + @TimerId("bundle_flusher") + private final TimerSpec bundleFlusherTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + public UnboundedBatchedSolaceWriter( + SerializableFunction destinationFn, + SessionServiceFactory sessionServiceFactory, + DeliveryMode deliveryMode, + SubmissionMode submissionMode, + int producersMapCardinality, + boolean publishLatencyMetrics) { + super( + destinationFn, + sessionServiceFactory, + deliveryMode, + submissionMode, + producersMapCardinality, + publishLatencyMetrics); + } + + // The state variable is here just to force a shuffling with a certain cardinality + @ProcessElement + public void processElement( + @Element KV element, + @TimerId("bundle_flusher") Timer bundleFlusherTimer, + @Timestamp Instant timestamp) { + + setCurrentBundleTimestamp(timestamp); + + Solace.Record record = element.getValue(); + + if (record == null) { + LOG.error( + "SolaceIO.Write: Found null record with key {}. Ignoring record.", element.getKey()); + } else { + addToCurrentBundle(record); + // Extend timer for bundle flushing + bundleFlusherTimer + .offset(Duration.standardSeconds(ACKS_FLUSHING_INTERVAL_SECS)) + .setRelative(); + } + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) throws IOException { + // Take messages in groups of 50 (if there are enough messages) + List currentBundle = getCurrentBundle(); + for (int i = 0; i < currentBundle.size(); i += SOLACE_BATCH_LIMIT) { + int toIndex = Math.min(i + SOLACE_BATCH_LIMIT, currentBundle.size()); + List batch = currentBundle.subList(i, toIndex); + if (batch.isEmpty()) { + continue; + } + publishBatch(batch); + } + getCurrentBundle().clear(); + + publishResults(BeamContextWrapper.of(context)); + } + + @OnTimer("bundle_flusher") + public void flushBundle(OnTimerContext context) throws IOException { + publishResults(BeamContextWrapper.of(context)); + } + + private void publishBatch(List records) { + try { + int entriesPublished = + solaceSessionServiceWithProducer() + .getInitializedProducer(getSubmissionMode()) + .publishBatch( + records, shouldPublishLatencyMetrics(), getDestinationFn(), getDeliveryMode()); + sentToBroker.inc(entriesPublished); + } catch (Exception e) { + batchesRejectedByBroker.inc(); + Solace.PublishResult errorPublish = + Solace.PublishResult.builder() + .setPublished(false) + .setMessageId(String.format("BATCH_OF_%d_ENTRIES", records.size())) + .setError( + String.format( + "Batch could not be published after several" + " retries. Error: %s", + e.getMessage())) + .setLatencyNanos(System.nanoTime()) + .build(); + solaceSessionServiceWithProducer().getPublishedResultsQueue().add(errorPublish); + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedSolaceWriter.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedSolaceWriter.java new file mode 100644 index 000000000000..1c98113c2416 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedSolaceWriter.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import static org.apache.beam.sdk.io.solace.SolaceIO.Write.FAILED_PUBLISH_TAG; +import static org.apache.beam.sdk.io.solace.SolaceIO.Write.SUCCESSFUL_PUBLISH_TAG; + +import com.solacesystems.jcsmp.BytesXMLMessage; +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import com.solacesystems.jcsmp.JCSMPFactory; +import com.solacesystems.jcsmp.JCSMPSendMultipleEntry; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Queue; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.SessionService; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; +import org.apache.beam.sdk.io.solace.data.Solace.Record; +import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This DoFn encapsulates common code used both for the {@link UnboundedBatchedSolaceWriter} and + * {@link UnboundedStreamingSolaceWriter}. + */ +@Internal +public abstract class UnboundedSolaceWriter + extends DoFn, Solace.PublishResult> { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedSolaceWriter.class); + + // This is the batch limit supported by the send multiple JCSMP API method. + static final int SOLACE_BATCH_LIMIT = 50; + private final Distribution latencyPublish = + Metrics.distribution(SolaceIO.Write.class, "latency_publish_ms"); + + private final Distribution latencyErrors = + Metrics.distribution(SolaceIO.Write.class, "latency_failed_ms"); + + private final SerializableFunction destinationFn; + + private final SessionServiceFactory sessionServiceFactory; + private final DeliveryMode deliveryMode; + private final SubmissionMode submissionMode; + private final int producersMapCardinality; + private final boolean publishLatencyMetrics; + private static final AtomicInteger bundleProducerIndexCounter = new AtomicInteger(); + private int currentBundleProducerIndex = 0; + + private final List batchToEmit; + + private @Nullable Instant bundleTimestamp; + + final UUID writerTransformUuid = UUID.randomUUID(); + + public UnboundedSolaceWriter( + SerializableFunction destinationFn, + SessionServiceFactory sessionServiceFactory, + DeliveryMode deliveryMode, + SubmissionMode submissionMode, + int producersMapCardinality, + boolean publishLatencyMetrics) { + this.destinationFn = destinationFn; + this.sessionServiceFactory = sessionServiceFactory; + // Make sure that we set the submission mode now that we know which mode has been set by the + // user. + this.sessionServiceFactory.setSubmissionMode(submissionMode); + this.deliveryMode = deliveryMode; + this.submissionMode = submissionMode; + this.producersMapCardinality = producersMapCardinality; + this.publishLatencyMetrics = publishLatencyMetrics; + this.batchToEmit = new ArrayList<>(); + } + + @Teardown + public void teardown() { + SolaceWriteSessionsHandler.disconnectFromSolace( + sessionServiceFactory, producersMapCardinality, writerTransformUuid); + } + + public void updateProducerIndex() { + currentBundleProducerIndex = + bundleProducerIndexCounter.getAndIncrement() % producersMapCardinality; + } + + @StartBundle + public void startBundle() { + // Pick a producer at random for this bundle, reuse for the whole bundle + updateProducerIndex(); + batchToEmit.clear(); + } + + public SessionService solaceSessionServiceWithProducer() { + return SolaceWriteSessionsHandler.getSessionServiceWithProducer( + currentBundleProducerIndex, sessionServiceFactory, writerTransformUuid); + } + + public void publishResults(BeamContextWrapper context) { + long sumPublish = 0; + long countPublish = 0; + long minPublish = Long.MAX_VALUE; + long maxPublish = 0; + + long sumFailed = 0; + long countFailed = 0; + long minFailed = Long.MAX_VALUE; + long maxFailed = 0; + + Queue publishResultsQueue = + solaceSessionServiceWithProducer().getPublishedResultsQueue(); + Solace.PublishResult result = publishResultsQueue.poll(); + + if (result != null) { + if (getCurrentBundleTimestamp() == null) { + setCurrentBundleTimestamp(Instant.now()); + } + } + + while (result != null) { + Long latency = result.getLatencyNanos(); + + if (latency == null && shouldPublishLatencyMetrics()) { + LOG.error( + "SolaceIO.Write: Latency is null but user asked for latency metrics." + + " This may be a bug."); + } + + if (latency != null) { + if (result.getPublished()) { + sumPublish += latency; + countPublish++; + minPublish = Math.min(minPublish, latency); + maxPublish = Math.max(maxPublish, latency); + } else { + sumFailed += latency; + countFailed++; + minFailed = Math.min(minFailed, latency); + maxFailed = Math.max(maxFailed, latency); + } + } + if (result.getPublished()) { + context.output( + SUCCESSFUL_PUBLISH_TAG, result, getCurrentBundleTimestamp(), GlobalWindow.INSTANCE); + } else { + try { + BadRecord b = + BadRecord.fromExceptionInformation( + result, + null, + null, + Optional.ofNullable(result.getError()).orElse("SolaceIO.Write: unknown error.")); + context.output(FAILED_PUBLISH_TAG, b, getCurrentBundleTimestamp(), GlobalWindow.INSTANCE); + } catch (IOException e) { + // ignore, the exception is thrown when the exception argument in the + // `BadRecord.fromExceptionInformation` is not null. + } + } + + result = publishResultsQueue.poll(); + } + + if (shouldPublishLatencyMetrics()) { + // Report all latency value in milliseconds + if (countPublish > 0) { + getPublishLatencyMetric() + .update( + TimeUnit.NANOSECONDS.toMillis(sumPublish), + countPublish, + TimeUnit.NANOSECONDS.toMillis(minPublish), + TimeUnit.NANOSECONDS.toMillis(maxPublish)); + } + + if (countFailed > 0) { + getFailedLatencyMetric() + .update( + TimeUnit.NANOSECONDS.toMillis(sumFailed), + countFailed, + TimeUnit.NANOSECONDS.toMillis(minFailed), + TimeUnit.NANOSECONDS.toMillis(maxFailed)); + } + } + } + + public BytesXMLMessage createSingleMessage( + Solace.Record record, boolean useCorrelationKeyLatency) { + JCSMPFactory jcsmpFactory = JCSMPFactory.onlyInstance(); + BytesXMLMessage msg = jcsmpFactory.createBytesXMLMessage(); + byte[] payload = record.getPayload(); + msg.writeBytes(payload); + + Long senderTimestamp = record.getSenderTimestamp(); + if (senderTimestamp == null) { + LOG.error( + "SolaceIO.Write: Record with id {} has no sender timestamp. Using current" + + " worker clock as timestamp.", + record.getMessageId()); + senderTimestamp = System.currentTimeMillis(); + } + msg.setSenderTimestamp(senderTimestamp); + msg.setDeliveryMode(getDeliveryMode()); + if (useCorrelationKeyLatency) { + Solace.CorrelationKey key = + Solace.CorrelationKey.builder() + .setMessageId(record.getMessageId()) + .setPublishMonotonicNanos(System.nanoTime()) + .build(); + msg.setCorrelationKey(key); + } else { + // Use only a string as correlation key + msg.setCorrelationKey(record.getMessageId()); + } + msg.setApplicationMessageId(record.getMessageId()); + return msg; + } + + public JCSMPSendMultipleEntry[] createMessagesArray( + Iterable records, boolean useCorrelationKeyLatency) { + // Solace batch publishing only supports 50 elements max, so it is safe to convert to + // list here + ArrayList recordsList = Lists.newArrayList(records); + if (recordsList.size() > SOLACE_BATCH_LIMIT) { + LOG.error( + "SolaceIO.Write: Trying to create a batch of {}, but Solace supports a" + + " maximum of {}. The batch will likely be rejected by Solace.", + recordsList.size(), + SOLACE_BATCH_LIMIT); + } + + JCSMPSendMultipleEntry[] entries = new JCSMPSendMultipleEntry[recordsList.size()]; + for (int i = 0; i < recordsList.size(); i++) { + Solace.Record record = recordsList.get(i); + JCSMPSendMultipleEntry entry = + JCSMPFactory.onlyInstance() + .createSendMultipleEntry( + createSingleMessage(record, useCorrelationKeyLatency), + getDestinationFn().apply(record)); + entries[i] = entry; + } + + return entries; + } + + public int getProducersMapCardinality() { + return producersMapCardinality; + } + + public Distribution getPublishLatencyMetric() { + return latencyPublish; + } + + public Distribution getFailedLatencyMetric() { + return latencyErrors; + } + + public boolean shouldPublishLatencyMetrics() { + return publishLatencyMetrics; + } + + public SerializableFunction getDestinationFn() { + return destinationFn; + } + + public DeliveryMode getDeliveryMode() { + return deliveryMode; + } + + public SubmissionMode getSubmissionMode() { + return submissionMode; + } + + public void addToCurrentBundle(Solace.Record record) { + batchToEmit.add(record); + } + + public List getCurrentBundle() { + return batchToEmit; + } + + public @Nullable Instant getCurrentBundleTimestamp() { + return bundleTimestamp; + } + + public void setCurrentBundleTimestamp(Instant bundleTimestamp) { + if (this.bundleTimestamp == null || bundleTimestamp.isBefore(this.bundleTimestamp)) { + this.bundleTimestamp = bundleTimestamp; + } + } + + /** + * Since we need to publish from on timer methods and finish bundle methods, we need a consistent + * way to handle both WindowedContext and FinishBundleContext. + */ + static class BeamContextWrapper { + private @Nullable WindowedContext windowedContext; + private @Nullable FinishBundleContext finishBundleContext; + + private BeamContextWrapper() {} + + public static BeamContextWrapper of(WindowedContext windowedContext) { + BeamContextWrapper beamContextWrapper = new BeamContextWrapper(); + beamContextWrapper.windowedContext = windowedContext; + return beamContextWrapper; + } + + public static BeamContextWrapper of(FinishBundleContext finishBundleContext) { + BeamContextWrapper beamContextWrapper = new BeamContextWrapper(); + beamContextWrapper.finishBundleContext = finishBundleContext; + return beamContextWrapper; + } + + public void output( + TupleTag tag, + T output, + @Nullable Instant timestamp, // Not required for windowed context + @Nullable BoundedWindow window) { // Not required for windowed context + if (windowedContext != null) { + windowedContext.output(tag, output); + } else if (finishBundleContext != null) { + if (timestamp == null) { + throw new IllegalStateException( + "SolaceIO.Write.UnboundedSolaceWriter.Context: Timestamp is required for a" + + " FinishBundleContext."); + } + if (window == null) { + throw new IllegalStateException( + "SolaceIO.Write.UnboundedSolaceWriter.Context: BoundedWindow is required for a" + + " FinishBundleContext."); + } + finishBundleContext.output(tag, output, timestamp, window); + } else { + throw new IllegalStateException( + "SolaceIO.Write.UnboundedSolaceWriter.Context: No context provided"); + } + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedStreamingSolaceWriter.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedStreamingSolaceWriter.java new file mode 100644 index 000000000000..6d6d0b27e2bb --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedStreamingSolaceWriter.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.KV; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This DoFn is the responsible for writing to Solace in streaming mode (one message at a time, not + * holding up any message), and emit the corresponding output (success or fail; only for persistent + * messages), so the SolaceIO.Write connector can be composed with other subsequent transforms in + * the pipeline. + * + *

    The DoFn will create several JCSMP sessions per VM, and the sessions and producers will be + * reused across different threads (if the number of threads is higher than the number of sessions, + * which is probably the most common case). + * + *

    The producer uses the JCSMP streaming mode to publish a single message at a time, processing + * the acks from this publication, and returning them as output of the DoFn. + * + *

    There are no acks if the delivery mode is set to DIRECT. + * + *

    This writer DoFn offers lower latency and lower throughput than {@link + * UnboundedBatchedSolaceWriter}. + */ +@Internal +public final class UnboundedStreamingSolaceWriter extends UnboundedSolaceWriter { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedStreamingSolaceWriter.class); + + private final Counter sentToBroker = + Metrics.counter(UnboundedStreamingSolaceWriter.class, "msgs_sent_to_broker"); + + private final Counter rejectedByBroker = + Metrics.counter(UnboundedStreamingSolaceWriter.class, "msgs_rejected_by_broker"); + + // We use a state variable to force a shuffling and ensure the cardinality of the processing + @SuppressWarnings("UnusedVariable") + @StateId("current_key") + private final StateSpec> currentKeySpec = StateSpecs.value(); + + public UnboundedStreamingSolaceWriter( + SerializableFunction destinationFn, + SessionServiceFactory sessionServiceFactory, + DeliveryMode deliveryMode, + SolaceIO.SubmissionMode submissionMode, + int producersMapCardinality, + boolean publishLatencyMetrics) { + super( + destinationFn, + sessionServiceFactory, + deliveryMode, + submissionMode, + producersMapCardinality, + publishLatencyMetrics); + } + + @ProcessElement + public void processElement( + @Element KV element, + @Timestamp Instant timestamp, + @AlwaysFetched @StateId("current_key") ValueState currentKeyState) { + + setCurrentBundleTimestamp(timestamp); + + Integer currentKey = currentKeyState.read(); + Integer elementKey = element.getKey(); + Solace.Record record = element.getValue(); + + if (currentKey == null || !currentKey.equals(elementKey)) { + currentKeyState.write(elementKey); + } + + if (record == null) { + LOG.error("SolaceIO.Write: Found null record with key {}. Ignoring record.", elementKey); + return; + } + + // The publish method will retry, let's send a failure message if all the retries fail + try { + solaceSessionServiceWithProducer() + .getInitializedProducer(getSubmissionMode()) + .publishSingleMessage( + record, + getDestinationFn().apply(record), + shouldPublishLatencyMetrics(), + getDeliveryMode()); + sentToBroker.inc(); + } catch (Exception e) { + rejectedByBroker.inc(); + Solace.PublishResult errorPublish = + Solace.PublishResult.builder() + .setPublished(false) + .setMessageId(record.getMessageId()) + .setError( + String.format( + "Message could not be published after several" + " retries. Error: %s", + e.getMessage())) + .setLatencyNanos(System.nanoTime()) + .build(); + solaceSessionServiceWithProducer().getPublishedResultsQueue().add(errorPublish); + } + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) { + publishResults(BeamContextWrapper.of(context)); + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java index ec0ae7194686..38b4953a5984 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java @@ -17,14 +17,24 @@ */ package org.apache.beam.sdk.io.solace; +import com.google.auto.value.AutoValue; import com.solacesystems.jcsmp.JCSMPProperties; +import java.util.Queue; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.MessageProducer; import org.apache.beam.sdk.io.solace.broker.MessageReceiver; import org.apache.beam.sdk.io.solace.broker.SessionService; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; -public class MockEmptySessionService extends SessionService { +@AutoValue +public abstract class MockEmptySessionService extends SessionService { String exceptionMessage = "This is an empty client, use a MockSessionService instead."; + public static MockEmptySessionService create() { + return new AutoValue_MockEmptySessionService(); + } + @Override public void close() { throw new UnsupportedOperationException(exceptionMessage); @@ -36,7 +46,17 @@ public boolean isClosed() { } @Override - public MessageReceiver createReceiver() { + public MessageReceiver getReceiver() { + throw new UnsupportedOperationException(exceptionMessage); + } + + @Override + public MessageProducer getInitializedProducer(SubmissionMode mode) { + throw new UnsupportedOperationException(exceptionMessage); + } + + @Override + public Queue getPublishedResultsQueue() { throw new UnsupportedOperationException(exceptionMessage); } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockProducer.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockProducer.java new file mode 100644 index 000000000000..271310359577 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockProducer.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace; + +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import com.solacesystems.jcsmp.JCSMPException; +import java.time.Instant; +import java.util.List; +import org.apache.beam.sdk.io.solace.broker.MessageProducer; +import org.apache.beam.sdk.io.solace.broker.PublishResultHandler; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Record; +import org.apache.beam.sdk.transforms.SerializableFunction; + +public abstract class MockProducer implements MessageProducer { + final PublishResultHandler handler; + + public MockProducer(PublishResultHandler handler) { + this.handler = handler; + } + + @Override + public int publishBatch( + List records, + boolean useCorrelationKeyLatency, + SerializableFunction destinationFn, + DeliveryMode deliveryMode) { + for (Record record : records) { + this.publishSingleMessage( + record, destinationFn.apply(record), useCorrelationKeyLatency, deliveryMode); + } + return records.size(); + } + + @Override + public boolean isClosed() { + return false; + } + + @Override + public void close() {} + + public static class MockSuccessProducer extends MockProducer { + public MockSuccessProducer(PublishResultHandler handler) { + super(handler); + } + + @Override + public void publishSingleMessage( + Record msg, + Destination topicOrQueue, + boolean useCorrelationKeyLatency, + DeliveryMode deliveryMode) { + if (useCorrelationKeyLatency) { + handler.responseReceivedEx( + Solace.PublishResult.builder() + .setPublished(true) + .setMessageId(msg.getMessageId()) + .build()); + } else { + handler.responseReceivedEx(msg.getMessageId()); + } + } + } + + public static class MockFailedProducer extends MockProducer { + public MockFailedProducer(PublishResultHandler handler) { + super(handler); + } + + @Override + public void publishSingleMessage( + Record msg, + Destination topicOrQueue, + boolean useCorrelationKeyLatency, + DeliveryMode deliveryMode) { + if (useCorrelationKeyLatency) { + handler.handleErrorEx( + Solace.PublishResult.builder() + .setPublished(false) + .setMessageId(msg.getMessageId()) + .setError("Some error") + .build(), + new JCSMPException("Some JCSMPException"), + Instant.now().toEpochMilli()); + } else { + handler.handleErrorEx( + msg.getMessageId(), + new JCSMPException("Some JCSMPException"), + Instant.now().toEpochMilli()); + } + } + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java index a4d6a42ef302..bd52dee7ea86 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java @@ -17,38 +17,63 @@ */ package org.apache.beam.sdk.io.solace; +import com.google.auto.value.AutoValue; import com.solacesystems.jcsmp.BytesXMLMessage; import com.solacesystems.jcsmp.JCSMPProperties; import java.io.IOException; -import java.io.Serializable; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import org.apache.beam.sdk.io.solace.MockProducer.MockSuccessProducer; import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.MessageProducer; import org.apache.beam.sdk.io.solace.broker.MessageReceiver; +import org.apache.beam.sdk.io.solace.broker.PublishResultHandler; import org.apache.beam.sdk.io.solace.broker.SessionService; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; import org.apache.beam.sdk.transforms.SerializableFunction; import org.checkerframework.checker.nullness.qual.Nullable; -public class MockSessionService extends SessionService { +@AutoValue +public abstract class MockSessionService extends SessionService { + public static int ackWindowSizeForTesting = 87; + public static boolean callbackOnReactor = true; - private final SerializableFunction getRecordFn; - private MessageReceiver messageReceiver = null; - private final int minMessagesReceived; - private final @Nullable SubmissionMode mode; - - public MockSessionService( - SerializableFunction getRecordFn, - int minMessagesReceived, - @Nullable SubmissionMode mode) { - this.getRecordFn = getRecordFn; - this.minMessagesReceived = minMessagesReceived; - this.mode = mode; + public abstract @Nullable SerializableFunction recordFn(); + + public abstract int minMessagesReceived(); + + public abstract @Nullable SubmissionMode mode(); + + public abstract Function mockProducerFn(); + + private final Queue publishedResultsReceiver = new ConcurrentLinkedQueue<>(); + + public static Builder builder() { + return new AutoValue_MockSessionService.Builder() + .minMessagesReceived(0) + .mockProducerFn(MockSuccessProducer::new); } - public MockSessionService( - SerializableFunction getRecordFn, int minMessagesReceived) { - this(getRecordFn, minMessagesReceived, null); + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder recordFn( + @Nullable SerializableFunction recordFn); + + public abstract Builder minMessagesReceived(int minMessagesReceived); + + public abstract Builder mode(@Nullable SubmissionMode mode); + + public abstract Builder mockProducerFn( + Function mockProducerFn); + + public abstract MockSessionService build(); } + private MessageReceiver messageReceiver = null; + private MockProducer messageProducer = null; + @Override public void close() {} @@ -58,17 +83,41 @@ public boolean isClosed() { } @Override - public MessageReceiver createReceiver() { + public MessageReceiver getReceiver() { if (messageReceiver == null) { - messageReceiver = new MockReceiver(getRecordFn, minMessagesReceived); + messageReceiver = new MockReceiver(recordFn(), minMessagesReceived()); } return messageReceiver; } + @Override + public MessageProducer getInitializedProducer(SubmissionMode mode) { + if (messageProducer == null) { + messageProducer = mockProducerFn().apply(new PublishResultHandler(publishedResultsReceiver)); + } + return messageProducer; + } + + @Override + public Queue getPublishedResultsQueue() { + return publishedResultsReceiver; + } + @Override public void connect() {} - public static class MockReceiver implements MessageReceiver, Serializable { + @Override + public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProperties) { + // Let's override some properties that will be overriden by the connector + // Opposite of the mode, to test that is overriden + baseProperties.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, callbackOnReactor); + + baseProperties.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, ackWindowSizeForTesting); + + return baseProperties; + } + + public static class MockReceiver implements MessageReceiver { private final AtomicInteger counter = new AtomicInteger(); private final SerializableFunction getRecordFn; private final int minMessagesReceived; @@ -100,16 +149,4 @@ public boolean isEOF() { return counter.get() >= minMessagesReceived; } } - - @Override - public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProperties) { - // Let's override some properties that will be overriden by the connector - // Opposite of the mode, to test that is overriden - baseProperties.setProperty( - JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, mode == SubmissionMode.HIGHER_THROUGHPUT); - - baseProperties.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, 87); - - return baseProperties; - } } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionServiceFactory.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionServiceFactory.java index 603a30ad2c90..9c17ca604201 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionServiceFactory.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionServiceFactory.java @@ -17,22 +17,78 @@ */ package org.apache.beam.sdk.io.solace; +import com.google.auto.value.AutoValue; +import com.solacesystems.jcsmp.BytesXMLMessage; +import org.apache.beam.sdk.io.solace.MockProducer.MockFailedProducer; +import org.apache.beam.sdk.io.solace.MockProducer.MockSuccessProducer; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; import org.apache.beam.sdk.io.solace.broker.SessionService; import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.checkerframework.checker.nullness.qual.Nullable; -public class MockSessionServiceFactory extends SessionServiceFactory { - SessionService sessionService; +@AutoValue +public abstract class MockSessionServiceFactory extends SessionServiceFactory { + public abstract @Nullable SubmissionMode mode(); - public MockSessionServiceFactory(SessionService clientService) { - this.sessionService = clientService; + public abstract @Nullable SerializableFunction recordFn(); + + public abstract int minMessagesReceived(); + + public abstract SessionServiceType sessionServiceType(); + + public static Builder builder() { + return new AutoValue_MockSessionServiceFactory.Builder() + .minMessagesReceived(0) + .sessionServiceType(SessionServiceType.WITH_SUCCEEDING_PRODUCER); } public static SessionServiceFactory getDefaultMock() { - return new MockSessionServiceFactory(new MockEmptySessionService()); + return MockSessionServiceFactory.builder().build(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder mode(@Nullable SubmissionMode mode); + + public abstract Builder recordFn( + @Nullable SerializableFunction recordFn); + + public abstract Builder minMessagesReceived(int minMessagesReceived); + + public abstract Builder sessionServiceType(SessionServiceType sessionServiceType); + + public abstract MockSessionServiceFactory build(); } @Override public SessionService create() { - return sessionService; + switch (sessionServiceType()) { + case EMPTY: + return MockEmptySessionService.create(); + case WITH_SUCCEEDING_PRODUCER: + return MockSessionService.builder() + .recordFn(recordFn()) + .minMessagesReceived(minMessagesReceived()) + .mode(mode()) + .mockProducerFn(MockSuccessProducer::new) + .build(); + case WITH_FAILING_PRODUCER: + return MockSessionService.builder() + .recordFn(recordFn()) + .minMessagesReceived(minMessagesReceived()) + .mode(mode()) + .mockProducerFn(MockFailedProducer::new) + .build(); + default: + throw new RuntimeException( + String.format("Unknown sessionServiceType: %s", sessionServiceType().name())); + } + } + + public enum SessionServiceType { + EMPTY, + WITH_SUCCEEDING_PRODUCER, + WITH_FAILING_PRODUCER } } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java similarity index 72% rename from sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOTest.java rename to sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java index cc1fa1d667aa..c718c55e1b48 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOTest.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java @@ -31,10 +31,12 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; +import org.apache.beam.sdk.io.solace.MockSessionServiceFactory.SessionServiceType; import org.apache.beam.sdk.io.solace.SolaceIO.Read; import org.apache.beam.sdk.io.solace.SolaceIO.Read.Configuration; import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; @@ -49,6 +51,7 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; @@ -61,7 +64,7 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public class SolaceIOTest { +public class SolaceIOReadTest { @Rule public final transient TestPipeline pipeline = TestPipeline.create(); @@ -69,7 +72,6 @@ private Read getDefaultRead() { return SolaceIO.read() .from(Solace.Queue.fromName("queue")) .withSempClientFactory(MockSempClientFactory.getDefaultMock()) - .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()) .withMaxNumConnections(1); } @@ -77,7 +79,6 @@ private Read getDefaultReadForTopic() { return SolaceIO.read() .from(Solace.Topic.fromName("topic")) .withSempClientFactory(MockSempClientFactory.getDefaultMock()) - .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()) .withMaxNumConnections(1); } @@ -102,20 +103,18 @@ private static UnboundedSolaceSource getSource(Read spec, TestPi @Test public void testReadMessages() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().minMessagesReceived(3).recordFn(recordFn).build(); // Expected data List expected = new ArrayList<>(); @@ -137,20 +136,18 @@ public void testReadMessages() { @Test public void testReadMessagesWithDeduplication() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "451")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "451")); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Expected data List expected = new ArrayList<>(); @@ -172,19 +169,18 @@ public void testReadMessagesWithDeduplication() { @Test public void testReadMessagesWithoutDeduplication() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "451")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "451")); + return getOrNull(index, messages); + }; + SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Expected data List expected = new ArrayList<>(); @@ -206,32 +202,38 @@ public void testReadMessagesWithoutDeduplication() { @Test public void testReadMessagesWithDeduplicationOnReplicationGroupMessageId() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage( - "payload_test0", null, null, new ReplicationGroupMessageIdImpl(2L, 1L)), - SolaceDataUtils.getBytesXmlMessage( - "payload_test1", null, null, new ReplicationGroupMessageIdImpl(2L, 2L)), - SolaceDataUtils.getBytesXmlMessage( - "payload_test2", null, null, new ReplicationGroupMessageIdImpl(2L, 2L))); - return getOrNull(index, messages); - }, - 3); + + String id0 = UUID.randomUUID().toString(); + String id1 = UUID.randomUUID().toString(); + String id2 = UUID.randomUUID().toString(); + + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage( + "payload_test0", id0, null, new ReplicationGroupMessageIdImpl(2L, 1L)), + SolaceDataUtils.getBytesXmlMessage( + "payload_test1", id1, null, new ReplicationGroupMessageIdImpl(2L, 2L)), + SolaceDataUtils.getBytesXmlMessage( + "payload_test2", id2, null, new ReplicationGroupMessageIdImpl(2L, 2L))); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Expected data List expected = new ArrayList<>(); expected.add( SolaceDataUtils.getSolaceRecord( - "payload_test0", null, new ReplicationGroupMessageIdImpl(2L, 1L))); + "payload_test0", id0, new ReplicationGroupMessageIdImpl(2L, 1L))); + expected.add( + SolaceDataUtils.getSolaceRecord( + "payload_test1", id1, new ReplicationGroupMessageIdImpl(2L, 2L))); expected.add( SolaceDataUtils.getSolaceRecord( - "payload_test1", null, new ReplicationGroupMessageIdImpl(2L, 2L))); + "payload_test2", id2, new ReplicationGroupMessageIdImpl(2L, 2L))); // Run the pipeline PCollection events = @@ -248,19 +250,18 @@ public void testReadMessagesWithDeduplicationOnReplicationGroupMessageId() { @Test public void testReadWithCoderAndParseFnAndTimestampFn() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); + return getOrNull(index, messages); + }; + SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Expected data List expected = new ArrayList<>(); @@ -304,7 +305,10 @@ public void testSplitsForExclusiveQueue() throws Exception { SolaceIO.read() .from(Solace.Queue.fromName("queue")) .withSempClientFactory(new MockSempClientFactory(mockSempClient)) - .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()); + .withSessionServiceFactory( + MockSessionServiceFactory.builder() + .sessionServiceType(SessionServiceType.EMPTY) + .build()); int desiredNumSplits = 5; @@ -316,7 +320,10 @@ public void testSplitsForExclusiveQueue() throws Exception { @Test public void testSplitsForNonExclusiveQueueWithMaxNumConnections() throws Exception { - Read spec = getDefaultRead().withMaxNumConnections(3); + Read spec = + getDefaultRead() + .withMaxNumConnections(3) + .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()); int desiredNumSplits = 5; @@ -328,7 +335,10 @@ public void testSplitsForNonExclusiveQueueWithMaxNumConnections() throws Excepti @Test public void testSplitsForNonExclusiveQueueWithMaxNumConnectionsRespectDesired() throws Exception { - Read spec = getDefaultRead().withMaxNumConnections(10); + Read spec = + getDefaultRead() + .withMaxNumConnections(10) + .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()); int desiredNumSplits = 5; UnboundedSolaceSource initialSource = getSource(spec, pipeline); @@ -346,7 +356,9 @@ public void testCreateQueueForTopic() { .build(); Read spec = - getDefaultReadForTopic().withSempClientFactory(new MockSempClientFactory(mockSempClient)); + getDefaultReadForTopic() + .withSempClientFactory(new MockSempClientFactory(mockSempClient)) + .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()); spec.expand(PBegin.in(TestPipeline.create())); // check if createQueueForTopic was executed assertEquals(1, createQueueForTopicFnCounter.get()); @@ -358,22 +370,22 @@ public void testCheckpointMark() throws Exception { AtomicInteger countAckMessages = new AtomicInteger(0); // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = new ArrayList<>(); - for (int i = 0; i < 10; i++) { - messages.add( - SolaceDataUtils.getBytesXmlMessage( - "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); - } - countConsumedMessages.incrementAndGet(); - return getOrNull(index, messages); - }, - 10); + + SerializableFunction recordFn = + index -> { + List messages = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + messages.add( + SolaceDataUtils.getBytesXmlMessage( + "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); + } + countConsumedMessages.incrementAndGet(); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(10).build(); + Read spec = getDefaultRead().withSessionServiceFactory(fakeSessionServiceFactory); UnboundedSolaceSource initialSource = getSource(spec, pipeline); @@ -407,21 +419,20 @@ public void testCheckpointMarkAndFinalizeSeparately() throws Exception { AtomicInteger countAckMessages = new AtomicInteger(0); // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = new ArrayList<>(); - for (int i = 0; i < 10; i++) { - messages.add( - SolaceDataUtils.getBytesXmlMessage( - "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); - } - countConsumedMessages.incrementAndGet(); - return getOrNull(index, messages); - }, - 10); + SerializableFunction recordFn = + index -> { + List messages = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + messages.add( + SolaceDataUtils.getBytesXmlMessage( + "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); + } + countConsumedMessages.incrementAndGet(); + return getOrNull(index, messages); + }; + SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(10).build(); Read spec = getDefaultRead() @@ -467,22 +478,21 @@ public void testCheckpointMarkSafety() throws Exception { AtomicInteger countAckMessages = new AtomicInteger(0); // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = new ArrayList<>(); - for (int i = 0; i < messagesToProcess; i++) { - messages.add( - SolaceDataUtils.getBytesXmlMessage( - "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); - } - countConsumedMessages.incrementAndGet(); - return getOrNull(index, messages); - }, - 10); + SerializableFunction recordFn = + index -> { + List messages = new ArrayList<>(); + for (int i = 0; i < messagesToProcess; i++) { + messages.add( + SolaceDataUtils.getBytesXmlMessage( + "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); + } + countConsumedMessages.incrementAndGet(); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(10).build(); + Read spec = getDefaultRead() .withSessionServiceFactory(fakeSessionServiceFactory) @@ -558,20 +568,18 @@ public void testDestinationTopicQueueCreation() { @Test public void testTopicEncoding() { - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Run PCollection events = diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOWriteTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOWriteTest.java new file mode 100644 index 000000000000..e92657c3c3d2 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOWriteTest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace; + +import static org.apache.beam.sdk.values.TypeDescriptors.strings; + +import com.solacesystems.jcsmp.DeliveryMode; +import java.util.List; +import java.util.Objects; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.io.solace.MockSessionServiceFactory.SessionServiceType; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.SolaceIO.WriterType; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Record; +import org.apache.beam.sdk.io.solace.data.SolaceDataUtils; +import org.apache.beam.sdk.io.solace.write.SolaceOutput; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SolaceIOWriteTest { + + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + + private final List keys = ImmutableList.of("450", "451", "452"); + private final List payloads = ImmutableList.of("payload0", "payload1", "payload2"); + + private PCollection getRecords(Pipeline p) { + TestStream.Builder> kvBuilder = + TestStream.create(KvCoder.of(AvroCoder.of(String.class), AvroCoder.of(String.class))) + .advanceWatermarkTo(Instant.EPOCH); + + assert keys.size() == payloads.size(); + + for (int k = 0; k < keys.size(); k++) { + kvBuilder = + kvBuilder + .addElements(KV.of(keys.get(k), payloads.get(k))) + .advanceProcessingTime(Duration.standardSeconds(60)); + } + + TestStream> testStream = kvBuilder.advanceWatermarkToInfinity(); + PCollection> kvs = p.apply("Test stream", testStream); + + return kvs.apply( + "To Record", + MapElements.into(TypeDescriptor.of(Record.class)) + .via(kv -> SolaceDataUtils.getSolaceRecord(kv.getValue(), kv.getKey()))); + } + + private SolaceOutput getWriteTransform( + SubmissionMode mode, + WriterType writerType, + Pipeline p, + ErrorHandler errorHandler) { + SessionServiceFactory fakeSessionServiceFactory = + MockSessionServiceFactory.builder().mode(mode).build(); + + PCollection records = getRecords(p); + return records.apply( + "Write to Solace", + SolaceIO.write() + .to(Solace.Queue.fromName("queue")) + .withSubmissionMode(mode) + .withWriterType(writerType) + .withDeliveryMode(DeliveryMode.PERSISTENT) + .withSessionServiceFactory(fakeSessionServiceFactory) + .withErrorHandler(errorHandler)); + } + + private static PCollection getIdsPCollection(SolaceOutput output) { + return output + .getSuccessfulPublish() + .apply( + "Get message ids", MapElements.into(strings()).via(Solace.PublishResult::getMessageId)); + } + + @Test + public void testWriteLatencyStreaming() throws Exception { + SubmissionMode mode = SubmissionMode.LOWER_LATENCY; + WriterType writerType = WriterType.STREAMING; + + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + SolaceOutput output = getWriteTransform(mode, writerType, pipeline, errorHandler); + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).containsInAnyOrder(keys); + errorHandler.close(); + PAssert.that(errorHandler.getOutput()).empty(); + + pipeline.run(); + } + + @Test + public void testWriteThroughputStreaming() throws Exception { + SubmissionMode mode = SubmissionMode.HIGHER_THROUGHPUT; + WriterType writerType = WriterType.STREAMING; + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + SolaceOutput output = getWriteTransform(mode, writerType, pipeline, errorHandler); + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).containsInAnyOrder(keys); + errorHandler.close(); + PAssert.that(errorHandler.getOutput()).empty(); + + pipeline.run(); + } + + @Test + public void testWriteLatencyBatched() throws Exception { + SubmissionMode mode = SubmissionMode.LOWER_LATENCY; + WriterType writerType = WriterType.BATCHED; + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + SolaceOutput output = getWriteTransform(mode, writerType, pipeline, errorHandler); + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).containsInAnyOrder(keys); + errorHandler.close(); + PAssert.that(errorHandler.getOutput()).empty(); + pipeline.run(); + } + + @Test + public void testWriteThroughputBatched() throws Exception { + SubmissionMode mode = SubmissionMode.HIGHER_THROUGHPUT; + WriterType writerType = WriterType.BATCHED; + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + SolaceOutput output = getWriteTransform(mode, writerType, pipeline, errorHandler); + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).containsInAnyOrder(keys); + errorHandler.close(); + PAssert.that(errorHandler.getOutput()).empty(); + pipeline.run(); + } + + @Test + public void testWriteWithFailedRecords() throws Exception { + SubmissionMode mode = SubmissionMode.HIGHER_THROUGHPUT; + WriterType writerType = WriterType.BATCHED; + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + + SessionServiceFactory fakeSessionServiceFactory = + MockSessionServiceFactory.builder() + .mode(mode) + .sessionServiceType(SessionServiceType.WITH_FAILING_PRODUCER) + .build(); + + PCollection records = getRecords(pipeline); + SolaceOutput output = + records.apply( + "Write to Solace", + SolaceIO.write() + .to(Solace.Queue.fromName("queue")) + .withSubmissionMode(mode) + .withWriterType(writerType) + .withDeliveryMode(DeliveryMode.PERSISTENT) + .withSessionServiceFactory(fakeSessionServiceFactory) + .withErrorHandler(errorHandler)); + + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).empty(); + errorHandler.close(); + PAssert.thatSingleton(Objects.requireNonNull(errorHandler.getOutput())) + .isEqualTo((long) payloads.size()); + pipeline.run(); + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java index 0c6f88a7c9d5..357734f18aad 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java @@ -31,9 +31,8 @@ public class OverrideWriterPropertiesTest { @Test public void testOverrideForHigherThroughput() { SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.HIGHER_THROUGHPUT; - MockSessionService service = new MockSessionService(null, 0, mode); + MockSessionService service = MockSessionService.builder().mode(mode).build(); - // Test HIGHER_THROUGHPUT mode JCSMPProperties props = service.initializeWriteSessionProperties(mode); assertEquals(false, props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); assertEquals( @@ -44,13 +43,26 @@ public void testOverrideForHigherThroughput() { @Test public void testOverrideForLowerLatency() { SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.LOWER_LATENCY; - MockSessionService service = new MockSessionService(null, 0, mode); + MockSessionService service = MockSessionService.builder().mode(mode).build(); - // Test HIGHER_THROUGHPUT mode JCSMPProperties props = service.initializeWriteSessionProperties(mode); assertEquals(true, props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); assertEquals( Long.valueOf(50), Long.valueOf(props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE))); } + + @Test + public void testDontOverrideForCustom() { + SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.CUSTOM; + MockSessionService service = MockSessionService.builder().mode(mode).build(); + + JCSMPProperties props = service.initializeWriteSessionProperties(mode); + assertEquals( + MockSessionService.callbackOnReactor, + props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); + assertEquals( + Long.valueOf(MockSessionService.ackWindowSizeForTesting), + Long.valueOf(props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE))); + } } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/data/SolaceDataUtils.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/data/SolaceDataUtils.java index 5134bd131d73..9e04c4cfd276 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/data/SolaceDataUtils.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/data/SolaceDataUtils.java @@ -100,7 +100,7 @@ public static Solace.Record getSolaceRecord( : DEFAULT_REPLICATION_GROUP_ID.toString(); return Solace.Record.builder() - .setPayload(ByteBuffer.wrap(payload.getBytes(StandardCharsets.UTF_8))) + .setPayload(payload.getBytes(StandardCharsets.UTF_8)) .setMessageId(messageId) .setDestination( Solace.Destination.builder() @@ -116,7 +116,7 @@ public static Solace.Record getSolaceRecord( .setTimeToLive(1000L) .setSenderTimestamp(null) .setReplicationGroupMessageId(replicationGroupMessageIdString) - .setAttachmentBytes(ByteBuffer.wrap(new byte[0])) + .setAttachmentBytes(new byte[0]) .build(); } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClient.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClient.java new file mode 100644 index 000000000000..637cecdcfd15 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClient.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.it; + +import com.google.api.client.http.HttpRequestFactory; +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.sdk.io.solace.broker.BasicAuthSempClient; +import org.apache.beam.sdk.io.solace.broker.SempBasicAuthClientExecutor; +import org.apache.beam.sdk.util.SerializableSupplier; + +/** + * Example class showing how the {@link BasicAuthSempClient} can be extended or have functionalities + * overridden. In this case, the modified method is {@link + * BasicAuthSempClient#getBacklogBytes(String)}, which queries multiple SEMP endpoints to collect + * accurate backlog metrics. For usage, see {@link SolaceIOMultipleSempIT}. + */ +public class BasicAuthMultipleSempClient extends BasicAuthSempClient { + private final List sempBacklogBasicAuthClientExecutors; + + public BasicAuthMultipleSempClient( + String mainHost, + List backlogHosts, + String username, + String password, + String vpnName, + SerializableSupplier httpRequestFactorySupplier) { + super(mainHost, username, password, vpnName, httpRequestFactorySupplier); + sempBacklogBasicAuthClientExecutors = + backlogHosts.stream() + .map( + host -> + new SempBasicAuthClientExecutor( + host, username, password, vpnName, httpRequestFactorySupplier.get())) + .collect(Collectors.toList()); + } + + @Override + public long getBacklogBytes(String queueName) throws IOException { + long backlog = 0; + for (SempBasicAuthClientExecutor client : sempBacklogBasicAuthClientExecutors) { + backlog += client.getBacklogBytes(queueName); + } + return backlog; + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClientFactory.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClientFactory.java new file mode 100644 index 000000000000..0a548c10555c --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClientFactory.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.it; + +import com.google.api.client.http.HttpRequestFactory; +import com.google.api.client.http.javanet.NetHttpTransport; +import com.google.auto.value.AutoValue; +import java.util.List; +import org.apache.beam.sdk.io.solace.broker.SempClient; +import org.apache.beam.sdk.io.solace.broker.SempClientFactory; +import org.apache.beam.sdk.util.SerializableSupplier; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Example class showing how to implement a custom {@link SempClientFactory} with custom client. For + * usage, see {@link SolaceIOMultipleSempIT}. + */ +@AutoValue +public abstract class BasicAuthMultipleSempClientFactory implements SempClientFactory { + + public abstract String mainHost(); + + public abstract List backlogHosts(); + + public abstract String username(); + + public abstract String password(); + + public abstract String vpnName(); + + public abstract @Nullable SerializableSupplier httpRequestFactorySupplier(); + + public static Builder builder() { + return new AutoValue_BasicAuthMultipleSempClientFactory.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + /** Set Solace host, format: [Protocol://]Host[:Port]. */ + public abstract Builder mainHost(String host); + + public abstract Builder backlogHosts(List hosts); + + /** Set Solace username. */ + public abstract Builder username(String username); + /** Set Solace password. */ + public abstract Builder password(String password); + + /** Set Solace vpn name. */ + public abstract Builder vpnName(String vpnName); + + abstract Builder httpRequestFactorySupplier( + SerializableSupplier httpRequestFactorySupplier); + + public abstract BasicAuthMultipleSempClientFactory build(); + } + + @Override + public SempClient create() { + return new BasicAuthMultipleSempClient( + mainHost(), + backlogHosts(), + username(), + password(), + vpnName(), + getHttpRequestFactorySupplier()); + } + + @SuppressWarnings("return") + private @NonNull SerializableSupplier getHttpRequestFactorySupplier() { + SerializableSupplier httpRequestSupplier = httpRequestFactorySupplier(); + return httpRequestSupplier != null + ? httpRequestSupplier + : () -> new NetHttpTransport().createRequestFactory(); + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java index 1a2a056efd45..ee5d206533dc 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java @@ -17,49 +17,71 @@ */ package org.apache.beam.sdk.io.solace.it; +import static org.apache.beam.sdk.io.solace.it.SolaceContainerManager.TOPIC_NAME; +import static org.apache.beam.sdk.values.TypeDescriptors.strings; import static org.junit.Assert.assertEquals; +import com.solacesystems.jcsmp.DeliveryMode; import java.io.IOException; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; import org.apache.beam.sdk.io.solace.SolaceIO; import org.apache.beam.sdk.io.solace.broker.BasicAuthJcsmpSessionServiceFactory; import org.apache.beam.sdk.io.solace.broker.BasicAuthSempClientFactory; +import org.apache.beam.sdk.io.solace.data.Solace; import org.apache.beam.sdk.io.solace.data.Solace.Queue; +import org.apache.beam.sdk.io.solace.data.SolaceDataUtils; +import org.apache.beam.sdk.io.solace.write.SolaceOutput; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.testutils.metrics.MetricsReader; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; import org.joda.time.Duration; +import org.joda.time.Instant; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.FixMethodOrder; import org.junit.Rule; import org.junit.Test; +import org.junit.runners.MethodSorters; +@FixMethodOrder(MethodSorters.NAME_ASCENDING) public class SolaceIOIT { private static final String NAMESPACE = SolaceIOIT.class.getName(); private static final String READ_COUNT = "read_count"; + private static final String WRITE_COUNT = "write_count"; private static SolaceContainerManager solaceContainerManager; - private static final TestPipelineOptions readPipelineOptions; + private static final String queueName = "test_queue"; + private static final TestPipelineOptions pipelineOptions; + private static final long PUBLISH_MESSAGE_COUNT = 20; static { - readPipelineOptions = PipelineOptionsFactory.create().as(TestPipelineOptions.class); - readPipelineOptions.setBlockOnRun(false); - readPipelineOptions.as(TestPipelineOptions.class).setBlockOnRun(false); - readPipelineOptions.as(StreamingOptions.class).setStreaming(false); + pipelineOptions = PipelineOptionsFactory.create().as(TestPipelineOptions.class); + pipelineOptions.as(StreamingOptions.class).setStreaming(true); + // For the read connector tests, we need to make sure that p.run() does not block + pipelineOptions.setBlockOnRun(false); + pipelineOptions.as(TestPipelineOptions.class).setBlockOnRun(false); } - @Rule public final TestPipeline readPipeline = TestPipeline.fromOptions(readPipelineOptions); + @Rule public final TestPipeline pipeline = TestPipeline.fromOptions(pipelineOptions); @BeforeClass public static void setup() throws IOException { solaceContainerManager = new SolaceContainerManager(); solaceContainerManager.start(); + solaceContainerManager.createQueueWithSubscriptionTopic(queueName); } @AfterClass @@ -69,20 +91,17 @@ public static void afterClass() { } } + // The order of the following tests matter. The first test publishes some messages in a Solace + // queue, and those messages are read by the second test. If another writer tests is run before + // the read test, that will alter the count for the read test and will make it fail. @Test - public void testRead() { - String queueName = "test_queue"; - solaceContainerManager.createQueueWithSubscriptionTopic(queueName); - - // todo this is very slow, needs to be replaced with the SolaceIO.write connector. - int publishMessagesCount = 20; - for (int i = 0; i < publishMessagesCount; i++) { - solaceContainerManager.sendToTopic( - "{\"field_str\":\"value\",\"field_int\":123}", - ImmutableList.of("Solace-Message-ID:m" + i)); - } + public void test01WriteStreaming() { + testWriteConnector(SolaceIO.WriterType.STREAMING); + } - readPipeline + @Test + public void test02Read() { + pipeline .apply( "Read from Solace", SolaceIO.read() @@ -105,12 +124,83 @@ public void testRead() { .build())) .apply("Count", ParDo.of(new CountingFn<>(NAMESPACE, READ_COUNT))); - PipelineResult pipelineResult = readPipeline.run(); + PipelineResult pipelineResult = pipeline.run(); + // We need enough time for Beam to pull all messages from the queue, but we need a timeout too, + // as the Read connector will keep attempting to read forever. pipelineResult.waitUntilFinish(Duration.standardSeconds(15)); MetricsReader metricsReader = new MetricsReader(pipelineResult, NAMESPACE); long actualRecordsCount = metricsReader.getCounterMetric(READ_COUNT); - assertEquals(publishMessagesCount, actualRecordsCount); + assertEquals(PUBLISH_MESSAGE_COUNT, actualRecordsCount); + } + + @Test + public void test03WriteBatched() { + testWriteConnector(SolaceIO.WriterType.BATCHED); + } + + private void testWriteConnector(SolaceIO.WriterType writerType) { + Pipeline p = createWriterPipeline(writerType); + + PipelineResult pipelineResult = p.run(); + pipelineResult.waitUntilFinish(); + MetricsReader metricsReader = new MetricsReader(pipelineResult, NAMESPACE); + long actualRecordsCount = metricsReader.getCounterMetric(WRITE_COUNT); + assertEquals(PUBLISH_MESSAGE_COUNT, actualRecordsCount); + } + + private Pipeline createWriterPipeline(SolaceIO.WriterType writerType) { + TestStream.Builder> kvBuilder = + TestStream.create(KvCoder.of(AvroCoder.of(String.class), AvroCoder.of(String.class))) + .advanceWatermarkTo(Instant.EPOCH); + + for (int i = 0; i < PUBLISH_MESSAGE_COUNT; i++) { + String key = "Solace-Message-ID:m" + i; + String payload = String.format("{\"field_str\":\"value\",\"field_int\":123%d}", i); + kvBuilder = + kvBuilder + .addElements(KV.of(key, payload)) + .advanceProcessingTime(Duration.standardSeconds(60)); + } + + TestStream> testStream = kvBuilder.advanceWatermarkToInfinity(); + + PCollection> kvs = + pipeline.apply(String.format("Test stream %s", writerType), testStream); + + PCollection records = + kvs.apply( + String.format("To Record %s", writerType), + MapElements.into(TypeDescriptor.of(Solace.Record.class)) + .via(kv -> SolaceDataUtils.getSolaceRecord(kv.getValue(), kv.getKey()))); + + SolaceOutput result = + records.apply( + String.format("Write to Solace %s", writerType), + SolaceIO.write() + .to(Solace.Topic.fromName(TOPIC_NAME)) + .withSubmissionMode(SolaceIO.SubmissionMode.TESTING) + .withWriterType(writerType) + .withDeliveryMode(DeliveryMode.PERSISTENT) + .withNumberOfClientsPerWorker(1) + .withNumShards(1) + .withSessionServiceFactory( + BasicAuthJcsmpSessionServiceFactory.builder() + .host("localhost:" + solaceContainerManager.jcsmpPortMapped) + .username(SolaceContainerManager.USERNAME) + .password(SolaceContainerManager.PASSWORD) + .vpnName(SolaceContainerManager.VPN_NAME) + .build())); + result + .getSuccessfulPublish() + .apply( + String.format("Get ids %s", writerType), + MapElements.into(strings()).via(Solace.PublishResult::getMessageId)) + .apply( + String.format("Count %s", writerType), + ParDo.of(new CountingFn<>(NAMESPACE, WRITE_COUNT))); + + return pipeline; } private static class CountingFn extends DoFn { diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOMultipleSempIT.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOMultipleSempIT.java new file mode 100644 index 000000000000..77d00b4e41ec --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOMultipleSempIT.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.it; + +import static org.apache.beam.sdk.io.solace.it.SolaceContainerManager.TOPIC_NAME; +import static org.apache.beam.sdk.values.TypeDescriptors.strings; +import static org.junit.Assert.assertEquals; + +import com.solacesystems.jcsmp.DeliveryMode; +import java.io.IOException; +import java.util.Arrays; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.SolaceIO.WriterType; +import org.apache.beam.sdk.io.solace.broker.BasicAuthJcsmpSessionServiceFactory; +import org.apache.beam.sdk.io.solace.broker.SempClientFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Queue; +import org.apache.beam.sdk.io.solace.data.SolaceDataUtils; +import org.apache.beam.sdk.io.solace.write.SolaceOutput; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.StreamingOptions; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.testutils.metrics.MetricsReader; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; + +public class SolaceIOMultipleSempIT { + private static final String NAMESPACE = SolaceIOMultipleSempIT.class.getName(); + private static final String READ_COUNT = "read_count"; + private static final String QUEUE_NAME = "test_queue"; + private static final long PUBLISH_MESSAGE_COUNT = 20; + private static final TestPipelineOptions pipelineOptions; + private static SolaceContainerManager solaceContainerManager; + + static { + pipelineOptions = PipelineOptionsFactory.create().as(TestPipelineOptions.class); + pipelineOptions.as(StreamingOptions.class).setStreaming(true); + // For the read connector tests, we need to make sure that p.run() does not block + pipelineOptions.setBlockOnRun(false); + pipelineOptions.as(TestPipelineOptions.class).setBlockOnRun(false); + } + + @Rule public final TestPipeline pipeline = TestPipeline.fromOptions(pipelineOptions); + + @BeforeClass + public static void setup() throws IOException { + solaceContainerManager = new SolaceContainerManager(); + solaceContainerManager.start(); + solaceContainerManager.createQueueWithSubscriptionTopic(QUEUE_NAME); + } + + @AfterClass + public static void afterClass() { + if (solaceContainerManager != null) { + solaceContainerManager.stop(); + } + } + + /** + * This test verifies the functionality of reading data from a Solace queue using the + * SolaceIO.read() transform. This test does not actually test functionalities of {@link + * BasicAuthMultipleSempClientFactory}, but it demonstrates how to integrate a custom + * implementation of {@link SempClientFactory}, in this case, {@link + * BasicAuthMultipleSempClientFactory}, to handle authentication and configuration interactions + * with the Solace message broker. + */ + @Test + public void test01writeAndReadWithMultipleSempClientFactory() { + Pipeline writerPipeline = + createWriterPipeline(WriterType.BATCHED, solaceContainerManager.jcsmpPortMapped); + writerPipeline + .apply( + "Read from Solace", + SolaceIO.read() + .from(Queue.fromName(QUEUE_NAME)) + .withMaxNumConnections(1) + .withDeduplicateRecords(true) + .withSempClientFactory( + BasicAuthMultipleSempClientFactory.builder() + .backlogHosts( + Arrays.asList( + "http://localhost:" + solaceContainerManager.sempPortMapped, + "http://localhost:" + solaceContainerManager.sempPortMapped)) + .mainHost("http://localhost:" + solaceContainerManager.sempPortMapped) + .username("admin") + .password("admin") + .vpnName(SolaceContainerManager.VPN_NAME) + .build()) + .withSessionServiceFactory( + BasicAuthJcsmpSessionServiceFactory.builder() + .host("localhost:" + solaceContainerManager.jcsmpPortMapped) + .username(SolaceContainerManager.USERNAME) + .password(SolaceContainerManager.PASSWORD) + .vpnName(SolaceContainerManager.VPN_NAME) + .build())) + .apply("Count", ParDo.of(new CountingFn<>(NAMESPACE, READ_COUNT))); + + PipelineResult pipelineResult = writerPipeline.run(); + // We need enough time for Beam to pull all messages from the queue, but we need a timeout too, + // as the Read connector will keep attempting to read forever. + pipelineResult.waitUntilFinish(Duration.standardSeconds(15)); + + MetricsReader metricsReader = new MetricsReader(pipelineResult, NAMESPACE); + long actualRecordsCount = metricsReader.getCounterMetric(READ_COUNT); + assertEquals(PUBLISH_MESSAGE_COUNT, actualRecordsCount); + } + + private Pipeline createWriterPipeline( + SolaceIO.WriterType writerType, int solaceContainerJcsmpPort) { + TestStream.Builder> kvBuilder = + TestStream.create(KvCoder.of(AvroCoder.of(String.class), AvroCoder.of(String.class))) + .advanceWatermarkTo(Instant.EPOCH); + + for (int i = 0; i < PUBLISH_MESSAGE_COUNT; i++) { + String key = "Solace-Message-ID:m" + solaceContainerJcsmpPort + i; + String payload = String.format("{\"field_str\":\"value\",\"field_int\":123%d}", i); + kvBuilder = + kvBuilder + .addElements(KV.of(key, payload)) + .advanceProcessingTime(Duration.standardSeconds(60)); + } + + TestStream> testStream = kvBuilder.advanceWatermarkToInfinity(); + + PCollection> kvs = + pipeline.apply(String.format("Test stream %s", writerType), testStream); + + PCollection records = + kvs.apply( + String.format("To Record %s", writerType), + MapElements.into(TypeDescriptor.of(Solace.Record.class)) + .via(kv -> SolaceDataUtils.getSolaceRecord(kv.getValue(), kv.getKey()))); + + SolaceOutput result = + records.apply( + String.format("Write to Solace %s", writerType), + SolaceIO.write() + .to(Solace.Topic.fromName(TOPIC_NAME)) + .withSubmissionMode(SolaceIO.SubmissionMode.TESTING) + .withWriterType(writerType) + .withDeliveryMode(DeliveryMode.PERSISTENT) + .withNumberOfClientsPerWorker(1) + .withNumShards(1) + .withSessionServiceFactory( + BasicAuthJcsmpSessionServiceFactory.builder() + .host("localhost:" + solaceContainerJcsmpPort) + .username(SolaceContainerManager.USERNAME) + .password(SolaceContainerManager.PASSWORD) + .vpnName(SolaceContainerManager.VPN_NAME) + .build())); + result + .getSuccessfulPublish() + .apply( + String.format("Get ids %s", writerType), + MapElements.into(strings()).via(Solace.PublishResult::getMessageId)); + + return pipeline; + } + + private static class CountingFn extends DoFn { + + private final Counter elementCounter; + + CountingFn(String namespace, String name) { + elementCounter = Metrics.counter(namespace, name); + } + + @ProcessElement + public void processElement(@Element T record, OutputReceiver c) { + elementCounter.inc(1L); + c.output(record); + } + } +} diff --git a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java index 5f4e195f227f..3094ea47d6ad 100644 --- a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java +++ b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java @@ -202,10 +202,10 @@ private Schema.Field beamField(FieldMetaData fieldDescriptor) { @SuppressWarnings("rawtypes") @Override - public @NonNull List fieldValueGetters( - @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { + public @NonNull List> fieldValueGetters( + @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { return schemaFieldDescriptors(targetTypeDescriptor.getRawType(), schema).keySet().stream() - .map(FieldExtractor::new) + .>map(FieldExtractor::new) .collect(Collectors.toList()); } @@ -242,10 +242,12 @@ private FieldValueTypeInformation fieldValueTypeInfo(Class type, String field if (factoryMethods.size() > 1) { throw new IllegalStateException("Overloaded factory methods: " + factoryMethods); } - return FieldValueTypeInformation.forSetter(factoryMethods.get(0), ""); + return FieldValueTypeInformation.forSetter( + TypeDescriptor.of(type), factoryMethods.get(0), ""); } else { try { - return FieldValueTypeInformation.forField(type.getDeclaredField(fieldName), 0); + return FieldValueTypeInformation.forField( + TypeDescriptor.of(type), type.getDeclaredField(fieldName), 0); } catch (NoSuchFieldException e) { throw new IllegalArgumentException(e); } @@ -373,7 +375,7 @@ private & TEnum> FieldType beamType(FieldValueMetaDat } } - private static class FieldExtractor> + private static class FieldExtractor implements FieldValueGetter { private final FieldT field; @@ -383,8 +385,9 @@ private FieldExtractor(FieldT field) { @Override public @Nullable Object get(T thrift) { - if (!(thrift instanceof TUnion) || thrift.isSet(field)) { - final Object value = thrift.getFieldValue(field); + TBase t = (TBase) thrift; + if (!(thrift instanceof TUnion) || t.isSet(field)) { + final Object value = t.getFieldValue(field); if (value instanceof Enum) { return ((Enum) value).ordinal(); } else { diff --git a/sdks/java/javadoc/build.gradle b/sdks/java/javadoc/build.gradle index c0622b173043..284cef130bd3 100644 --- a/sdks/java/javadoc/build.gradle +++ b/sdks/java/javadoc/build.gradle @@ -62,7 +62,7 @@ task aggregateJavadoc(type: Javadoc) { source exportedJavadocProjects.collect { project(it).sourceSets.main.allJava } classpath = files(exportedJavadocProjects.collect { project(it).sourceSets.main.compileClasspath }) destinationDir = file("${buildDir}/docs/javadoc") - failOnError = true + failOnError = false exclude "org/apache/beam/examples/*" exclude "org/apache/beam/fn/harness/*" diff --git a/sdks/java/managed/build.gradle b/sdks/java/managed/build.gradle index add0d7f3cc0d..c6e868872246 100644 --- a/sdks/java/managed/build.gradle +++ b/sdks/java/managed/build.gradle @@ -28,6 +28,7 @@ ext.summary = """Library that provides managed IOs.""" dependencies { implementation project(path: ":sdks:java:core", configuration: "shadow") + implementation project(path: ":model:pipeline", configuration: "shadow") implementation library.java.vendored_guava_32_1_2_jre implementation library.java.slf4j_api diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index 911e25cdda14..8e7e0862eff4 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -17,11 +17,14 @@ */ package org.apache.beam.sdk.managed; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; + import com.google.auto.value.AutoValue; import java.util.ArrayList; import java.util.List; import java.util.Map; import javax.annotation.Nullable; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; @@ -83,17 +86,20 @@ public class Managed { // TODO: Dynamically generate a list of supported transforms public static final String ICEBERG = "iceberg"; public static final String KAFKA = "kafka"; + public static final String BIGQUERY = "bigquery"; // Supported SchemaTransforms public static final Map READ_TRANSFORMS = ImmutableMap.builder() - .put(ICEBERG, ManagedTransformConstants.ICEBERG_READ) - .put(KAFKA, ManagedTransformConstants.KAFKA_READ) + .put(ICEBERG, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_READ)) + .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ)) + .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ)) .build(); public static final Map WRITE_TRANSFORMS = ImmutableMap.builder() - .put(ICEBERG, ManagedTransformConstants.ICEBERG_WRITE) - .put(KAFKA, ManagedTransformConstants.KAFKA_WRITE) + .put(ICEBERG, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_WRITE)) + .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE)) + .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE)) .build(); /** @@ -101,7 +107,9 @@ public class Managed { * supported managed sources are: * *

      - *
    • {@link Managed#ICEBERG} : Read from Apache Iceberg + *
    • {@link Managed#ICEBERG} : Read from Apache Iceberg tables + *
    • {@link Managed#KAFKA} : Read from Apache Kafka topics + *
    • {@link Managed#BIGQUERY} : Read from GCP BigQuery tables *
    */ public static ManagedTransform read(String source) { @@ -121,7 +129,9 @@ public static ManagedTransform read(String source) { * managed sinks are: * *
      - *
    • {@link Managed#ICEBERG} : Write to Apache Iceberg + *
    • {@link Managed#ICEBERG} : Write to Apache Iceberg tables + *
    • {@link Managed#KAFKA} : Write to Apache Kafka topics + *
    • {@link Managed#BIGQUERY} : Write to GCP BigQuery tables *
    */ public static ManagedTransform write(String sink) { diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java index 6f97983d3260..b705306b9478 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java @@ -117,7 +117,7 @@ protected void validate() { "Please specify a config or a config URL, but not both."); } - public @Nullable String resolveUnderlyingConfig() { + private Map resolveUnderlyingConfig() { String yamlTransformConfig = getConfig(); // If YAML string is empty, then attempt to read from YAML file if (Strings.isNullOrEmpty(yamlTransformConfig)) { @@ -131,7 +131,8 @@ protected void validate() { throw new RuntimeException(e); } } - return yamlTransformConfig; + + return YamlUtils.yamlStringToMap(yamlTransformConfig); } } @@ -152,34 +153,34 @@ protected SchemaTransform from(ManagedConfig managedConfig) { static class ManagedSchemaTransform extends SchemaTransform { private final ManagedConfig managedConfig; - private final Row underlyingTransformConfig; + private final Row underlyingRowConfig; private final SchemaTransformProvider underlyingTransformProvider; ManagedSchemaTransform( ManagedConfig managedConfig, SchemaTransformProvider underlyingTransformProvider) { // parse config before expansion to check if it matches underlying transform's config schema Schema transformConfigSchema = underlyingTransformProvider.configurationSchema(); - Row underlyingTransformConfig; + Row underlyingRowConfig; try { - underlyingTransformConfig = getRowConfig(managedConfig, transformConfigSchema); + underlyingRowConfig = getRowConfig(managedConfig, transformConfigSchema); } catch (Exception e) { throw new IllegalArgumentException( "Encountered an error when retrieving a Row configuration", e); } - this.managedConfig = managedConfig; - this.underlyingTransformConfig = underlyingTransformConfig; + this.underlyingRowConfig = underlyingRowConfig; this.underlyingTransformProvider = underlyingTransformProvider; + this.managedConfig = managedConfig; } @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { LOG.debug( - "Building transform \"{}\" with Row configuration: {}", + "Building transform \"{}\" with configuration: {}", underlyingTransformProvider.identifier(), - underlyingTransformConfig); + underlyingRowConfig); - return input.apply(underlyingTransformProvider.from(underlyingTransformConfig)); + return input.apply(underlyingTransformProvider.from(underlyingRowConfig)); } public ManagedConfig getManagedConfig() { @@ -201,16 +202,14 @@ Row getConfigurationRow() { } } + // May return an empty row (perhaps the underlying transform doesn't have any required + // parameters) @VisibleForTesting static Row getRowConfig(ManagedConfig config, Schema transformSchema) { - // May return an empty row (perhaps the underlying transform doesn't have any required - // parameters) - String yamlConfig = config.resolveUnderlyingConfig(); - Map configMap = YamlUtils.yamlStringToMap(yamlConfig); - - // The config Row object will be used to build the underlying SchemaTransform. - // If a mapping for the SchemaTransform exists, we use it to update parameter names and align - // with the underlying config schema + Map configMap = config.resolveUnderlyingConfig(); + // Build a config Row that will be used to build the underlying SchemaTransform. + // If a mapping for the SchemaTransform exists, we use it to update parameter names to align + // with the underlying SchemaTransform config schema Map mapping = MAPPINGS.get(config.getTransformIdentifier()); if (mapping != null && configMap != null) { Map remappedConfig = new HashMap<>(); @@ -227,7 +226,7 @@ static Row getRowConfig(ManagedConfig config, Schema transformSchema) { return YamlUtils.toBeamRow(configMap, transformSchema, false); } - // We load providers seperately, after construction, to prevent the + // We load providers separately, after construction, to prevent the // 'ManagedSchemaTransformProvider' from being initialized in a recursive loop // when being loaded using 'AutoValue'. synchronized Map getAllProviders() { diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java index 51d0b67b4b89..30476a30d373 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java @@ -17,7 +17,10 @@ */ package org.apache.beam.sdk.managed; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; + import java.util.Map; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; /** @@ -41,21 +44,33 @@ public class ManagedTransformConstants { // Standard input PCollection tag public static final String INPUT = "input"; - public static final String ICEBERG_READ = "beam:schematransform:org.apache.beam:iceberg_read:v1"; - public static final String ICEBERG_WRITE = - "beam:schematransform:org.apache.beam:iceberg_write:v1"; - public static final String KAFKA_READ = "beam:schematransform:org.apache.beam:kafka_read:v1"; - public static final String KAFKA_WRITE = "beam:schematransform:org.apache.beam:kafka_write:v1"; - private static final Map KAFKA_READ_MAPPINGS = ImmutableMap.builder().put("data_format", "format").build(); private static final Map KAFKA_WRITE_MAPPINGS = ImmutableMap.builder().put("data_format", "format").build(); + private static final Map BIGQUERY_READ_MAPPINGS = + ImmutableMap.builder() + .put("table", "table_spec") + .put("fields", "selected_fields") + .build(); + + private static final Map BIGQUERY_WRITE_MAPPINGS = + ImmutableMap.builder() + .put("at_least_once", "use_at_least_once_semantics") + .put("triggering_frequency", "triggering_frequency_seconds") + .build(); + public static final Map> MAPPINGS = ImmutableMap.>builder() - .put(KAFKA_READ, KAFKA_READ_MAPPINGS) - .put(KAFKA_WRITE, KAFKA_WRITE_MAPPINGS) + .put(getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ), KAFKA_READ_MAPPINGS) + .put(getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE), KAFKA_WRITE_MAPPINGS) + .put( + getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ), + BIGQUERY_READ_MAPPINGS) + .put( + getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE), + BIGQUERY_WRITE_MAPPINGS) .build(); } diff --git a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java index e9edf8751e34..a287ec6260ce 100644 --- a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java +++ b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java @@ -88,8 +88,7 @@ public void testGetConfigRowFromYamlFile() throws URISyntaxException { .withFieldValue("extra_integer", 123) .build(); Row configRow = - ManagedSchemaTransformProvider.getRowConfig( - config, new TestSchemaTransformProvider().configurationSchema()); + ManagedSchemaTransformProvider.getRowConfig(config, TestSchemaTransformProvider.SCHEMA); assertEquals(expectedRow, configRow); } diff --git a/sdks/java/testing/nexmark/build.gradle b/sdks/java/testing/nexmark/build.gradle index a09ed9238991..0a09c357ed57 100644 --- a/sdks/java/testing/nexmark/build.gradle +++ b/sdks/java/testing/nexmark/build.gradle @@ -119,11 +119,7 @@ def getNexmarkArgs = { } } else { def dataflowWorkerJar = project.findProperty('dataflowWorkerJar') ?: project(":runners:google-cloud-dataflow-java:worker").shadowJar.archivePath - // Provide job with a customizable worker jar. - // With legacy worker jar, containerImage is set to empty (i.e. to use the internal build). - // More context and discussions can be found in PR#6694. nexmarkArgsList.add("--dataflowWorkerJar=${dataflowWorkerJar}".toString()) - nexmarkArgsList.add('--workerHarnessContainerImage=') def nexmarkProfile = project.findProperty(nexmarkProfilingProperty) ?: "" if (nexmarkProfile.equals("true")) { diff --git a/sdks/java/transform-service/build.gradle b/sdks/java/transform-service/build.gradle index 91e6185152da..a50da5a9210a 100644 --- a/sdks/java/transform-service/build.gradle +++ b/sdks/java/transform-service/build.gradle @@ -48,3 +48,7 @@ dependencies { testImplementation library.java.mockito_core testImplementation project(path: ":runners:java-fn-execution") } + +shadowJar { + outputs.upToDateWhen { false } +} \ No newline at end of file diff --git a/sdks/python/apache_beam/__init__.py b/sdks/python/apache_beam/__init__.py index 6e08083bc0de..af88934b0e71 100644 --- a/sdks/python/apache_beam/__init__.py +++ b/sdks/python/apache_beam/__init__.py @@ -70,17 +70,11 @@ import warnings if sys.version_info.major == 3: - if sys.version_info.minor <= 7 or sys.version_info.minor >= 13: + if sys.version_info.minor <= 8 or sys.version_info.minor >= 13: warnings.warn( 'This version of Apache Beam has not been sufficiently tested on ' 'Python %s.%s. You may encounter bugs or missing features.' % (sys.version_info.major, sys.version_info.minor)) - elif sys.version_info.minor == 8: - warnings.warn( - 'Python 3.8 reaches EOL in October 2024 and support will ' - 'be removed from Apache Beam in version 2.61.0. See ' - 'https://github.com/apache/beam/issues/31192 for more ' - 'information.') pass else: raise RuntimeError( diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index ff5fb5bef7ac..5262e6adf8a6 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -500,7 +500,9 @@ def encode_special_deterministic(self, value, stream): self.encode_to_stream(value.value, stream, True) except Exception as e: raise TypeError(self._deterministic_encoding_error_msg(value)) from e - elif hasattr(value, "__getstate__"): + elif (hasattr(value, "__getstate__") and + # https://github.com/apache/beam/issues/33020 + type(value).__reduce__ == object.__reduce__): if not hasattr(value, "__setstate__"): raise TypeError( "Unable to deterministically encode '%s' of type '%s', " @@ -1975,7 +1977,7 @@ class DecimalCoderImpl(StreamCoderImpl): def encode_to_stream(self, value, out, nested): # type: (decimal.Decimal, create_OutputStream, bool) -> None - scale = -value.as_tuple().exponent + scale = -value.as_tuple().exponent # type: ignore[operator] int_value = int(value.scaleb(scale)) out.write_var_int64(scale) self.BIG_INT_CODER_IMPL.encode_to_stream(int_value, out, nested) diff --git a/sdks/python/apache_beam/coders/coders_property_based_test.py b/sdks/python/apache_beam/coders/coders_property_based_test.py index be18dd3586b0..9279fc31c099 100644 --- a/sdks/python/apache_beam/coders/coders_property_based_test.py +++ b/sdks/python/apache_beam/coders/coders_property_based_test.py @@ -141,7 +141,7 @@ def test_row_coder(self, data: st.DataObject): coders_registry.register_coder(RowType, RowCoder) # TODO(https://github.com/apache/beam/issues/23002): Apply nulls for these - row = RowType( # type: ignore + row = RowType( **{ name: data.draw(SCHEMA_TYPES_TO_STRATEGY[type_]) for name, diff --git a/sdks/python/apache_beam/coders/coders_test.py b/sdks/python/apache_beam/coders/coders_test.py index 4d8e8fe9bcb8..dc9780e36be3 100644 --- a/sdks/python/apache_beam/coders/coders_test.py +++ b/sdks/python/apache_beam/coders/coders_test.py @@ -23,11 +23,13 @@ import proto import pytest +import apache_beam as beam from apache_beam import typehints from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message from apache_beam.coders import coders from apache_beam.coders.avro_record import AvroRecord from apache_beam.coders.typecoders import registry as coders_registry +from apache_beam.testing.test_pipeline import TestPipeline class PickleCoderTest(unittest.TestCase): @@ -242,6 +244,20 @@ def test_to_type_hint(self): assert coder.to_type_hint() is bytes +class NumpyIntAsKeyTest(unittest.TestCase): + def test_numpy_int(self): + # this type is not supported as the key + import numpy as np + + with self.assertRaises(TypeError): + with TestPipeline() as p: + indata = p | "Create" >> beam.Create([(a, int(a)) + for a in np.arange(3)]) + + # Apply CombinePerkey to sum values for each key. + _ = indata | "CombinePerKey" >> beam.CombinePerKey(sum) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 7dcfae83f10e..4bd9698dd57b 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -53,7 +53,7 @@ except ImportError: dataclasses = None # type: ignore -MyNamedTuple = collections.namedtuple('A', ['x', 'y']) +MyNamedTuple = collections.namedtuple('A', ['x', 'y']) # type: ignore[name-match] MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)]) diff --git a/sdks/python/apache_beam/dataframe/convert.py b/sdks/python/apache_beam/dataframe/convert.py index e44cc429eac1..c5a0d1025c6d 100644 --- a/sdks/python/apache_beam/dataframe/convert.py +++ b/sdks/python/apache_beam/dataframe/convert.py @@ -17,11 +17,9 @@ import inspect import warnings import weakref +from collections.abc import Iterable from typing import Any -from typing import Dict -from typing import Iterable from typing import Optional -from typing import Tuple from typing import Union import pandas as pd @@ -172,7 +170,7 @@ def to_pcollection( always_return_tuple=False, yield_elements='schemas', include_indexes=False, - pipeline=None) -> Union[pvalue.PCollection, Tuple[pvalue.PCollection, ...]]: + pipeline=None) -> Union[pvalue.PCollection, tuple[pvalue.PCollection, ...]]: """Converts one or more deferred dataframe-like objects back to a PCollection. This method creates and applies the actual Beam operations that compute @@ -252,7 +250,7 @@ def extract_input(placeholder): df for df in dataframes if df._expr._id not in TO_PCOLLECTION_CACHE ] if len(new_dataframes): - new_results: Dict[Any, pvalue.PCollection] = { + new_results: dict[Any, pvalue.PCollection] = { p: extract_input(p) for p in placeholders } | label >> transforms._DataframeExpressionsTransform( diff --git a/sdks/python/apache_beam/dataframe/doctests.py b/sdks/python/apache_beam/dataframe/doctests.py index 33faa6b58599..57ee8009ba44 100644 --- a/sdks/python/apache_beam/dataframe/doctests.py +++ b/sdks/python/apache_beam/dataframe/doctests.py @@ -45,8 +45,6 @@ import traceback from io import StringIO from typing import Any -from typing import Dict -from typing import List import numpy as np import pandas as pd @@ -146,7 +144,7 @@ class _InMemoryResultRecorder(object): """ # Class-level value to survive pickling. - _ALL_RESULTS = {} # type: Dict[str, List[Any]] + _ALL_RESULTS: dict[str, list[Any]] = {} def __init__(self): self._id = id(self) @@ -729,15 +727,15 @@ def wrapper(fn): Args: optionflags (int): Passed through to doctests. - extraglobs (Dict[str,Any]): Passed through to doctests. + extraglobs (dict[str,Any]): Passed through to doctests. use_beam (bool): If true, run a Beam pipeline with partitioned input to verify the examples, else use PartitioningSession to simulate distributed execution. - skip (Dict[str,str]): A set of examples to skip entirely. + skip (dict[str,str]): A set of examples to skip entirely. If a key is '*', an example will be skipped in all test scenarios. - wont_implement_ok (Dict[str,str]): A set of examples that are allowed to + wont_implement_ok (dict[str,str]): A set of examples that are allowed to raise WontImplementError. - not_implemented_ok (Dict[str,str]): A set of examples that are allowed to + not_implemented_ok (dict[str,str]): A set of examples that are allowed to raise NotImplementedError. Returns: diff --git a/sdks/python/apache_beam/dataframe/expressions.py b/sdks/python/apache_beam/dataframe/expressions.py index 91d237c7de96..2ef172b8dad3 100644 --- a/sdks/python/apache_beam/dataframe/expressions.py +++ b/sdks/python/apache_beam/dataframe/expressions.py @@ -17,10 +17,10 @@ import contextlib import random import threading +from collections.abc import Callable +from collections.abc import Iterable from typing import Any -from typing import Callable from typing import Generic -from typing import Iterable from typing import Optional from typing import TypeVar @@ -36,12 +36,12 @@ class Session(object): def __init__(self, bindings=None): self._bindings = dict(bindings or {}) - def evaluate(self, expr): # type: (Expression) -> Any + def evaluate(self, expr: 'Expression') -> Any: if expr not in self._bindings: self._bindings[expr] = expr.evaluate_at(self) return self._bindings[expr] - def lookup(self, expr): # type: (Expression) -> Any + def lookup(self, expr: 'Expression') -> Any: return self._bindings[expr] @@ -251,9 +251,9 @@ def preserves_partition_by(self) -> partitionings.Partitioning: class PlaceholderExpression(Expression): """An expression whose value must be explicitly bound in the session.""" def __init__( - self, # type: PlaceholderExpression - proxy, # type: T - reference=None, # type: Any + self, + proxy: T, + reference: Any = None, ): """Initialize a placeholder expression. @@ -282,11 +282,7 @@ def preserves_partition_by(self): class ConstantExpression(Expression): """An expression whose value is known at pipeline construction time.""" - def __init__( - self, # type: ConstantExpression - value, # type: T - proxy=None # type: Optional[T] - ): + def __init__(self, value: T, proxy: Optional[T] = None): """Initialize a constant expression. Args: @@ -319,14 +315,15 @@ def preserves_partition_by(self): class ComputedExpression(Expression): """An expression whose value must be computed at pipeline execution time.""" def __init__( - self, # type: ComputedExpression - name, # type: str - func, # type: Callable[...,T] - args, # type: Iterable[Expression] - proxy=None, # type: Optional[T] - _id=None, # type: Optional[str] - requires_partition_by=partitionings.Index(), # type: partitionings.Partitioning - preserves_partition_by=partitionings.Singleton(), # type: partitionings.Partitioning + self, + name: str, + func: Callable[..., T], + args: Iterable[Expression], + proxy: Optional[T] = None, + _id: Optional[str] = None, + requires_partition_by: partitionings.Partitioning = partitionings.Index(), + preserves_partition_by: partitionings.Partitioning = partitionings. + Singleton(), ): """Initialize a computed expression. diff --git a/sdks/python/apache_beam/dataframe/frame_base.py b/sdks/python/apache_beam/dataframe/frame_base.py index 90f34d45dd98..8e206fc5e037 100644 --- a/sdks/python/apache_beam/dataframe/frame_base.py +++ b/sdks/python/apache_beam/dataframe/frame_base.py @@ -17,15 +17,13 @@ import functools import operator import re +from collections.abc import Callable from inspect import cleandoc from inspect import getfullargspec from inspect import isclass from inspect import ismodule from inspect import unwrap from typing import Any -from typing import Callable -from typing import Dict -from typing import List from typing import Optional from typing import Tuple from typing import Union @@ -38,7 +36,7 @@ class DeferredBase(object): - _pandas_type_map: Dict[Union[type, None], type] = {} + _pandas_type_map: dict[Union[type, None], type] = {} def __init__(self, expr): self._expr = expr @@ -229,7 +227,7 @@ def _elementwise_function( def _proxy_function( func: Union[Callable, str], name: Optional[str] = None, - restrictions: Optional[Dict[str, Union[Any, List[Any]]]] = None, + restrictions: Optional[dict[str, Union[Any, list[Any]]]] = None, inplace: bool = False, base: Optional[type] = None, *, @@ -606,6 +604,21 @@ def wrap(func): " :skipif: True"), re.sub(r"^", " ", content, flags=re.MULTILINE), ]) + elif "Examples" in content and ">>>" in content: + # some new examples don't have the correct heading + # this catches those examples + split_content = content.split("Examples") + content = '\n\n'.join([ + split_content[0], + "Examples\n", + # Indent the code snippet under a doctest heading, + # add skipif option. This makes sure our doctest + # framework doesn't run these pandas tests. + (".. doctest::\n" + " :skipif: True"), + re.sub(r"^", " ", content, flags=re.MULTILINE), + split_content[1] + ]) else: content = content.replace('DataFrame', 'DeferredDataFrame').replace( 'Series', 'DeferredSeries') diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py index 421430ec972c..ccd01f35f87b 100644 --- a/sdks/python/apache_beam/dataframe/frames.py +++ b/sdks/python/apache_beam/dataframe/frames.py @@ -38,7 +38,6 @@ import math import re import warnings -from typing import List from typing import Optional import numpy as np @@ -2660,7 +2659,7 @@ def get(self, key, default_value=None): @frame_base.populate_defaults(pd.DataFrame) @frame_base.maybe_inplace def set_index(self, keys, **kwargs): - """``keys`` must be a ``str`` or ``List[str]``. Passing an Index or Series + """``keys`` must be a ``str`` or ``list[str]``. Passing an Index or Series is not yet supported (`Issue 20759 `_).""" if isinstance(keys, str): @@ -4574,7 +4573,7 @@ def value_counts(self, **kwargs): tshift = frame_base.wont_implement_method( DataFrameGroupBy, 'tshift', reason="deprecated") -def _maybe_project_func(projection: Optional[List[str]]): +def _maybe_project_func(projection: Optional[list[str]]): """ Returns identity func if projection is empty or None, else returns a function that projects the specified columns. """ if projection: @@ -4967,7 +4966,7 @@ def func(*args): else: raise frame_base.WontImplementError( - "others must be None, DeferredSeries, or List[DeferredSeries] " + "others must be None, DeferredSeries, or list[DeferredSeries] " f"(encountered {type(others)}). Other types are not supported " "because they make this operation sensitive to the order of the " "data.", reason="order-sensitive") diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py index 076ab504adde..f99b77e446a8 100644 --- a/sdks/python/apache_beam/dataframe/frames_test.py +++ b/sdks/python/apache_beam/dataframe/frames_test.py @@ -18,7 +18,6 @@ import sys import unittest import warnings -from typing import Dict import numpy as np import pandas as pd @@ -1025,6 +1024,17 @@ def test_series_fillna_series_as_value(self): self._run_test(lambda df, df2: df.A.fillna(df2.A), df, df2) + def test_dataframe_column_fillna_constant_as_value(self): + from apache_beam.dataframe import convert + from apache_beam.testing.util import assert_that + from apache_beam.testing.util import equal_to + with beam.Pipeline(None) as p: + pcoll = ( + p | beam.Create([1.0, np.nan, -1.0]) | beam.Select(x=lambda x: x)) + df = convert.to_dataframe(pcoll) + df_new = df['x'].fillna(0) + assert_that(convert.to_pcollection(df_new), equal_to([1.0, 0.0, -1.0])) + @unittest.skipIf(PD_VERSION >= (2, 0), 'append removed in Pandas 2.0') def test_append_verify_integrity(self): df1 = pd.DataFrame({'A': range(10), 'B': range(10)}, index=range(10)) @@ -1696,7 +1706,7 @@ def test_pivot_no_index_provided_on_multiindex(self): 'describe')) -def numeric_only_kwargs_for_pandas_2(agg_type: str) -> Dict[str, bool]: +def numeric_only_kwargs_for_pandas_2(agg_type: str) -> dict[str, bool]: """Get proper arguments for numeric_only. Behavior for numeric_only in these methods changed in Pandas 2 to default diff --git a/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py b/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py index ce36dbeb09ad..a8139675ad39 100644 --- a/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py +++ b/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py @@ -18,7 +18,7 @@ """ import re -from typing import Mapping +from collections.abc import Mapping import pandas as pd diff --git a/sdks/python/apache_beam/dataframe/partitionings.py b/sdks/python/apache_beam/dataframe/partitionings.py index 0ff09e111480..1fe760fe8589 100644 --- a/sdks/python/apache_beam/dataframe/partitionings.py +++ b/sdks/python/apache_beam/dataframe/partitionings.py @@ -15,9 +15,8 @@ # limitations under the License. import random +from collections.abc import Iterable from typing import Any -from typing import Iterable -from typing import Tuple from typing import TypeVar import numpy as np @@ -47,7 +46,7 @@ def __le__(self, other): return not self.is_subpartitioning_of(other) def partition_fn(self, df: Frame, - num_partitions: int) -> Iterable[Tuple[Any, Frame]]: + num_partitions: int) -> Iterable[tuple[Any, Frame]]: """A callable that actually performs the partitioning of a Frame df. This will be invoked via a FlatMap in conjunction with a GroupKey to diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py index e70229f21f77..f849ab11e77c 100644 --- a/sdks/python/apache_beam/dataframe/schemas.py +++ b/sdks/python/apache_beam/dataframe/schemas.py @@ -24,12 +24,10 @@ # pytype: skip-file import warnings +from collections.abc import Sequence from typing import Any -from typing import Dict from typing import NamedTuple from typing import Optional -from typing import Sequence -from typing import Tuple from typing import TypeVar from typing import Union @@ -170,7 +168,7 @@ def element_typehint_from_dataframe_proxy( fields = [(column, dtype_to_fieldtype(dtype)) for (column, dtype) in output_columns] - field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] + field_options: Optional[dict[str, Sequence[tuple[str, Any]]]] if include_indexes: field_options = { index_name: [(INDEX_OPTION_NAME, None)] diff --git a/sdks/python/apache_beam/dataframe/transforms.py b/sdks/python/apache_beam/dataframe/transforms.py index 852b49c4e2ed..59e5eec05d2f 100644 --- a/sdks/python/apache_beam/dataframe/transforms.py +++ b/sdks/python/apache_beam/dataframe/transforms.py @@ -16,12 +16,8 @@ import collections import logging -from typing import TYPE_CHECKING +from collections.abc import Mapping from typing import Any -from typing import Dict -from typing import List -from typing import Mapping -from typing import Tuple from typing import TypeVar from typing import Union @@ -30,18 +26,16 @@ import apache_beam as beam from apache_beam import transforms from apache_beam.dataframe import expressions +from apache_beam.dataframe import frame_base from apache_beam.dataframe import frames # pylint: disable=unused-import from apache_beam.dataframe import partitionings +from apache_beam.pvalue import PCollection from apache_beam.utils import windowed_value __all__ = [ 'DataframeTransform', ] -if TYPE_CHECKING: - # pylint: disable=ungrouped-imports - from apache_beam.pvalue import PCollection - T = TypeVar('T') TARGET_PARTITION_SIZE = 1 << 23 # 8M @@ -108,15 +102,15 @@ def expand(self, input_pcolls): from apache_beam.dataframe import convert # Convert inputs to a flat dict. - input_dict = _flatten(input_pcolls) # type: Dict[Any, PCollection] + input_dict: dict[Any, PCollection] = _flatten(input_pcolls) proxies = _flatten(self._proxy) if self._proxy is not None else { tag: None for tag in input_dict } - input_frames = { + input_frames: dict[Any, frame_base.DeferredFrame] = { k: convert.to_dataframe(pc, proxies[k]) for k, pc in input_dict.items() - } # type: Dict[Any, DeferredFrame] # noqa: F821 + } # noqa: F821 # Apply the function. frames_input = _substitute(input_pcolls, input_frames) @@ -152,9 +146,9 @@ def expand(self, inputs): def _apply_deferred_ops( self, - inputs, # type: Dict[expressions.Expression, PCollection] - outputs, # type: Dict[Any, expressions.Expression] - ): # -> Dict[Any, PCollection] + inputs: dict[expressions.Expression, PCollection], + outputs: dict[Any, expressions.Expression], + ) -> dict[Any, PCollection]: """Construct a Beam graph that evaluates a set of expressions on a set of input PCollections. @@ -395,7 +389,11 @@ def expr_to_stages(expr): if stage is None: # No stage available, compute this expression as part of a new stage. - stage = Stage(expr.args(), expr.requires_partition_by()) + stage = Stage([ + arg for arg in expr.args() + if not isinstance(arg, expressions.ConstantExpression) + ], + expr.requires_partition_by()) for arg in expr.args(): # For each argument, declare that it is also available in # this new stage. @@ -581,11 +579,9 @@ def _concat(parts): def _flatten( - valueish, # type: Union[T, List[T], Tuple[T], Dict[Any, T]] - root=(), # type: Tuple[Any, ...] - ): - # type: (...) -> Mapping[Tuple[Any, ...], T] - + valueish: Union[T, list[T], tuple[T], dict[Any, T]], + root: tuple[Any, ...] = (), +) -> Mapping[tuple[Any, ...], T]: """Given a nested structure of dicts, tuples, and lists, return a flat dictionary where the values are the leafs and the keys are the "paths" to these leaves. diff --git a/sdks/python/apache_beam/examples/complete/estimate_pi.py b/sdks/python/apache_beam/examples/complete/estimate_pi.py index 089767d2a99e..530a270308d9 100644 --- a/sdks/python/apache_beam/examples/complete/estimate_pi.py +++ b/sdks/python/apache_beam/examples/complete/estimate_pi.py @@ -30,9 +30,8 @@ import json import logging import random +from collections.abc import Iterable from typing import Any -from typing import Iterable -from typing import Tuple import apache_beam as beam from apache_beam.io import WriteToText @@ -40,7 +39,7 @@ from apache_beam.options.pipeline_options import SetupOptions -@beam.typehints.with_output_types(Tuple[int, int, int]) +@beam.typehints.with_output_types(tuple[int, int, int]) @beam.typehints.with_input_types(int) def run_trials(runs): """Run trials and return a 3-tuple representing the results. @@ -62,8 +61,8 @@ def run_trials(runs): return runs, inside_runs, 0 -@beam.typehints.with_output_types(Tuple[int, int, float]) -@beam.typehints.with_input_types(Iterable[Tuple[int, int, Any]]) +@beam.typehints.with_output_types(tuple[int, int, float]) +@beam.typehints.with_input_types(Iterable[tuple[int, int, Any]]) def combine_results(results): """Combiner function to sum up trials and compute the estimate. diff --git a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py index 6b5573aa4569..0a8c55d17d3a 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py @@ -25,7 +25,6 @@ import unittest import uuid from typing import TYPE_CHECKING -from typing import List import pytest import pytz @@ -53,7 +52,7 @@ if TYPE_CHECKING: import google.cloud.bigtable.instance -EXISTING_INSTANCES: List['google.cloud.bigtable.instance.Instance'] = [] +EXISTING_INSTANCES: list['google.cloud.bigtable.instance.Instance'] = [] LABEL_KEY = 'python-bigtable-beam' label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC) label_stamp_micros = _microseconds_from_datetime(label_stamp) diff --git a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py index 65ea7990a2d8..9d71ac32aff2 100644 --- a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py +++ b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py @@ -59,7 +59,7 @@ import logging import re import sys -from typing import Iterable +from collections.abc import Iterable from typing import Optional from typing import Text import uuid diff --git a/sdks/python/apache_beam/examples/cookbook/group_with_coder.py b/sdks/python/apache_beam/examples/cookbook/group_with_coder.py index 3ce7836b491a..8a959138d3da 100644 --- a/sdks/python/apache_beam/examples/cookbook/group_with_coder.py +++ b/sdks/python/apache_beam/examples/cookbook/group_with_coder.py @@ -30,7 +30,6 @@ import argparse import logging import sys -import typing import apache_beam as beam from apache_beam import coders @@ -71,7 +70,7 @@ def is_deterministic(self): # Annotate the get_players function so that the typehint system knows that the # input to the CombinePerKey operation is a key-value pair of a Player object # and an integer. -@with_output_types(typing.Tuple[Player, int]) +@with_output_types(tuple[Player, int]) def get_players(descriptor): name, points = descriptor.split(',') return Player(name), int(points) diff --git a/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py index 5eb57c8fc080..69c2eacc593d 100644 --- a/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py +++ b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py @@ -27,10 +27,8 @@ import argparse import logging -from typing import Dict -from typing import Iterable -from typing import Iterator -from typing import Tuple +from collections.abc import Iterable +from collections.abc import Iterator import apache_beam as beam import torch @@ -45,14 +43,14 @@ from transformers import AutoTokenizer -def add_mask_to_last_word(text: str) -> Tuple[str, str]: +def add_mask_to_last_word(text: str) -> tuple[str, str]: text_list = text.split() return text, ' '.join(text_list[:-2] + ['', text_list[-1]]) def tokenize_sentence( - text_and_mask: Tuple[str, str], - tokenizer: AutoTokenizer) -> Tuple[str, Dict[str, torch.Tensor]]: + text_and_mask: tuple[str, str], + tokenizer: AutoTokenizer) -> tuple[str, dict[str, torch.Tensor]]: text, masked_text = text_and_mask tokenized_sentence = tokenizer.encode_plus(masked_text, return_tensors="pt") @@ -81,7 +79,7 @@ def __init__(self, tokenizer: AutoTokenizer): super().__init__() self.tokenizer = tokenizer - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction_result = element inputs = prediction_result.example logits = prediction_result.inference['logits'] diff --git a/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py index 9005ea5d11d7..7d4899cc38d9 100644 --- a/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py +++ b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py @@ -28,8 +28,7 @@ import argparse import logging -from typing import Iterable -from typing import Tuple +from collections.abc import Iterable import apache_beam as beam from apache_beam.ml.inference.base import KeyedModelHandler @@ -49,7 +48,7 @@ class PostProcessor(beam.DoFn): Hugging Face Pipeline for Question Answering returns a dictionary with score, start and end index of answer and the answer. """ - def process(self, result: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, result: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction = result predicted_answer = prediction.inference['answer'] yield text + ';' + predicted_answer diff --git a/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py b/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py index 18f697f673bf..0e62ab865431 100644 --- a/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py +++ b/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py @@ -28,9 +28,8 @@ import argparse import logging -from typing import Iterable -from typing import Iterator -from typing import Tuple +from collections.abc import Iterable +from collections.abc import Iterator import numpy as np @@ -47,7 +46,7 @@ def tokenize_sentence(text: str, - tokenizer: RobertaTokenizer) -> Tuple[str, torch.Tensor]: + tokenizer: RobertaTokenizer) -> tuple[str, torch.Tensor]: tokenized_sentence = tokenizer.encode(text, add_special_tokens=True) # Workaround to manually remove batch dim until we have the feature to @@ -63,7 +62,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: class PostProcessor(beam.DoFn): - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction = np.argmax(prediction_result.inference, axis=0) yield filename + ';' + str(prediction) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py index d627001bcb82..c24a6d0a910e 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py @@ -21,9 +21,8 @@ import io import logging import os -from typing import Iterator +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -41,7 +40,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -122,13 +121,13 @@ def run( model_class = models.mobilenet_v2 model_params = {'num_classes': 1000} - def preprocess(image_name: str) -> Tuple[str, torch.Tensor]: + def preprocess(image_name: str) -> tuple[str, torch.Tensor]: image_name, image = read_image( image_file_name=image_name, path_to_dir=known_args.images_dir) return (image_name, preprocess_image(image)) - def postprocess(element: Tuple[str, PredictionResult]) -> str: + def postprocess(element: tuple[str, PredictionResult]) -> str: filename, prediction_result = element prediction = torch.argmax(prediction_result.inference, dim=0) return filename + ',' + str(prediction.item()) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py index 2a4e6e9a9bc6..787341263fde 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py @@ -62,10 +62,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -84,7 +83,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -116,7 +115,7 @@ class PostProcessor(beam.DoFn): Return filename, prediction and the model id used to perform the prediction """ - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction = torch.argmax(prediction_result.inference, dim=0) yield filename, prediction, prediction_result.model_id diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py index cdecb826d6e3..5e5f77a679c3 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py @@ -21,10 +21,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -138,7 +137,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -161,7 +160,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: class PostProcessor(beam.DoFn): - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction_labels = prediction_result.inference['labels'] classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels] diff --git a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py index 9de10e73e11b..a616998d2c73 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py @@ -26,10 +26,8 @@ import argparse import logging -from typing import Dict -from typing import Iterable -from typing import Iterator -from typing import Tuple +from collections.abc import Iterable +from collections.abc import Iterator import apache_beam as beam import torch @@ -45,14 +43,14 @@ from transformers import BertTokenizer -def add_mask_to_last_word(text: str) -> Tuple[str, str]: +def add_mask_to_last_word(text: str) -> tuple[str, str]: text_list = text.split() return text, ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]]) def tokenize_sentence( - text_and_mask: Tuple[str, str], - bert_tokenizer: BertTokenizer) -> Tuple[str, Dict[str, torch.Tensor]]: + text_and_mask: tuple[str, str], + bert_tokenizer: BertTokenizer) -> tuple[str, dict[str, torch.Tensor]]: text, masked_text = text_and_mask tokenized_sentence = bert_tokenizer.encode_plus( masked_text, return_tensors="pt") @@ -84,7 +82,7 @@ def __init__(self, bert_tokenizer: BertTokenizer): super().__init__() self.bert_tokenizer = bert_tokenizer - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction_result = element inputs = prediction_result.example logits = prediction_result.inference['logits'] diff --git a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py index f0b5462d5335..18c4c3e653b4 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py @@ -24,10 +24,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -143,7 +142,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -168,15 +167,15 @@ def filter_empty_lines(text: str) -> Iterator[str]: class KeyExamplesForEachModelType(beam.DoFn): """Duplicate data to run against each model type""" def process( - self, element: Tuple[torch.Tensor, - str]) -> Iterable[Tuple[str, torch.Tensor]]: + self, element: tuple[torch.Tensor, + str]) -> Iterable[tuple[str, torch.Tensor]]: yield 'v1', element[0] yield 'v2', element[0] class PostProcessor(beam.DoFn): def process( - self, element: Tuple[str, PredictionResult]) -> Tuple[torch.Tensor, str]: + self, element: tuple[str, PredictionResult]) -> tuple[torch.Tensor, str]: model, prediction_result = element prediction_labels = prediction_result.inference['labels'] classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels] diff --git a/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py b/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py index a6e4dc2bdb03..755eff17c163 100644 --- a/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py +++ b/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py @@ -22,9 +22,9 @@ import argparse import logging import time -from typing import Iterable +from collections.abc import Iterable +from collections.abc import Sequence from typing import Optional -from typing import Sequence import apache_beam as beam from apache_beam.ml.inference import base diff --git a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py index 3aa2f362fa64..0a527e88dec2 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py @@ -31,7 +31,7 @@ import argparse import os -from typing import Iterable +from collections.abc import Iterable import pandas diff --git a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py index 5392cdf7ddae..d7d08e294e9d 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py @@ -27,9 +27,7 @@ import argparse import logging import os -from typing import Iterable -from typing import List -from typing import Tuple +from collections.abc import Iterable import apache_beam as beam from apache_beam.ml.inference.base import KeyedModelHandler @@ -42,7 +40,7 @@ from apache_beam.runners.runner import PipelineResult -def process_input(row: str) -> Tuple[int, List[int]]: +def process_input(row: str) -> tuple[int, list[int]]: data = row.split(',') label, pixels = int(data[0]), data[1:] pixels = [int(pixel) for pixel in pixels] @@ -53,7 +51,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ - def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = prediction_result.inference yield '{},{}'.format(label, prediction) diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py index a0f249dcfbf0..b44d775f4ad3 100644 --- a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py @@ -17,8 +17,8 @@ import argparse import logging -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator import numpy diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py index 6cf746e77cd2..bf85bb1aef16 100644 --- a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py +++ b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py @@ -17,8 +17,7 @@ import argparse import logging -from typing import Iterable -from typing import Tuple +from collections.abc import Iterable import numpy @@ -33,7 +32,7 @@ from apache_beam.runners.runner import PipelineResult -def process_input(row: str) -> Tuple[int, numpy.ndarray]: +def process_input(row: str) -> tuple[int, numpy.ndarray]: data = row.split(',') label, pixels = int(data[0]), data[1:] pixels = [int(pixel) for pixel in pixels] @@ -46,7 +45,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ - def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = numpy.argmax(prediction_result.inference, axis=0) yield '{},{}'.format(label, prediction) diff --git a/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py b/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py index 1faf502c71af..677d36b9b767 100644 --- a/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py +++ b/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py @@ -22,9 +22,8 @@ import argparse import io import os -from typing import Iterable +from collections.abc import Iterable from typing import Optional -from typing import Tuple import numpy as np @@ -134,14 +133,14 @@ def attach_im_size_to_key( - data: Tuple[str, Image.Image]) -> Tuple[Tuple[str, int, int], Image.Image]: + data: tuple[str, Image.Image]) -> tuple[tuple[str, int, int], Image.Image]: filename, image = data width, height = image.size return ((filename, width, height), image) def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -168,7 +167,7 @@ class PostProcessor(beam.DoFn): an integer that we can transform into actual string class using COCO_OBJ_DET_CLASSES as reference. """ - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: key, prediction_result = element filename, im_width, im_height = key num_detections = prediction_result.inference[0] diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py index 09a70caa4ede..5df0b51e36d7 100644 --- a/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py @@ -32,10 +32,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import tensorflow as tf @@ -60,7 +59,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: def read_and_process_image( image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, tf.Tensor]: + path_to_dir: Optional[str] = None) -> tuple[str, tf.Tensor]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -97,7 +96,7 @@ def convert_image_to_example_proto(tensor: tf.Tensor) -> tf.train.Example: class ProcessInferenceToString(beam.DoFn): def process( - self, element: Tuple[str, + self, element: tuple[str, prediction_log_pb2.PredictionLog]) -> Iterable[str]: """ Args: diff --git a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py index 73126569e988..20312e7d3c88 100644 --- a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py @@ -27,9 +27,7 @@ import argparse import io import logging -from typing import Iterable -from typing import List -from typing import Tuple +from collections.abc import Iterable import apache_beam as beam import tensorflow as tf @@ -102,13 +100,13 @@ def parse_known_args(argv): COLUMNS = ['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses'] -def read_image(image_file_name: str) -> Tuple[str, bytes]: +def read_image(image_file_name: str) -> tuple[str, bytes]: with FileSystems().open(image_file_name, 'r') as file: data = io.BytesIO(file.read()).getvalue() return image_file_name, data -def preprocess_image(data: bytes) -> List[float]: +def preprocess_image(data: bytes) -> list[float]: """Preprocess the image, resizing it and normalizing it before converting to a list. """ @@ -119,7 +117,7 @@ def preprocess_image(data: bytes) -> List[float]: class PostProcessor(beam.DoFn): - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: img_name, prediction_result = element prediction_vals = prediction_result.inference index = prediction_vals.index(max(prediction_vals)) diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index 3cf7d04cb03e..2708c0f3d1a1 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -25,7 +25,7 @@ import argparse import logging -from typing import Iterable +from collections.abc import Iterable import apache_beam as beam from apache_beam.ml.inference.base import PredictionResult diff --git a/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py b/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py index 963187fd210d..498511a5a2cf 100644 --- a/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py +++ b/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py @@ -17,10 +17,8 @@ import argparse import logging -from typing import Callable -from typing import Iterable -from typing import List -from typing import Tuple +from collections.abc import Callable +from collections.abc import Iterable from typing import Union import numpy @@ -48,7 +46,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ - def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = prediction_result.inference yield '{},{}'.format(label, prediction) @@ -89,7 +87,7 @@ def parse_known_args(argv): def load_sklearn_iris_test_data( data_type: Callable, split: bool = True, - seed: int = 999) -> List[Union[numpy.array, pandas.DataFrame]]: + seed: int = 999) -> list[Union[numpy.array, pandas.DataFrame]]: """ Loads test data from the sklearn Iris dataset in a given format, either in a single or multiple batches. diff --git a/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py b/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py index 1cdd266c3df4..9b4889017077 100644 --- a/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py +++ b/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py @@ -26,7 +26,6 @@ import logging import sys -import typing import apache_beam as beam from apache_beam.io.kafka import ReadFromKafka @@ -97,7 +96,7 @@ def convert_kafka_record_to_dictionary(record): topic='projects/pubsub-public-data/topics/taxirides-realtime'). with_output_types(bytes) | beam.Map(lambda x: (b'', x)).with_output_types( - typing.Tuple[bytes, bytes]) # Kafka write transforms expects KVs. + tuple[bytes, bytes]) # Kafka write transforms expects KVs. | beam.WindowInto(beam.window.FixedWindows(window_size)) | WriteToKafka( producer_config={'bootstrap.servers': bootstrap_servers}, diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py index 715011d302d2..c849af4a00b3 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets.py +++ b/sdks/python/apache_beam/examples/snippets/snippets.py @@ -1143,6 +1143,60 @@ def model_multiple_pcollections_flatten(contents, output_path): merged | beam.io.WriteToText(output_path) +def model_multiple_pcollections_flatten_with(contents, output_path): + """Merging a PCollection with FlattenWith.""" + some_hash_fn = lambda s: ord(s[0]) + partition_fn = lambda element, partitions: some_hash_fn(element) % partitions + import apache_beam as beam + with TestPipeline() as pipeline: # Use TestPipeline for testing. + + # Partition into deciles + partitioned = pipeline | beam.Create(contents) | beam.Partition( + partition_fn, 3) + pcoll1 = partitioned[0] + pcoll2 = partitioned[1] + pcoll3 = partitioned[2] + SomeTransform = lambda: beam.Map(lambda x: x) + SomeOtherTransform = lambda: beam.Map(lambda x: x) + + # Flatten them back into 1 + + # A collection of PCollection objects can be represented simply + # as a tuple (or list) of PCollections. + # (The SDK for Python has no separate type to store multiple + # PCollection objects, whether containing the same or different + # types.) + # [START model_multiple_pcollections_flatten_with] + merged = ( + pcoll1 + | SomeTransform() + | beam.FlattenWith(pcoll2, pcoll3) + | SomeOtherTransform()) + # [END model_multiple_pcollections_flatten_with] + merged | beam.io.WriteToText(output_path) + + +def model_multiple_pcollections_flatten_with_transform(contents, output_path): + """Merging output of PTransform with FlattenWith.""" + some_hash_fn = lambda s: ord(s[0]) + partition_fn = lambda element, partitions: some_hash_fn(element) % partitions + import apache_beam as beam + with TestPipeline() as pipeline: # Use TestPipeline for testing. + + pcoll = pipeline | beam.Create(contents) + SomeTransform = lambda: beam.Map(lambda x: x) + SomeOtherTransform = lambda: beam.Map(lambda x: x) + + # [START model_multiple_pcollections_flatten_with_transform] + merged = ( + pcoll + | SomeTransform() + | beam.FlattenWith(beam.Create(['x', 'y', 'z'])) + | SomeOtherTransform()) + # [END model_multiple_pcollections_flatten_with_transform] + merged | beam.io.WriteToText(output_path) + + def model_multiple_pcollections_partition(contents, output_path): """Splitting a PCollection with Partition.""" some_hash_fn = lambda s: ord(s[0]) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index e8cb8960cf4d..54a57673b5f4 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -917,6 +917,19 @@ def test_model_multiple_pcollections_flatten(self): snippets.model_multiple_pcollections_flatten(contents, result_path) self.assertEqual(contents, self.get_output(result_path)) + def test_model_multiple_pcollections_flatten_with(self): + contents = ['a', 'b', 'c', 'd', 'e', 'f'] + result_path = self.create_temp_file() + snippets.model_multiple_pcollections_flatten_with(contents, result_path) + self.assertEqual(contents, self.get_output(result_path)) + + def test_model_multiple_pcollections_flatten_with_transform(self): + contents = ['a', 'b', 'c', 'd', 'e', 'f'] + result_path = self.create_temp_file() + snippets.model_multiple_pcollections_flatten_with_transform( + contents, result_path) + self.assertEqual(contents + ['x', 'y', 'z'], self.get_output(result_path)) + def test_model_multiple_pcollections_partition(self): contents = [17, 42, 64, 32, 0, 99, 53, 89] result_path = self.create_temp_file() diff --git a/sdks/python/apache_beam/examples/wordcount.py b/sdks/python/apache_beam/examples/wordcount.py index 31407aec6c40..a9138647581c 100644 --- a/sdks/python/apache_beam/examples/wordcount.py +++ b/sdks/python/apache_beam/examples/wordcount.py @@ -45,6 +45,7 @@ from apache_beam.io import WriteToText from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult class WordExtractingDoFn(beam.DoFn): @@ -63,7 +64,7 @@ def process(self, element): return re.findall(r'[\w\']+', element, re.UNICODE) -def run(argv=None, save_main_session=True): +def run(argv=None, save_main_session=True) -> PipelineResult: """Main entry point; defines and runs the wordcount pipeline.""" parser = argparse.ArgumentParser() parser.add_argument( @@ -83,27 +84,31 @@ def run(argv=None, save_main_session=True): pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = save_main_session - # The pipeline will be run on exiting the with block. - with beam.Pipeline(options=pipeline_options) as p: + pipeline = beam.Pipeline(options=pipeline_options) - # Read the text file[pattern] into a PCollection. - lines = p | 'Read' >> ReadFromText(known_args.input) + # Read the text file[pattern] into a PCollection. + lines = pipeline | 'Read' >> ReadFromText(known_args.input) - counts = ( - lines - | 'Split' >> (beam.ParDo(WordExtractingDoFn()).with_output_types(str)) - | 'PairWithOne' >> beam.Map(lambda x: (x, 1)) - | 'GroupAndSum' >> beam.CombinePerKey(sum)) + counts = ( + lines + | 'Split' >> (beam.ParDo(WordExtractingDoFn()).with_output_types(str)) + | 'PairWithOne' >> beam.Map(lambda x: (x, 1)) + | 'GroupAndSum' >> beam.CombinePerKey(sum)) - # Format the counts into a PCollection of strings. - def format_result(word, count): - return '%s: %d' % (word, count) + # Format the counts into a PCollection of strings. + def format_result(word, count): + return '%s: %d' % (word, count) - output = counts | 'Format' >> beam.MapTuple(format_result) + output = counts | 'Format' >> beam.MapTuple(format_result) - # Write the output using a "Write" transform that has side effects. - # pylint: disable=expression-not-assigned - output | 'Write' >> WriteToText(known_args.output) + # Write the output using a "Write" transform that has side effects. + # pylint: disable=expression-not-assigned + output | 'Write' >> WriteToText(known_args.output) + + # Execute the pipeline and return the result. + result = pipeline.run() + result.wait_until_finish() + return result if __name__ == '__main__': diff --git a/sdks/python/apache_beam/examples/wordcount_xlang_sql.py b/sdks/python/apache_beam/examples/wordcount_xlang_sql.py index 9d7d756f223f..632e90303010 100644 --- a/sdks/python/apache_beam/examples/wordcount_xlang_sql.py +++ b/sdks/python/apache_beam/examples/wordcount_xlang_sql.py @@ -24,7 +24,7 @@ import argparse import logging import re -import typing +from typing import NamedTuple import apache_beam as beam from apache_beam import coders @@ -41,7 +41,7 @@ # # Here we create and register a simple NamedTuple with a single str typed # field named 'word' which we will use below. -MyRow = typing.NamedTuple('MyRow', [('word', str)]) +MyRow = NamedTuple('MyRow', [('word', str)]) coders.registry.register_coder(MyRow, coders.RowCoder) diff --git a/sdks/python/apache_beam/internal/dill_pickler.py b/sdks/python/apache_beam/internal/dill_pickler.py index 7f7ac5b214fa..e1d6b7e74e49 100644 --- a/sdks/python/apache_beam/internal/dill_pickler.py +++ b/sdks/python/apache_beam/internal/dill_pickler.py @@ -46,9 +46,15 @@ settings = {'dill_byref': None} -if sys.version_info >= (3, 10) and dill.__version__ == "0.3.1.1": - # Let's make dill 0.3.1.1 support Python 3.11. +patch_save_code = sys.version_info >= (3, 10) and dill.__version__ == "0.3.1.1" + +def get_normalized_path(path): + """Returns a normalized path. This function is intended to be overridden.""" + return path + + +if patch_save_code: # The following function is based on 'save_code' from 'dill' # Author: Mike McKerns (mmckerns @caltech and @uqfoundation) # Copyright (c) 2008-2015 California Institute of Technology. @@ -66,6 +72,7 @@ @dill.register(CodeType) def save_code(pickler, obj): + co_filename = get_normalized_path(obj.co_filename) if hasattr(obj, "co_endlinetable"): # python 3.11a (20 args) args = ( obj.co_argcount, @@ -78,7 +85,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_qualname, obj.co_firstlineno, @@ -100,7 +107,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_qualname, obj.co_firstlineno, @@ -120,7 +127,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_firstlineno, obj.co_linetable, @@ -138,7 +145,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, diff --git a/sdks/python/apache_beam/internal/metrics/cells.py b/sdks/python/apache_beam/internal/metrics/cells.py index c7b546258a70..989dc7183045 100644 --- a/sdks/python/apache_beam/internal/metrics/cells.py +++ b/sdks/python/apache_beam/internal/metrics/cells.py @@ -28,7 +28,6 @@ from typing import TYPE_CHECKING from typing import Optional -from apache_beam.metrics.cells import MetricAggregator from apache_beam.metrics.cells import MetricCell from apache_beam.metrics.cells import MetricCellFactory from apache_beam.utils.histogram import Histogram @@ -50,10 +49,10 @@ class HistogramCell(MetricCell): """ def __init__(self, bucket_type): self._bucket_type = bucket_type - self.data = HistogramAggregator(bucket_type).identity_element() + self.data = HistogramData.identity_element(bucket_type) def reset(self): - self.data = HistogramAggregator(self._bucket_type).identity_element() + self.data = HistogramData.identity_element(self._bucket_type) def combine(self, other: 'HistogramCell') -> 'HistogramCell': result = HistogramCell(self._bucket_type) @@ -148,22 +147,6 @@ def combine(self, other: Optional['HistogramData']) -> 'HistogramData': return HistogramData(self.histogram.combine(other.histogram)) - -class HistogramAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. - - Aggregator for Histogram metric data during pipeline execution. - - Values aggregated should be ``HistogramData`` objects. - """ - def __init__(self, bucket_type: 'BucketType') -> None: - self._bucket_type = bucket_type - - def identity_element(self) -> HistogramData: - return HistogramData(Histogram(self._bucket_type)) - - def combine(self, x: HistogramData, y: HistogramData) -> HistogramData: - return x.combine(y) - - def result(self, x: HistogramData) -> HistogramResult: - return HistogramResult(x.get_cumulative()) + @staticmethod + def identity_element(bucket_type) -> 'HistogramData': + return HistogramData(Histogram(bucket_type)) diff --git a/sdks/python/apache_beam/internal/pickler_test.py b/sdks/python/apache_beam/internal/pickler_test.py index 824c4c59c0ce..c26a8ee3e653 100644 --- a/sdks/python/apache_beam/internal/pickler_test.py +++ b/sdks/python/apache_beam/internal/pickler_test.py @@ -94,6 +94,11 @@ def test_pickle_rlock(self): self.assertIsInstance(loads(dumps(rlock_instance)), rlock_type) + def test_save_paths(self): + f = loads(dumps(lambda x: x)) + co_filename = f.__code__.co_filename + self.assertTrue(co_filename.endswith('pickler_test.py')) + @unittest.skipIf(NO_MAPPINGPROXYTYPE, 'test if MappingProxyType introduced') def test_dump_and_load_mapping_proxy(self): self.assertEqual( diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 77b20117e702..2d25010da486 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -82,7 +82,7 @@ class AvroBase(object): - _temp_files = [] # type: List[str] + _temp_files: List[str] = [] def __init__(self, methodName='runTest'): super().__init__(methodName) diff --git a/sdks/python/apache_beam/io/aws/s3filesystem.py b/sdks/python/apache_beam/io/aws/s3filesystem.py index e181beac4a58..ffbce5893a96 100644 --- a/sdks/python/apache_beam/io/aws/s3filesystem.py +++ b/sdks/python/apache_beam/io/aws/s3filesystem.py @@ -315,10 +315,14 @@ def delete(self, paths): if exceptions: raise BeamIOError("Delete operation failed", exceptions) - def report_lineage(self, path, lineage): + def report_lineage(self, path, lineage, level=None): try: - components = s3io.parse_s3_path(path, get_account=True) + components = s3io.parse_s3_path(path, object_optional=True) except ValueError: # report lineage is fail-safe return + if level == FileSystem.LineageLevel.TOP_LEVEL or \ + (len(components) > 1 and components[-1] == ''): + # bucket only + components = components[:-1] lineage.add('s3', *components) diff --git a/sdks/python/apache_beam/io/aws/s3filesystem_test.py b/sdks/python/apache_beam/io/aws/s3filesystem_test.py index 60e6f319b2c9..87403f482bd2 100644 --- a/sdks/python/apache_beam/io/aws/s3filesystem_test.py +++ b/sdks/python/apache_beam/io/aws/s3filesystem_test.py @@ -265,6 +265,15 @@ def test_rename(self, unused_mock_arg): src_dest_pairs = list(zip(sources, destinations)) s3io_mock.rename_files.assert_called_once_with(src_dest_pairs) + def test_lineage(self): + self._verify_lineage("s3://bucket/", ("bucket", )) + self._verify_lineage("s3://bucket/foo/bar.txt", ("bucket", "foo/bar.txt")) + + def _verify_lineage(self, uri, expected_segments): + lineage_mock = mock.MagicMock() + self.fs.report_lineage(uri, lineage_mock) + lineage_mock.add.assert_called_once_with("s3", *expected_segments) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py index bb56fa09d370..4495245dc54a 100644 --- a/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py +++ b/sdks/python/apache_beam/io/azure/blobstoragefilesystem.py @@ -317,10 +317,15 @@ def delete(self, paths): if exceptions: raise BeamIOError("Delete operation failed", exceptions) - def report_lineage(self, path, lineage): + def report_lineage(self, path, lineage, level=None): try: - components = blobstorageio.parse_azfs_path(path, get_account=True) + components = blobstorageio.parse_azfs_path( + path, blob_optional=True, get_account=True) except ValueError: # report lineage is fail-safe return + if level == FileSystem.LineageLevel.TOP_LEVEL \ + or(len(components) > 1 and components[-1] == ''): + # bucket only + components = components[:-1] lineage.add('abs', *components) diff --git a/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py b/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py index cee459f5b8a2..138fe5f78b20 100644 --- a/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py +++ b/sdks/python/apache_beam/io/azure/blobstoragefilesystem_test.py @@ -320,6 +320,18 @@ def test_rename(self, unused_mock_blobstorageio): src_dest_pairs = list(zip(sources, destinations)) blobstorageio_mock.rename_files.assert_called_once_with(src_dest_pairs) + def test_lineage(self): + self._verify_lineage( + "azfs://storageaccount/container/", ("storageaccount", "container")) + self._verify_lineage( + "azfs://storageaccount/container/foo/bar.txt", + ("storageaccount", "container", "foo/bar.txt")) + + def _verify_lineage(self, uri, expected_segments): + lineage_mock = mock.MagicMock() + self.fs.report_lineage(uri, lineage_mock) + lineage_mock.add.assert_called_once_with("abs", *expected_segments) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/io/external/xlang_bigqueryio_it_test.py b/sdks/python/apache_beam/io/external/xlang_bigqueryio_it_test.py index cfbb411b4e5f..7f3a16e02aa3 100644 --- a/sdks/python/apache_beam/io/external/xlang_bigqueryio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_bigqueryio_it_test.py @@ -245,6 +245,128 @@ def test_write_with_beam_rows(self): | StorageWriteToBigQuery(table=table_id)) hamcrest_assert(p, bq_matcher) + def test_write_with_beam_rows_cdc(self): + table = 'write_with_beam_rows_cdc' + table_id = '{}:{}.{}'.format(self.project, self.dataset_id, table) + + expected_data_on_bq = [ + # (name, value) + { + "name": "cdc_test", + "value": 5, + } + ] + + rows_with_cdc = [ + beam.Row( + row_mutation_info=beam.Row( + mutation_type="UPSERT", change_sequence_number="AAA/2"), + record=beam.Row(name="cdc_test", value=5)), + beam.Row( + row_mutation_info=beam.Row( + mutation_type="UPSERT", change_sequence_number="AAA/1"), + record=beam.Row(name="cdc_test", value=3)) + ] + + bq_matcher = BigqueryFullResultMatcher( + project=self.project, + query="SELECT * FROM {}.{}".format(self.dataset_id, table), + data=self.parse_expected_data(expected_data_on_bq)) + + with beam.Pipeline(argv=self.args) as p: + _ = ( + p + | beam.Create(rows_with_cdc) + | beam.io.WriteToBigQuery( + table=table_id, + method=beam.io.WriteToBigQuery.Method.STORAGE_WRITE_API, + use_at_least_once=True, + use_cdc_writes=True, + primary_key=["name"])) + hamcrest_assert(p, bq_matcher) + + def test_write_with_dicts_cdc(self): + table = 'write_with_dicts_cdc' + table_id = '{}:{}.{}'.format(self.project, self.dataset_id, table) + + expected_data_on_bq = [ + # (name, value) + { + "name": "cdc_test", + "value": 5, + } + ] + + data_with_cdc = [ + # record: (name, value) + { + 'row_mutation_info': { + 'mutation_type': 'UPSERT', 'change_sequence_number': 'AAA/2' + }, + 'record': { + 'name': 'cdc_test', 'value': 5 + } + }, + { + 'row_mutation_info': { + 'mutation_type': 'UPSERT', 'change_sequence_number': 'AAA/1' + }, + 'record': { + 'name': 'cdc_test', 'value': 3 + } + } + ] + + schema = { + "fields": [ + # include both record and mutation info fields as part of the schema + { + "name": "row_mutation_info", + "type": "STRUCT", + "fields": [ + # setting both fields are required + { + "name": "mutation_type", + "type": "STRING", + "mode": "REQUIRED" + }, + { + "name": "change_sequence_number", + "type": "STRING", + "mode": "REQUIRED" + } + ] + }, + { + "name": "record", + "type": "STRUCT", + "fields": [{ + "name": "name", "type": "STRING" + }, { + "name": "value", "type": "INTEGER" + }] + } + ] + } + + bq_matcher = BigqueryFullResultMatcher( + project=self.project, + query="SELECT * FROM {}.{}".format(self.dataset_id, table), + data=self.parse_expected_data(expected_data_on_bq)) + + with beam.Pipeline(argv=self.args) as p: + _ = ( + p + | beam.Create(data_with_cdc) + | beam.io.WriteToBigQuery( + table=table_id, + method=beam.io.WriteToBigQuery.Method.STORAGE_WRITE_API, + use_at_least_once=True, + use_cdc_writes=True, + schema=schema, + primary_key=["name"])) + hamcrest_assert(p, bq_matcher) + def test_write_to_dynamic_destinations(self): base_table_spec = '{}.dynamic_dest_'.format(self.dataset_id) spec_with_project = '{}:{}'.format(self.project, base_table_spec) diff --git a/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py b/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py index 7829baba8b69..abe9530787e8 100644 --- a/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py @@ -107,7 +107,7 @@ def start_db_container(self, retries): for i in range(retries): try: self.db = PostgresContainer( - 'debezium/example-postgres:latest', + 'quay.io/debezium/example-postgres:latest', user=self.username, password=self.password, dbname=self.database) diff --git a/sdks/python/apache_beam/io/filebasedsink.py b/sdks/python/apache_beam/io/filebasedsink.py index c708e117c3a1..f9d4303c8c78 100644 --- a/sdks/python/apache_beam/io/filebasedsink.py +++ b/sdks/python/apache_beam/io/filebasedsink.py @@ -280,9 +280,31 @@ def _check_state_for_finalize_write(self, writer_results, num_shards): src_files.append(src) dst_files.append(dst) - FileSystems.report_sink_lineage(dst) + + self._report_sink_lineage(dst_glob, dst_files) return src_files, dst_files, delete_files, num_skipped + def _report_sink_lineage(self, dst_glob, dst_files): + """ + Report sink Lineage. Report every file if number of files no more than 100, + otherwise only report at directory level. + """ + if len(dst_files) <= 100: + for dst in dst_files: + FileSystems.report_sink_lineage(dst) + else: + dst = dst_glob + # dst_glob has a wildcard for shard number (see _shard_name_template) + sep = dst_glob.find('*') + if sep > 0: + dst = dst[:sep] + try: + dst, _ = FileSystems.split(dst) + except ValueError: + return # lineage report is fail-safe + + FileSystems.report_sink_lineage(dst) + @check_accessible(['file_path_prefix']) def finalize_write( self, init_result, writer_results, unused_pre_finalize_results): diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index efd863810ed7..a02bc6de32c7 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -39,6 +39,7 @@ from apache_beam.io import range_trackers from apache_beam.io.filesystem import CompressionTypes from apache_beam.io.filesystem import FileMetadata +from apache_beam.io.filesystem import FileSystem from apache_beam.io.filesystems import FileSystems from apache_beam.io.restriction_trackers import OffsetRange from apache_beam.options.value_provider import StaticValueProvider @@ -168,10 +169,38 @@ def _get_concat_source(self) -> concat_source.ConcatSource: min_bundle_size=self._min_bundle_size, splittable=splittable) single_file_sources.append(single_file_source) - FileSystems.report_source_lineage(file_name) + + self._report_source_lineage(files_metadata) self._concat_source = concat_source.ConcatSource(single_file_sources) + return self._concat_source + def _report_source_lineage(self, files_metadata): + """ + Report source Lineage. depend on the number of files, report full file + name, only dir, or only top level + """ + if len(files_metadata) <= 100: + for file_metadata in files_metadata: + FileSystems.report_source_lineage(file_metadata.path) + else: + size_track = set() + for file_metadata in files_metadata: + if len(size_track) >= 100: + FileSystems.report_source_lineage( + file_metadata.path, level=FileSystem.LineageLevel.TOP_LEVEL) + return + + try: + base, _ = FileSystems.split(file_metadata.path) + except ValueError: + pass + else: + size_track.add(base) + + for base in size_track: + FileSystems.report_source_lineage(base) + def open_file(self, file_name): return FileSystems.open( file_name, @@ -343,6 +372,7 @@ def __init__( self._min_bundle_size = min_bundle_size self._splittable = splittable self._compression_type = compression_type + self._size_track = None def process(self, element: Union[str, FileMetadata], *args, **kwargs) -> Iterable[Tuple[FileMetadata, OffsetRange]]: @@ -352,7 +382,8 @@ def process(self, element: Union[str, FileMetadata], *args, match_results = FileSystems.match([element]) metadata_list = match_results[0].metadata_list for metadata in metadata_list: - FileSystems.report_source_lineage(metadata.path) + self._report_source_lineage(metadata.path) + splittable = ( self._splittable and _determine_splittability_from_compression_type( metadata.path, self._compression_type)) @@ -366,6 +397,28 @@ def process(self, element: Union[str, FileMetadata], *args, metadata, OffsetRange(0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY)) + def _report_source_lineage(self, path): + """ + Report source Lineage. Due to the size limit of Beam metrics, report full + file name or only top level depend on the number of files. + + * Number of files<=100, report full file paths; + + * Otherwise, report top level only. + """ + if self._size_track is None: + self._size_track = set() + elif len(self._size_track) == 0: + FileSystems.report_source_lineage( + path, level=FileSystem.LineageLevel.TOP_LEVEL) + return + + self._size_track.add(path) + FileSystems.report_source_lineage(path) + + if len(self._size_track) >= 100: + self._size_track.clear() + class _ReadRange(DoFn): def __init__( diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index d9b2a2040675..111206a18a28 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -94,7 +94,6 @@ import uuid from collections import namedtuple from functools import partial -from typing import TYPE_CHECKING from typing import Any from typing import BinaryIO # pylint: disable=unused-import from typing import Callable @@ -115,15 +114,13 @@ from apache_beam.options.value_provider import ValueProvider from apache_beam.transforms.periodicsequence import PeriodicImpulse from apache_beam.transforms.userstate import CombiningValueStateSpec +from apache_beam.transforms.window import BoundedWindow from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import IntervalWindow from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import Timestamp -if TYPE_CHECKING: - from apache_beam.transforms.window import BoundedWindow - __all__ = [ 'EmptyMatchTreatment', 'MatchFiles', @@ -382,8 +379,7 @@ def create_metadata( mime_type="application/octet-stream", compression_type=CompressionTypes.AUTO) - def open(self, fh): - # type: (BinaryIO) -> None + def open(self, fh: BinaryIO) -> None: raise NotImplementedError def write(self, record): @@ -575,8 +571,7 @@ class signature or an instance of FileSink to this parameter. If none is self._max_num_writers_per_bundle = max_writers_per_bundle @staticmethod - def _get_sink_fn(input_sink): - # type: (...) -> Callable[[Any], FileSink] + def _get_sink_fn(input_sink) -> Callable[[Any], FileSink]: if isinstance(input_sink, type) and issubclass(input_sink, FileSink): return lambda x: input_sink() elif isinstance(input_sink, FileSink): @@ -588,8 +583,7 @@ def _get_sink_fn(input_sink): return lambda x: TextSink() @staticmethod - def _get_destination_fn(destination): - # type: (...) -> Callable[[Any], str] + def _get_destination_fn(destination) -> Callable[[Any], str]: if isinstance(destination, ValueProvider): return lambda elm: destination.get() elif callable(destination): @@ -757,12 +751,8 @@ def _check_orphaned_files(self, writer_key): class _WriteShardedRecordsFn(beam.DoFn): - - def __init__(self, - base_path, - sink_fn, # type: Callable[[Any], FileSink] - shards # type: int - ): + def __init__( + self, base_path, sink_fn: Callable[[Any], FileSink], shards: int): self.base_path = base_path self.sink_fn = sink_fn self.shards = shards @@ -805,17 +795,13 @@ def process( class _AppendShardedDestination(beam.DoFn): - def __init__( - self, - destination, # type: Callable[[Any], str] - shards # type: int - ): + def __init__(self, destination: Callable[[Any], str], shards: int): self.destination_fn = destination self.shards = shards # We start the shards for a single destination at an arbitrary point. - self._shard_counter = collections.defaultdict( - lambda: random.randrange(self.shards)) # type: DefaultDict[str, int] + self._shard_counter: DefaultDict[str, int] = collections.defaultdict( + lambda: random.randrange(self.shards)) def _next_shard_for_destination(self, destination): self._shard_counter[destination] = ((self._shard_counter[destination] + 1) % @@ -835,8 +821,9 @@ class _WriteUnshardedRecordsFn(beam.DoFn): SPILLED_RECORDS = 'spilled_records' WRITTEN_FILES = 'written_files' - _writers_and_sinks = None # type: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, FileSink]] - _file_names = None # type: Dict[Tuple[str, BoundedWindow], str] + _writers_and_sinks: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, + FileSink]] = None + _file_names: Dict[Tuple[str, BoundedWindow], str] = None def __init__( self, diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py index bdc25dcf0fe5..840fdf3309e7 100644 --- a/sdks/python/apache_beam/io/filesystem.py +++ b/sdks/python/apache_beam/io/filesystem.py @@ -934,7 +934,11 @@ def delete(self, paths): """ raise NotImplementedError - def report_lineage(self, path, unused_lineage): + class LineageLevel: + FILE = 'FILE' + TOP_LEVEL = 'TOP_LEVEL' + + def report_lineage(self, path, unused_lineage, level=None): """ Report Lineage metrics for path. diff --git a/sdks/python/apache_beam/io/filesystems.py b/sdks/python/apache_beam/io/filesystems.py index ccbeac640765..87f45f3308ee 100644 --- a/sdks/python/apache_beam/io/filesystems.py +++ b/sdks/python/apache_beam/io/filesystems.py @@ -391,13 +391,27 @@ def get_chunk_size(path): return filesystem.CHUNK_SIZE @staticmethod - def report_source_lineage(path): - """Report source :class:`~apache_beam.metrics.metric.Lineage`.""" + def report_source_lineage(path, level=None): + """ + Report source :class:`~apache_beam.metrics.metric.LineageLevel`. + + Args: + path: string path to be reported. + level: the level of file path. default to + :class:`~apache_beam.io.filesystem.FileSystem.LineageLevel`.FILE. + """ filesystem = FileSystems.get_filesystem(path) - filesystem.report_lineage(path, Lineage.sources()) + filesystem.report_lineage(path, Lineage.sources(), level=level) @staticmethod - def report_sink_lineage(path): - """Report sink :class:`~apache_beam.metrics.metric.Lineage`.""" + def report_sink_lineage(path, level=None): + """ + Report sink :class:`~apache_beam.metrics.metric.Lineage`. + + Args: + path: string path to be reported. + level: the level of file path. default to + :class:`~apache_beam.io.filesystem.FileSystem.Lineage`.FILE. + """ filesystem = FileSystems.get_filesystem(path) - filesystem.report_lineage(path, Lineage.sinks()) + filesystem.report_lineage(path, Lineage.sinks(), level=level) diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 2cb64742f26c..11e0d098b2f3 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -1930,6 +1930,8 @@ def __init__( load_job_project_id=None, max_insert_payload_size=MAX_INSERT_PAYLOAD_SIZE, num_streaming_keys=DEFAULT_SHARDS_PER_DESTINATION, + use_cdc_writes: bool = False, + primary_key: List[str] = None, expansion_service=None): """Initialize a WriteToBigQuery transform. @@ -2095,6 +2097,15 @@ def __init__( GCP expansion service. Used for STORAGE_WRITE_API method. max_insert_payload_size: The maximum byte size for a BigQuery legacy streaming insert payload. + use_cdc_writes: Configure the usage of CDC writes on BigQuery. + The argument can be used by passing True and the Beam Rows will be + sent as they are to the BigQuery sink which expects a 'record' + and 'row_mutation_info' properties. + Used for STORAGE_WRITE_API, working on 'at least once' mode. + primary_key: When using CDC write on BigQuery and + CREATE_IF_NEEDED mode for the underlying tables a list of column names + is required to be configured as the primary key. Used for + STORAGE_WRITE_API, working on 'at least once' mode. """ self._table = table self._dataset = dataset @@ -2136,6 +2147,8 @@ def __init__( self.load_job_project_id = load_job_project_id self._max_insert_payload_size = max_insert_payload_size self._num_streaming_keys = num_streaming_keys + self._use_cdc_writes = use_cdc_writes + self._primary_key = primary_key # Dict/schema methods were moved to bigquery_tools, but keep references # here for backward compatibility. @@ -2289,8 +2302,9 @@ def find_in_nested_dict(schema): use_at_least_once=self.use_at_least_once, with_auto_sharding=self.with_auto_sharding, num_storage_api_streams=self._num_storage_api_streams, + use_cdc_writes=self._use_cdc_writes, + primary_key=self._primary_key, expansion_service=self.expansion_service) - else: raise ValueError(f"Unsupported method {method_to_use}") @@ -2518,6 +2532,10 @@ class StorageWriteToBigQuery(PTransform): # fields for rows sent to Storage API with dynamic destinations DESTINATION = "destination" RECORD = "record" + # field names for rows sent to Storage API for CDC functionality + CDC_INFO = "row_mutation_info" + CDC_MUTATION_TYPE = "mutation_type" + CDC_SQN = "change_sequence_number" # magic string to tell Java that these rows are going to dynamic destinations DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS" @@ -2532,6 +2550,8 @@ def __init__( use_at_least_once=False, with_auto_sharding=False, num_storage_api_streams=0, + use_cdc_writes: bool = False, + primary_key: List[str] = None, expansion_service=None): self._table = table self._table_side_inputs = table_side_inputs @@ -2542,6 +2562,8 @@ def __init__( self._use_at_least_once = use_at_least_once self._with_auto_sharding = with_auto_sharding self._num_storage_api_streams = num_storage_api_streams + self._use_cdc_writes = use_cdc_writes + self._primary_key = primary_key self._expansion_service = expansion_service or BeamJarExpansionService( 'sdks:java:io:google-cloud-platform:expansion-service:build') @@ -2552,11 +2574,11 @@ def expand(self, input): is_rows = True except TypeError as exn: raise ValueError( - "A schema is required in order to prepare rows" + "A schema is required in order to prepare rows " "for writing with STORAGE_WRITE_API.") from exn elif callable(self._schema): raise NotImplementedError( - "Writing with dynamic schemas is not" + "Writing with dynamic schemas is not " "supported for this write method.") elif isinstance(self._schema, vp.ValueProvider): schema = self._schema.get() @@ -2624,6 +2646,8 @@ def expand(self, input): auto_sharding=self._with_auto_sharding, num_streams=self._num_storage_api_streams, use_at_least_once_semantics=self._use_at_least_once, + use_cdc_writes=self._use_cdc_writes, + primary_key=self._primary_key, error_handling={ 'output': StorageWriteToBigQuery.FAILED_ROWS_WITH_ERRORS })) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py index a7311ad6d063..3145fb511068 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py @@ -777,26 +777,10 @@ def process( GlobalWindows.windowed_value((destination, job_reference))) def finish_bundle(self): - dataset_locations = {} - for windowed_value in self.pending_jobs: - table_ref = bigquery_tools.parse_table_reference(windowed_value.value[0]) - project_dataset = (table_ref.projectId, table_ref.datasetId) - job_ref = windowed_value.value[1] - # In some cases (e.g. when the load job op returns a 409 ALREADY_EXISTS), - # the returned job reference may not include a location. In such cases, - # we need to override with the dataset's location. - job_location = job_ref.location - if not job_location and project_dataset not in dataset_locations: - job_location = self.bq_wrapper.get_table_location( - table_ref.projectId, table_ref.datasetId, table_ref.tableId) - dataset_locations[project_dataset] = job_location - self.bq_wrapper.wait_for_bq_job( - job_ref, - sleep_duration_sec=_SLEEP_DURATION_BETWEEN_POLLS, - location=job_location) + job_ref, sleep_duration_sec=_SLEEP_DURATION_BETWEEN_POLLS) return self.pending_jobs diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index e4c0e34d9c1f..10453d9c8baf 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -427,7 +427,6 @@ def test_records_traverse_transform_with_mocks(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = bigquery_api.Job() result_job.jobReference = job_reference @@ -483,7 +482,6 @@ def test_load_job_id_used(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'loadJobProject' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = bigquery_api.Job() result_job.jobReference = job_reference @@ -521,7 +519,6 @@ def test_load_job_id_use_for_copy_job(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'loadJobProject' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = mock.Mock() result_job.jobReference = job_reference @@ -577,12 +574,10 @@ def test_wait_for_load_job_completion(self, sleep_mock): job_1.jobReference = bigquery_api.JobReference() job_1.jobReference.projectId = 'project1' job_1.jobReference.jobId = 'jobId1' - job_1.jobReference.location = 'US' job_2 = bigquery_api.Job() job_2.jobReference = bigquery_api.JobReference() job_2.jobReference.projectId = 'project1' job_2.jobReference.jobId = 'jobId2' - job_2.jobReference.location = 'US' job_1_waiting = mock.Mock() job_1_waiting.status.state = 'RUNNING' @@ -622,12 +617,10 @@ def test_one_load_job_failed_after_waiting(self, sleep_mock): job_1.jobReference = bigquery_api.JobReference() job_1.jobReference.projectId = 'project1' job_1.jobReference.jobId = 'jobId1' - job_1.jobReference.location = 'US' job_2 = bigquery_api.Job() job_2.jobReference = bigquery_api.JobReference() job_2.jobReference.projectId = 'project1' job_2.jobReference.jobId = 'jobId2' - job_2.jobReference.location = 'US' job_1_waiting = mock.Mock() job_1_waiting.status.state = 'RUNNING' @@ -664,7 +657,6 @@ def test_multiple_partition_files(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = mock.Mock() result_job.jobReference = job_reference @@ -750,7 +742,6 @@ def test_multiple_partition_files_write_dispositions( job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = mock.Mock() result_job.jobReference = job_reference @@ -793,7 +784,6 @@ def test_triggering_frequency(self, is_streaming, with_auto_sharding): job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = bigquery_api.Job() result_job.jobReference = job_reference diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index c7128e7899ec..48da929a07b2 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -32,6 +32,7 @@ import io import json import logging +import re import sys import time import uuid @@ -558,9 +559,22 @@ def _insert_load_job( )) return self._start_job(request, stream=source_stream).jobReference + @staticmethod + def _parse_location_from_exc(content, job_id): + """Parse job location from Exception content.""" + if isinstance(content, bytes): + content = content.decode('ascii', 'replace') + # search for "Already Exists: Job :." + m = re.search(r"Already Exists: Job \S+\:(\S+)\." + job_id, content) + if not m: + _LOGGER.warning( + "Not able to parse BigQuery load job location for %s", job_id) + return None + return m.group(1) + def _start_job( self, - request, # type: bigquery.BigqueryJobsInsertRequest + request: 'bigquery.BigqueryJobsInsertRequest', stream=None, ): """Inserts a BigQuery job. @@ -585,11 +599,17 @@ def _start_job( return response except HttpError as exn: if exn.status_code == 409: + jobId = request.job.jobReference.jobId _LOGGER.info( "BigQuery job %s already exists, will not retry inserting it: %s", request.job.jobReference, exn) - return request.job + job_location = self._parse_location_from_exc(exn.content, jobId) + response = request.job + if not response.jobReference.location and job_location: + # Request not constructed with location + response.jobReference.location = job_location + return response else: _LOGGER.info( "Failed to insert job %s: %s", request.job.jobReference, exn) @@ -631,8 +651,7 @@ def _start_query_job( return self._start_job(request) - def wait_for_bq_job( - self, job_reference, sleep_duration_sec=5, max_retries=0, location=None): + def wait_for_bq_job(self, job_reference, sleep_duration_sec=5, max_retries=0): """Poll job until it is DONE. Args: @@ -640,7 +659,6 @@ def wait_for_bq_job( sleep_duration_sec: Specifies the delay in seconds between retries. max_retries: The total number of times to retry. If equals to 0, the function waits forever. - location: Fall back on this location if job_reference doesn't have one. Raises: `RuntimeError`: If the job is FAILED or the number of retries has been @@ -650,9 +668,7 @@ def wait_for_bq_job( while True: retry += 1 job = self.get_job( - job_reference.projectId, - job_reference.jobId, - job_reference.location or location) + job_reference.projectId, job_reference.jobId, job_reference.location) _LOGGER.info('Job %s status: %s', job.id, job.status.state) if job.status.state == 'DONE' and job.status.errorResult: raise RuntimeError( @@ -1786,9 +1802,11 @@ def generate_bq_job_name(job_name, step_id, job_type, random=None): def check_schema_equal( - left, right, *, ignore_descriptions=False, ignore_field_order=False): - # type: (Union[bigquery.TableSchema, bigquery.TableFieldSchema], Union[bigquery.TableSchema, bigquery.TableFieldSchema], bool, bool) -> bool - + left: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'], + right: Union['bigquery.TableSchema', 'bigquery.TableFieldSchema'], + *, + ignore_descriptions: bool = False, + ignore_field_order: bool = False) -> bool: """Check whether schemas are equivalent. This comparison function differs from using == to compare TableSchema diff --git a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py index 053b02d325a5..325f70ddfd96 100644 --- a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py +++ b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py @@ -366,10 +366,14 @@ def delete(self, paths): if exceptions: raise BeamIOError("Delete operation failed", exceptions) - def report_lineage(self, path, lineage): + def report_lineage(self, path, lineage, level=None): try: - bucket, blob = gcsio.parse_gcs_path(path) + components = gcsio.parse_gcs_path(path, object_optional=True) except ValueError: # report lineage is fail-safe return - lineage.add('gcs', bucket, blob) + if level == FileSystem.LineageLevel.TOP_LEVEL \ + or(len(components) > 1 and components[-1] == ''): + # bucket only + components = components[:-1] + lineage.add('gcs', *components) diff --git a/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py b/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py index 1206529faf01..ec7fa94b05fd 100644 --- a/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsfilesystem_test.py @@ -375,6 +375,15 @@ def test_delete_error(self, mock_gcsio): self.fs.delete(files) gcsio_mock.delete_batch.assert_called() + def test_lineage(self): + self._verify_lineage("gs://bucket/", ("bucket", )) + self._verify_lineage("gs://bucket/foo/bar.txt", ("bucket", "foo/bar.txt")) + + def _verify_lineage(self, uri, expected_segments): + lineage_mock = mock.MagicMock() + self.fs.report_lineage(uri, lineage_mock) + lineage_mock.add.assert_called_once_with("gcs", *expected_segments) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 22a33fa13c63..8056de51f43f 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -137,8 +137,10 @@ def create_storage_client(pipeline_options, use_credentials=True): class GcsIO(object): """Google Cloud Storage I/O client.""" - def __init__(self, storage_client=None, pipeline_options=None): - # type: (Optional[storage.Client], Optional[Union[dict, PipelineOptions]]) -> None + def __init__( + self, + storage_client: Optional[storage.Client] = None, + pipeline_options: Optional[Union[dict, PipelineOptions]] = None) -> None: if pipeline_options is None: pipeline_options = PipelineOptions() elif isinstance(pipeline_options, dict): diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index b6f801c63f79..9e006dbeda93 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -126,7 +126,7 @@ def _from_proto_str(proto_msg: bytes) -> 'PubsubMessage': """ msg = pubsub.types.PubsubMessage.deserialize(proto_msg) # Convert ScalarMapContainer to dict. - attributes = dict((key, msg.attributes[key]) for key in msg.attributes) + attributes = dict(msg.attributes) return PubsubMessage( msg.data, attributes, @@ -151,10 +151,8 @@ def _to_proto_str(self, for_publish=False): https://cloud.google.com/pubsub/docs/reference/rpc/google.pubsub.v1#google.pubsub.v1.PubsubMessage containing the payload of this object. """ - msg = pubsub.types.PubsubMessage() if len(self.data) > (10_000_000): raise ValueError('A pubsub message data field must not exceed 10MB') - msg.data = self.data if self.attributes: if len(self.attributes) > 100: @@ -167,19 +165,25 @@ def _to_proto_str(self, for_publish=False): if len(value) > 1024: raise ValueError( 'A pubsub message attribute value must not exceed 1024 bytes') - msg.attributes[key] = value + message_id = None + publish_time = None if not for_publish: if self.message_id: - msg.message_id = self.message_id + message_id = self.message_id if self.publish_time: - msg.publish_time = self.publish_time + publish_time = self.publish_time if len(self.ordering_key) > 1024: raise ValueError( 'A pubsub message ordering key must not exceed 1024 bytes.') - msg.ordering_key = self.ordering_key + msg = pubsub.types.PubsubMessage( + data=self.data, + attributes=self.attributes, + message_id=message_id, + publish_time=publish_time, + ordering_key=self.ordering_key) serialized = pubsub.types.PubsubMessage.serialize(msg) if len(serialized) > (10_000_000): raise ValueError( @@ -193,7 +197,7 @@ def _from_message(msg: Any) -> 'PubsubMessage': https://googleapis.github.io/google-cloud-python/latest/pubsub/subscriber/api/message.html """ # Convert ScalarMapContainer to dict. - attributes = dict((key, msg.attributes[key]) for key in msg.attributes) + attributes = dict(msg.attributes) pubsubmessage = PubsubMessage(msg.data, attributes) if msg.message_id: pubsubmessage.message_id = msg.message_id diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 2e3e9b301618..73ba8d6abdb6 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -901,7 +901,8 @@ def test_write_messages_with_attributes_error(self, mock_pubsub): options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True - with self.assertRaisesRegex(Exception, r'Type hint violation'): + with self.assertRaisesRegex(Exception, + r'requires.*PubsubMessage.*applied.*str'): with TestPipeline(options=options) as p: _ = ( p diff --git a/sdks/python/apache_beam/io/jdbc.py b/sdks/python/apache_beam/io/jdbc.py index 3fef1f5fee35..11570680a2f3 100644 --- a/sdks/python/apache_beam/io/jdbc.py +++ b/sdks/python/apache_beam/io/jdbc.py @@ -125,12 +125,14 @@ def default_io_expansion_service(classpath=None): ('read_query', typing.Optional[str]), ('write_statement', typing.Optional[str]), ('fetch_size', typing.Optional[np.int16]), + ('disable_autocommit', typing.Optional[bool]), ('output_parallelization', typing.Optional[bool]), ('autosharding', typing.Optional[bool]), ('partition_column', typing.Optional[str]), ('partitions', typing.Optional[np.int16]), ('max_connections', typing.Optional[np.int16]), - ('driver_jars', typing.Optional[str])], + ('driver_jars', typing.Optional[str]), + ('write_batch_size', typing.Optional[np.int64])], ) DEFAULT_JDBC_CLASSPATH = ['org.postgresql:postgresql:42.2.16'] @@ -186,6 +188,7 @@ def __init__( driver_jars=None, expansion_service=None, classpath=None, + write_batch_size=None, ): """ Initializes a write operation to Jdbc. @@ -217,6 +220,9 @@ def __init__( package (e.g. "org.postgresql:postgresql:42.3.1"). By default, this argument includes a Postgres SQL JDBC driver. + :param write_batch_size: sets the maximum size in number of SQL statement + for the batch. + default is {@link JdbcIO.DEFAULT_BATCH_SIZE} """ classpath = classpath or DEFAULT_JDBC_CLASSPATH super().__init__( @@ -234,8 +240,10 @@ def __init__( connection_properties=connection_properties, connection_init_sqls=connection_init_sqls, write_statement=statement, + write_batch_size=write_batch_size, read_query=None, fetch_size=None, + disable_autocommit=None, output_parallelization=None, autosharding=autosharding, max_connections=max_connections, @@ -286,6 +294,7 @@ def __init__( username, password, query=None, + disable_autocommit=None, output_parallelization=None, fetch_size=None, partition_column=None, @@ -305,6 +314,7 @@ def __init__( :param username: database username :param password: database password :param query: sql query to be executed + :param disable_autocommit: disable autocommit on read :param output_parallelization: is output parallelization on :param fetch_size: how many rows to fetch :param partition_column: enable partitioned reads by splitting on this @@ -348,8 +358,10 @@ def __init__( connection_properties=connection_properties, connection_init_sqls=connection_init_sqls, write_statement=None, + write_batch_size=None, read_query=query, fetch_size=fetch_size, + disable_autocommit=disable_autocommit, output_parallelization=output_parallelization, autosharding=None, max_connections=max_connections, diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd index a8f4003d8980..c583dabeb0c0 100644 --- a/sdks/python/apache_beam/metrics/cells.pxd +++ b/sdks/python/apache_beam/metrics/cells.pxd @@ -33,6 +33,7 @@ cdef class CounterCell(MetricCell): cpdef bint update(self, value) except -1 +# Not using AbstractMetricCell so that data can be typed. cdef class DistributionCell(MetricCell): cdef readonly DistributionData data @@ -40,14 +41,18 @@ cdef class DistributionCell(MetricCell): cdef inline bint _update(self, value) except -1 -cdef class GaugeCell(MetricCell): - cdef readonly object data +cdef class AbstractMetricCell(MetricCell): + cdef readonly object data_class + cdef public object data + cdef bint _update_locked(self, value) except -1 -cdef class StringSetCell(MetricCell): - cdef readonly set data +cdef class GaugeCell(AbstractMetricCell): + pass - cdef inline bint _update(self, value) except -1 + +cdef class StringSetCell(AbstractMetricCell): + pass cdef class DistributionData(object): diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index 407106342fb8..5802c6914eb2 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -23,12 +23,13 @@ # pytype: skip-file +import logging import threading import time from datetime import datetime -from typing import Any +from typing import Iterable from typing import Optional -from typing import SupportsInt +from typing import Set try: import cython @@ -40,13 +41,11 @@ class fake_cython: globals()['cython'] = fake_cython __all__ = [ - 'MetricAggregator', - 'MetricCell', - 'MetricCellFactory', - 'DistributionResult', - 'GaugeResult' + 'MetricCell', 'MetricCellFactory', 'DistributionResult', 'GaugeResult' ] +_LOGGER = logging.getLogger(__name__) + class MetricCell(object): """For internal use only; no backwards-compatibility guarantees. @@ -105,11 +104,11 @@ class CounterCell(MetricCell): """ def __init__(self, *args): super().__init__(*args) - self.value = CounterAggregator.identity_element() + self.value = 0 def reset(self): # type: () -> None - self.value = CounterAggregator.identity_element() + self.value = 0 def combine(self, other): # type: (CounterCell) -> CounterCell @@ -170,11 +169,11 @@ class DistributionCell(MetricCell): """ def __init__(self, *args): super().__init__(*args) - self.data = DistributionAggregator.identity_element() + self.data = DistributionData.identity_element() def reset(self): # type: () -> None - self.data = DistributionAggregator.identity_element() + self.data = DistributionData.identity_element() def combine(self, other): # type: (DistributionCell) -> DistributionCell @@ -216,47 +215,65 @@ def to_runner_api_monitoring_info_impl(self, name, transform_id): ptransform=transform_id) -class GaugeCell(MetricCell): +class AbstractMetricCell(MetricCell): """For internal use only; no backwards-compatibility guarantees. - Tracks the current value and delta for a gauge metric. - - Each cell tracks the state of a metric independently per context per bundle. - Therefore, each metric has a different cell in each bundle, that is later - aggregated. + Tracks the current value and delta for a metric with a data class. This class is thread safe. """ - def __init__(self, *args): - super().__init__(*args) - self.data = GaugeAggregator.identity_element() + def __init__(self, data_class): + super().__init__() + self.data_class = data_class + self.data = self.data_class.identity_element() def reset(self): - self.data = GaugeAggregator.identity_element() + self.data = self.data_class.identity_element() - def combine(self, other): - # type: (GaugeCell) -> GaugeCell - result = GaugeCell() + def combine(self, other: 'AbstractMetricCell') -> 'AbstractMetricCell': + result = type(self)() # type: ignore[call-arg] result.data = self.data.combine(other.data) return result def set(self, value): - self.update(value) + with self._lock: + self._update_locked(value) def update(self, value): - # type: (SupportsInt) -> None - value = int(value) with self._lock: - # Set the value directly without checking timestamp, because - # this value is naturally the latest value. - self.data.value = value - self.data.timestamp = time.time() + self._update_locked(value) + + def _update_locked(self, value): + raise NotImplementedError(type(self)) def get_cumulative(self): - # type: () -> GaugeData with self._lock: return self.data.get_cumulative() + def to_runner_api_monitoring_info_impl(self, name, transform_id): + raise NotImplementedError(type(self)) + + +class GaugeCell(AbstractMetricCell): + """For internal use only; no backwards-compatibility guarantees. + + Tracks the current value and delta for a gauge metric. + + Each cell tracks the state of a metric independently per context per bundle. + Therefore, each metric has a different cell in each bundle, that is later + aggregated. + + This class is thread safe. + """ + def __init__(self): + super().__init__(GaugeData) + + def _update_locked(self, value): + # Set the value directly without checking timestamp, because + # this value is naturally the latest value. + self.data.value = int(value) + self.data.timestamp = time.time() + def to_runner_api_monitoring_info_impl(self, name, transform_id): from apache_beam.metrics import monitoring_infos return monitoring_infos.int64_user_gauge( @@ -266,7 +283,7 @@ def to_runner_api_monitoring_info_impl(self, name, transform_id): ptransform=transform_id) -class StringSetCell(MetricCell): +class StringSetCell(AbstractMetricCell): """For internal use only; no backwards-compatibility guarantees. Tracks the current value for a StringSet metric. @@ -277,50 +294,23 @@ class StringSetCell(MetricCell): This class is thread safe. """ - def __init__(self, *args): - super().__init__(*args) - self.data = StringSetAggregator.identity_element() + def __init__(self): + super().__init__(StringSetData) def add(self, value): self.update(value) - def update(self, value): - # type: (str) -> None - if cython.compiled: - # We will hold the GIL throughout the entire _update. - self._update(value) - else: - with self._lock: - self._update(value) - - def _update(self, value): + def _update_locked(self, value): self.data.add(value) - def get_cumulative(self): - # type: () -> set - with self._lock: - return set(self.data) - - def combine(self, other): - # type: (StringSetCell) -> StringSetCell - combined = StringSetAggregator().combine(self.data, other.data) - result = StringSetCell() - result.data = combined - return result - def to_runner_api_monitoring_info_impl(self, name, transform_id): from apache_beam.metrics import monitoring_infos - return monitoring_infos.user_set_string( name.namespace, name.name, self.get_cumulative(), ptransform=transform_id) - def reset(self): - # type: () -> None - self.data = StringSetAggregator.identity_element() - class DistributionResult(object): """The result of a Distribution metric.""" @@ -444,6 +434,10 @@ def get_cumulative(self): # type: () -> GaugeData return GaugeData(self.value, timestamp=self.timestamp) + def get_result(self): + # type: () -> GaugeResult + return GaugeResult(self.get_cumulative()) + def combine(self, other): # type: (Optional[GaugeData]) -> GaugeData if other is None: @@ -459,6 +453,11 @@ def singleton(value, timestamp=None): # type: (Optional[int], Optional[int]) -> GaugeData return GaugeData(value, timestamp=timestamp) + @staticmethod + def identity_element(): + # type: () -> GaugeData + return GaugeData(0, timestamp=0) + class DistributionData(object): """For internal use only; no backwards-compatibility guarantees. @@ -505,6 +504,9 @@ def get_cumulative(self): # type: () -> DistributionData return DistributionData(self.sum, self.count, self.min, self.max) + def get_result(self) -> DistributionResult: + return DistributionResult(self.get_cumulative()) + def combine(self, other): # type: (Optional[DistributionData]) -> DistributionData if other is None: @@ -521,108 +523,110 @@ def singleton(value): # type: (int) -> DistributionData return DistributionData(value, 1, value, value) - -class MetricAggregator(object): - """For internal use only; no backwards-compatibility guarantees. - - Base interface for aggregating metric data during pipeline execution.""" - def identity_element(self): - # type: () -> Any - - """Returns the identical element of an Aggregation. - - For the identity element, it must hold that - Aggregator.combine(any_element, identity_element) == any_element. - """ - raise NotImplementedError - - def combine(self, x, y): - # type: (Any, Any) -> Any - raise NotImplementedError - - def result(self, x): - # type: (Any) -> Any - raise NotImplementedError + @staticmethod + def identity_element(): + # type: () -> DistributionData + return DistributionData(0, 0, 2**63 - 1, -2**63) -class CounterAggregator(MetricAggregator): +class StringSetData(object): """For internal use only; no backwards-compatibility guarantees. - Aggregator for Counter metric data during pipeline execution. + The data structure that holds data about a StringSet metric. - Values aggregated should be ``int`` objects. - """ - @staticmethod - def identity_element(): - # type: () -> int - return 0 - - def combine(self, x, y): - # type: (SupportsInt, SupportsInt) -> int - return int(x) + int(y) + StringSet metrics are restricted to set of strings only. - def result(self, x): - # type: (SupportsInt) -> int - return int(x) + This object is not thread safe, so it's not supposed to be modified + by other than the StringSetCell that contains it. + The summation of all string length for a StringSetData cannot exceed 1 MB. + Further addition of elements are dropped. + """ -class DistributionAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. + _STRING_SET_SIZE_LIMIT = 1_000_000 - Aggregator for Distribution metric data during pipeline execution. + def __init__(self, string_set: Optional[Set] = None, string_size: int = 0): + self.string_set = string_set or set() + if not string_size: + string_size = 0 + for s in self.string_set: + string_size += len(s) + self.string_size = string_size - Values aggregated should be ``DistributionData`` objects. - """ - @staticmethod - def identity_element(): - # type: () -> DistributionData - return DistributionData(0, 0, 2**63 - 1, -2**63) + def __eq__(self, other: object) -> bool: + if isinstance(other, StringSetData): + return ( + self.string_size == other.string_size and + self.string_set == other.string_set) + else: + return False - def combine(self, x, y): - # type: (DistributionData, DistributionData) -> DistributionData - return x.combine(y) + def __hash__(self) -> int: + return hash(self.string_set) - def result(self, x): - # type: (DistributionData) -> DistributionResult - return DistributionResult(x.get_cumulative()) + def __repr__(self) -> str: + return 'StringSetData{}:{}'.format(self.string_set, self.string_size) + def get_cumulative(self) -> "StringSetData": + return StringSetData(set(self.string_set), self.string_size) -class GaugeAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. + def get_result(self) -> Set[str]: + return set(self.string_set) - Aggregator for Gauge metric data during pipeline execution. + def add(self, *strings): + """ + Add strings into this StringSetData and return the result StringSetData. + Reuse the original StringSetData's set. + """ + self.string_size = self.add_until_capacity( + self.string_set, self.string_size, strings) + return self - Values aggregated should be ``GaugeData`` objects. - """ - @staticmethod - def identity_element(): - # type: () -> GaugeData - return GaugeData(0, timestamp=0) + def combine(self, other: "StringSetData") -> "StringSetData": + """ + Combines this StringSetData with other, both original StringSetData are left + intact. + """ + if other is None: + return self - def combine(self, x, y): - # type: (GaugeData, GaugeData) -> GaugeData - result = x.combine(y) - return result + if not other.string_set: + return self + elif not self.string_set: + return other - def result(self, x): - # type: (GaugeData) -> GaugeResult - return GaugeResult(x.get_cumulative()) + combined = set(self.string_set) + string_size = self.add_until_capacity( + combined, self.string_size, other.string_set) + return StringSetData(combined, string_size) + @classmethod + def add_until_capacity( + cls, combined: set, current_size: int, others: Iterable[str]): + """ + Add strings into set until reach capacity. Return the all string size of + added set. + """ + if current_size > cls._STRING_SET_SIZE_LIMIT: + return current_size + + for string in others: + if string not in combined: + combined.add(string) + current_size += len(string) + if current_size > cls._STRING_SET_SIZE_LIMIT: + _LOGGER.warning( + "StringSet metrics reaches capacity. Further incoming elements " + "won't be recorded. Current size: %d, last element size: %d.", + current_size, + len(string)) + break + return current_size -class StringSetAggregator(MetricAggregator): @staticmethod - def identity_element(): - # type: () -> set - return set() - - def combine(self, x, y): - # type: (set, set) -> set - if len(x) == 0: - return y - elif len(y) == 0: - return x - else: - return set.union(x, y) + def singleton(value: str) -> "StringSetData": + return StringSetData({value}) - def result(self, x): - return x + @staticmethod + def identity_element() -> "StringSetData": + return StringSetData() diff --git a/sdks/python/apache_beam/metrics/cells_test.py b/sdks/python/apache_beam/metrics/cells_test.py index 052ff051bf96..d1ee37b8ed82 100644 --- a/sdks/python/apache_beam/metrics/cells_test.py +++ b/sdks/python/apache_beam/metrics/cells_test.py @@ -26,6 +26,7 @@ from apache_beam.metrics.cells import GaugeCell from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import StringSetCell +from apache_beam.metrics.cells import StringSetData from apache_beam.metrics.metricbase import MetricName @@ -176,9 +177,9 @@ def test_not_leak_mutable_set(self): c.add('test') c.add('another') s = c.get_cumulative() - self.assertEqual(s, set(('test', 'another'))) + self.assertEqual(s, StringSetData({'test', 'another'}, 11)) s.add('yet another') - self.assertEqual(c.get_cumulative(), set(('test', 'another'))) + self.assertEqual(c.get_cumulative(), StringSetData({'test', 'another'}, 11)) def test_combine_appropriately(self): s1 = StringSetCell() @@ -190,7 +191,16 @@ def test_combine_appropriately(self): s2.add('3') result = s2.combine(s1) - self.assertEqual(result.data, set(('1', '2', '3'))) + self.assertEqual(result.data, StringSetData({'1', '2', '3'})) + + def test_add_size_tracked_correctly(self): + s = StringSetCell() + s.add('1') + s.add('2') + self.assertEqual(s.data.string_size, 2) + s.add('2') + s.add('3') + self.assertEqual(s.data.string_size, 3) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py index 37007add9163..fa70d3a4d9c0 100644 --- a/sdks/python/apache_beam/metrics/execution.py +++ b/sdks/python/apache_beam/metrics/execution.py @@ -47,6 +47,7 @@ from apache_beam.metrics.cells import DistributionCell from apache_beam.metrics.cells import GaugeCell from apache_beam.metrics.cells import StringSetCell +from apache_beam.metrics.cells import StringSetData from apache_beam.runners.worker import statesampler from apache_beam.runners.worker.statesampler import get_current_tracker @@ -356,7 +357,7 @@ def __init__( counters=None, # type: Optional[Dict[MetricKey, int]] distributions=None, # type: Optional[Dict[MetricKey, DistributionData]] gauges=None, # type: Optional[Dict[MetricKey, GaugeData]] - string_sets=None, # type: Optional[Dict[MetricKey, set]] + string_sets=None, # type: Optional[Dict[MetricKey, StringSetData]] ): # type: (...) -> None diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py index b157aeb20e9e..38e27f1f3d0c 100644 --- a/sdks/python/apache_beam/metrics/execution_test.py +++ b/sdks/python/apache_beam/metrics/execution_test.py @@ -110,11 +110,12 @@ def test_get_cumulative_or_updates(self): self.assertEqual( set(all_values), {v.value for _, v in cumulative.gauges.items()}) - self.assertEqual({str(i % 7) - for i in all_values}, - functools.reduce( - set.union, - (v for _, v in cumulative.string_sets.items()))) + self.assertEqual( + {str(i % 7) + for i in all_values}, + functools.reduce( + set.union, + (v.string_set for _, v in cumulative.string_sets.items()))) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py index f402c0acab2f..3e665dd805ea 100644 --- a/sdks/python/apache_beam/metrics/metric.py +++ b/sdks/python/apache_beam/metrics/metric.py @@ -140,7 +140,7 @@ class DelegatingCounter(Counter): def __init__( self, metric_name: MetricName, process_wide: bool = False) -> None: super().__init__(metric_name) - self.inc = MetricUpdater( # type: ignore[assignment] + self.inc = MetricUpdater( # type: ignore[method-assign] cells.CounterCell, metric_name, default_value=1, @@ -150,19 +150,19 @@ class DelegatingDistribution(Distribution): """Metrics Distribution Delegates functionality to MetricsEnvironment.""" def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) - self.update = MetricUpdater(cells.DistributionCell, metric_name) # type: ignore[assignment] + self.update = MetricUpdater(cells.DistributionCell, metric_name) # type: ignore[method-assign] class DelegatingGauge(Gauge): """Metrics Gauge that Delegates functionality to MetricsEnvironment.""" def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) - self.set = MetricUpdater(cells.GaugeCell, metric_name) # type: ignore[assignment] + self.set = MetricUpdater(cells.GaugeCell, metric_name) # type: ignore[method-assign] class DelegatingStringSet(StringSet): """Metrics StringSet that Delegates functionality to MetricsEnvironment.""" def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) - self.add = MetricUpdater(cells.StringSetCell, metric_name) # type: ignore[assignment] + self.add = MetricUpdater(cells.StringSetCell, metric_name) # type: ignore[method-assign] class MetricResults(object): diff --git a/sdks/python/apache_beam/metrics/monitoring_infos.py b/sdks/python/apache_beam/metrics/monitoring_infos.py index a9540f2846ad..5227a4c9872b 100644 --- a/sdks/python/apache_beam/metrics/monitoring_infos.py +++ b/sdks/python/apache_beam/metrics/monitoring_infos.py @@ -31,6 +31,7 @@ from apache_beam.metrics.cells import DistributionResult from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import GaugeResult +from apache_beam.metrics.cells import StringSetData from apache_beam.portability import common_urns from apache_beam.portability.api import metrics_pb2 @@ -181,9 +182,8 @@ def create_labels(ptransform=None, namespace=None, name=None, pcollection=None): return labels -def int64_user_counter(namespace, name, metric, ptransform=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_user_counter( + namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo: """Return the counter monitoring info for the specifed URN, metric and labels. Args: @@ -198,9 +198,12 @@ def int64_user_counter(namespace, name, metric, ptransform=None): USER_COUNTER_URN, SUM_INT64_TYPE, metric, labels) -def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_counter( + urn, + metric, + ptransform=None, + pcollection=None, + labels=None) -> metrics_pb2.MonitoringInfo: """Return the counter monitoring info for the specifed URN, metric and labels. Args: @@ -216,9 +219,8 @@ def int64_counter(urn, metric, ptransform=None, pcollection=None, labels=None): return create_monitoring_info(urn, SUM_INT64_TYPE, metric, labels) -def int64_user_distribution(namespace, name, metric, ptransform=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_user_distribution( + namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo: """Return the distribution monitoring info for the URN, metric and labels. Args: @@ -233,9 +235,11 @@ def int64_user_distribution(namespace, name, metric, ptransform=None): USER_DISTRIBUTION_URN, DISTRIBUTION_INT64_TYPE, payload, labels) -def int64_distribution(urn, metric, ptransform=None, pcollection=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_distribution( + urn, + metric, + ptransform=None, + pcollection=None) -> metrics_pb2.MonitoringInfo: """Return a distribution monitoring info for the URN, metric and labels. Args: @@ -250,9 +254,8 @@ def int64_distribution(urn, metric, ptransform=None, pcollection=None): return create_monitoring_info(urn, DISTRIBUTION_INT64_TYPE, payload, labels) -def int64_user_gauge(namespace, name, metric, ptransform=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_user_gauge( + namespace, name, metric, ptransform=None) -> metrics_pb2.MonitoringInfo: """Return the gauge monitoring info for the URN, metric and labels. Args: @@ -275,9 +278,7 @@ def int64_user_gauge(namespace, name, metric, ptransform=None): USER_GAUGE_URN, LATEST_INT64_TYPE, payload, labels) -def int64_gauge(urn, metric, ptransform=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def int64_gauge(urn, metric, ptransform=None) -> metrics_pb2.MonitoringInfo: """Return the gauge monitoring info for the URN, metric and labels. Args: @@ -305,10 +306,12 @@ def user_set_string(namespace, name, metric, ptransform=None): Args: namespace: User-defined namespace of StringSet. name: Name of StringSet. - metric: The set representing the metrics. + metric: The StringSetData representing the metrics. ptransform: The ptransform id used as a label. """ labels = create_labels(ptransform=ptransform, namespace=namespace, name=name) + if isinstance(metric, StringSetData): + metric = metric.string_set if isinstance(metric, set): metric = list(metric) if isinstance(metric, list): @@ -317,9 +320,8 @@ def user_set_string(namespace, name, metric, ptransform=None): USER_STRING_SET_URN, STRING_SET_TYPE, metric, labels) -def create_monitoring_info(urn, type_urn, payload, labels=None): - # type: (...) -> metrics_pb2.MonitoringInfo - +def create_monitoring_info( + urn, type_urn, payload, labels=None) -> metrics_pb2.MonitoringInfo: """Return the gauge monitoring info for the URN, type, metric and labels. Args: @@ -363,9 +365,9 @@ def is_user_monitoring_info(monitoring_info_proto): return monitoring_info_proto.urn in USER_METRIC_URNS -def extract_metric_result_map_value(monitoring_info_proto): - # type: (...) -> Union[None, int, DistributionResult, GaugeResult, set] - +def extract_metric_result_map_value( + monitoring_info_proto +) -> Union[None, int, DistributionResult, GaugeResult, set]: """Returns the relevant GaugeResult, DistributionResult or int value for counter metric, set for StringSet metric. @@ -405,14 +407,13 @@ def get_step_name(monitoring_info_proto): return monitoring_info_proto.labels.get(PTRANSFORM_LABEL) -def to_key(monitoring_info_proto): - # type: (metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable] - +def to_key( + monitoring_info_proto: metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable]: """Returns a key based on the URN and labels. This is useful in maps to prevent reporting the same MonitoringInfo twice. """ - key_items = list(monitoring_info_proto.labels.items()) # type: List[Hashable] + key_items: List[Hashable] = list(monitoring_info_proto.labels.items()) key_items.append(monitoring_info_proto.urn) return frozenset(key_items) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index e7af114ad431..4ac856456748 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -116,7 +116,15 @@ def load_model(self) -> ort.InferenceSession: # when path is remote, we should first load into memory then deserialize f = FileSystems.open(self._model_uri, "rb") model_proto = onnx.load(f) - model_proto_bytes = onnx._serialize(model_proto) + model_proto_bytes = model_proto + if not isinstance(model_proto, bytes): + if (hasattr(model_proto, "SerializeToString") and + callable(model_proto.SerializeToString)): + model_proto_bytes = model_proto.SerializeToString() + else: + raise TypeError( + "No SerializeToString method is detected on loaded model. " + + f"Type of model: {type(model_proto)}") ort_session = ort.InferenceSession( model_proto_bytes, sess_options=self._session_options, diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py index 52123516de1a..dc35aa016013 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py @@ -21,6 +21,7 @@ import shutil import tempfile import unittest +import uuid from typing import Any from typing import Dict from typing import Iterable @@ -127,7 +128,7 @@ def test_predict_tensor(self): def test_predict_tensor_with_batch_size(self): model = _create_mult2_model() - model_path = os.path.join(self.tmpdir, 'mult2.keras') + model_path = os.path.join(self.tmpdir, f'mult2_{uuid.uuid4()}.keras') tf.keras.models.save_model(model, model_path) with TestPipeline() as pipeline: @@ -173,7 +174,7 @@ def fake_batching_inference_fn( def test_predict_tensor_with_large_model(self): model = _create_mult2_model() - model_path = os.path.join(self.tmpdir, 'mult2.keras') + model_path = os.path.join(self.tmpdir, f'mult2_{uuid.uuid4()}.keras') tf.keras.models.save_model(model, model_path) with TestPipeline() as pipeline: @@ -220,7 +221,7 @@ def fake_batching_inference_fn( def test_predict_numpy_with_batch_size(self): model = _create_mult2_model() - model_path = os.path.join(self.tmpdir, 'mult2_numpy.keras') + model_path = os.path.join(self.tmpdir, f'mult2_{uuid.uuid4()}.keras') tf.keras.models.save_model(model, model_path) with TestPipeline() as pipeline: @@ -263,7 +264,7 @@ def fake_batching_inference_fn( def test_predict_numpy_with_large_model(self): model = _create_mult2_model() - model_path = os.path.join(self.tmpdir, 'mult2_numpy.keras') + model_path = os.path.join(self.tmpdir, f'mult2_{uuid.uuid4()}.keras') tf.keras.models.save_model(model, model_path) with TestPipeline() as pipeline: diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py index b38947b494c2..9563aa05232a 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py @@ -125,7 +125,7 @@ def __init__(self, engine: trt.ICudaEngine): # TODO(https://github.com/NVIDIA/TensorRT/issues/2557): # Clean up when fixed upstream. try: - _ = np.bool # type: ignore + _ = np.bool except AttributeError: # numpy >= 1.24.0 np.bool = np.bool_ # type: ignore @@ -258,7 +258,7 @@ def __init__( model_copies: The exact number of models that you would like loaded onto your machine. This can be useful if you exactly know your CPU or GPU capacity and want to maximize resource utilization. - max_batch_duration_secs: the maximum amount of time to buffer + max_batch_duration_secs: the maximum amount of time to buffer a batch before emitting; used in streaming contexts. kwargs: Additional arguments like 'engine_path' and 'onnx_path' are currently supported. 'env_vars' can be used to set environment variables diff --git a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile index 5abbffdc5a2a..f27abbfd0051 100644 --- a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile +++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile @@ -40,7 +40,7 @@ RUN pip install openai vllm RUN apt install libcairo2-dev pkg-config python3-dev -y RUN pip install pycairo -# Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image. +# Copy the Apache Beam worker dependencies from the Beam Python 3.12 SDK image. COPY --from=apache/beam_python3.12_sdk:2.58.1 /opt/apache/beam /opt/apache/beam # Set the entrypoint to Apache Beam SDK worker launcher. diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 28890083d93e..799083d16ceb 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -17,9 +17,11 @@ # pytype: skip-file +import asyncio import logging import os import subprocess +import sys import threading import time import uuid @@ -35,6 +37,7 @@ from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.utils import subprocess_server +from openai import AsyncOpenAI from openai import OpenAI try: @@ -94,6 +97,15 @@ def getVLLMClient(port) -> OpenAI: ) +def getAsyncVLLMClient(port) -> AsyncOpenAI: + openai_api_key = "EMPTY" + openai_api_base = f"http://localhost:{port}/v1" + return AsyncOpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + class _VLLMModelServer(): def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]): self._model_name = model_name @@ -107,7 +119,7 @@ def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]): def start_server(self, retries=3): if not self._server_started: server_cmd = [ - 'python', + sys.executable, '-m', 'vllm.entrypoints.openai.api_server', '--model', @@ -120,7 +132,7 @@ def start_server(self, retries=3): server_cmd.append(v) self._server_process, self._server_port = start_process(server_cmd) - self.check_connectivity() + self.check_connectivity(retries) def get_server_port(self) -> int: if not self._server_started: @@ -184,6 +196,34 @@ def __init__( def load_model(self) -> _VLLMModelServer: return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) + async def _async_run_inference( + self, + batch: Sequence[str], + model: _VLLMModelServer, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + client = getAsyncVLLMClient(model.get_server_port()) + inference_args = inference_args or {} + async_predictions = [] + for prompt in batch: + try: + completion = client.completions.create( + model=self._model_name, prompt=prompt, **inference_args) + async_predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + + predictions = [] + for p in async_predictions: + try: + predictions.append(await p) + except Exception as e: + model.check_connectivity() + raise e + + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + def run_inference( self, batch: Sequence[str], @@ -200,22 +240,7 @@ def run_inference( Returns: An Iterable of type PredictionResult. """ - client = getVLLMClient(model.get_server_port()) - inference_args = inference_args or {} - predictions = [] - # TODO(https://github.com/apache/beam/issues/32528): We should add support - # for taking in batches and doing a bunch of async calls. That will end up - # being more efficient when we can do in bundle batching. - for prompt in batch: - try: - completion = client.completions.create( - model=self._model_name, prompt=prompt, **inference_args) - predictions.append(completion) - except Exception as e: - model.check_connectivity() - raise e - - return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + return asyncio.run(self._async_run_inference(batch, model, inference_args)) def share_model_across_processes(self) -> bool: return True @@ -272,28 +297,15 @@ def load_model(self) -> _VLLMModelServer: return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) - def run_inference( + async def _async_run_inference( self, batch: Sequence[Sequence[OpenAIChatMessage]], model: _VLLMModelServer, inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[PredictionResult]: - """Runs inferences on a batch of text strings. - - Args: - batch: A sequence of examples as OpenAI messages. - model: A _VLLMModelServer for connecting to the spun up server. - inference_args: Any additional arguments for an inference. - - Returns: - An Iterable of type PredictionResult. - """ - client = getVLLMClient(model.get_server_port()) + client = getAsyncVLLMClient(model.get_server_port()) inference_args = inference_args or {} - predictions = [] - # TODO(https://github.com/apache/beam/issues/32528): We should add support - # for taking in batches and doing a bunch of async calls. That will end up - # being more efficient when we can do in bundle batching. + async_predictions = [] for messages in batch: formatted = [] for message in messages: @@ -301,12 +313,38 @@ def run_inference( try: completion = client.chat.completions.create( model=self._model_name, messages=formatted, **inference_args) - predictions.append(completion) + async_predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + + predictions = [] + for p in async_predictions: + try: + predictions.append(await p) except Exception as e: model.check_connectivity() raise e return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + def run_inference( + self, + batch: Sequence[Sequence[OpenAIChatMessage]], + model: _VLLMModelServer, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """Runs inferences on a batch of text strings. + + Args: + batch: A sequence of examples as OpenAI messages. + model: A _VLLMModelServer for connecting to the spun up server. + inference_args: Any additional arguments for an inference. + + Returns: + An Iterable of type PredictionResult. + """ + return asyncio.run(self._async_run_inference(batch, model, inference_args)) + def share_model_across_processes(self) -> bool: return True diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py b/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py index eab547b1c17b..e09f116dfb38 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py @@ -53,7 +53,7 @@ def _compare_prediction_result(a: PredictionResult, b: PredictionResult): example_equal = numpy.array_equal(a.example.todense(), b.example.todense()) else: - example_equal = numpy.array_equal(a.example, b.example) + example_equal = numpy.array_equal(a.example, b.example) # type: ignore[arg-type] if isinstance(a.inference, dict): return all( x == y for x, y in zip(a.inference.values(), diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index 678ab0882d24..a963f602a06d 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -20,14 +20,11 @@ import os import tempfile import uuid +from collections.abc import Mapping +from collections.abc import Sequence from typing import Any -from typing import Dict from typing import Generic -from typing import List -from typing import Mapping from typing import Optional -from typing import Sequence -from typing import Tuple from typing import TypeVar from typing import Union @@ -67,7 +64,7 @@ def _convert_list_of_dicts_to_dict_of_lists( - list_of_dicts: Sequence[Dict[str, Any]]) -> Dict[str, List[Any]]: + list_of_dicts: Sequence[dict[str, Any]]) -> dict[str, list[Any]]: keys_to_element_list = collections.defaultdict(list) input_keys = list_of_dicts[0].keys() for d in list_of_dicts: @@ -83,9 +80,9 @@ def _convert_list_of_dicts_to_dict_of_lists( def _convert_dict_of_lists_to_lists_of_dict( - dict_of_lists: Dict[str, List[Any]]) -> List[Dict[str, Any]]: + dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]: batch_length = len(next(iter(dict_of_lists.values()))) - result: List[Dict[str, Any]] = [{} for _ in range(batch_length)] + result: list[dict[str, Any]] = [{} for _ in range(batch_length)] # all the values in the dict_of_lists should have same length for key, values in dict_of_lists.items(): assert len(values) == batch_length, ( @@ -140,7 +137,7 @@ def get_counter(self): class BaseOperation(Generic[OperationInputT, OperationOutputT], MLTransformProvider, abc.ABC): - def __init__(self, columns: List[str]) -> None: + def __init__(self, columns: list[str]) -> None: """ Base Opertation class data processing transformations. Args: @@ -150,7 +147,7 @@ def __init__(self, columns: List[str]) -> None: @abc.abstractmethod def apply_transform(self, data: OperationInputT, - output_column_name: str) -> Dict[str, OperationOutputT]: + output_column_name: str) -> dict[str, OperationOutputT]: """ Define any processing logic in the apply_transform() method. processing logics are applied on inputs and returns a transformed @@ -160,7 +157,7 @@ def apply_transform(self, data: OperationInputT, """ def __call__(self, data: OperationInputT, - output_column_name: str) -> Dict[str, OperationOutputT]: + output_column_name: str) -> dict[str, OperationOutputT]: """ This method is called when the instance of the class is called. This method will invoke the apply() method of the class. @@ -172,7 +169,7 @@ def __call__(self, data: OperationInputT, class ProcessHandler( beam.PTransform[beam.PCollection[ExampleT], Union[beam.PCollection[MLTransformOutputT], - Tuple[beam.PCollection[MLTransformOutputT], + tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]], abc.ABC): """ @@ -190,10 +187,10 @@ def append_transform(self, transform: BaseOperation): class EmbeddingsManager(MLTransformProvider): def __init__( self, - columns: List[str], + columns: list[str], *, # common args for all ModelHandlers. - load_model_args: Optional[Dict[str, Any]] = None, + load_model_args: Optional[dict[str, Any]] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, large_model: bool = False, @@ -222,7 +219,7 @@ def get_columns_to_apply(self): class MLTransform( beam.PTransform[beam.PCollection[ExampleT], Union[beam.PCollection[MLTransformOutputT], - Tuple[beam.PCollection[MLTransformOutputT], + tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]], Generic[ExampleT, MLTransformOutputT]): def __init__( @@ -230,7 +227,7 @@ def __init__( *, write_artifact_location: Optional[str] = None, read_artifact_location: Optional[str] = None, - transforms: Optional[List[MLTransformProvider]] = None): + transforms: Optional[list[MLTransformProvider]] = None): """ MLTransform is a Beam PTransform that can be used to apply transformations to the data. MLTransform is used to wrap the @@ -304,12 +301,12 @@ def __init__( self._counter = Metrics.counter( MLTransform, f'BeamML_{self.__class__.__name__}') self._with_exception_handling = False - self._exception_handling_args: Dict[str, Any] = {} + self._exception_handling_args: dict[str, Any] = {} def expand( self, pcoll: beam.PCollection[ExampleT] ) -> Union[beam.PCollection[MLTransformOutputT], - Tuple[beam.PCollection[MLTransformOutputT], + tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]: """ This is the entrypoint for the MLTransform. This method will @@ -533,7 +530,7 @@ class _MLTransformToPTransformMapper: """ def __init__( self, - transforms: List[MLTransformProvider], + transforms: list[MLTransformProvider], artifact_location: str, artifact_mode: str = ArtifactMode.PRODUCE, pipeline_options: Optional[PipelineOptions] = None, @@ -595,7 +592,7 @@ class _EmbeddingHandler(ModelHandler): For example, if the original mode is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a - PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. _EmbeddingHandler will accept an EmbeddingsManager instance, which contains the details of the model to be loaded and the inference_fn to be @@ -619,7 +616,7 @@ def load_model(self): def _validate_column_data(self, batch): pass - def _validate_batch(self, batch: Sequence[Dict[str, Any]]): + def _validate_batch(self, batch: Sequence[dict[str, Any]]): if not batch or not isinstance(batch[0], dict): raise TypeError( 'Expected data to be dicts, got ' @@ -627,10 +624,10 @@ def _validate_batch(self, batch: Sequence[Dict[str, Any]]): def _process_batch( self, - dict_batch: Dict[str, List[Any]], + dict_batch: dict[str, list[Any]], model: ModelT, - inference_args: Optional[Dict[str, Any]]) -> Dict[str, List[Any]]: - result: Dict[str, List[Any]] = collections.defaultdict(list) + inference_args: Optional[dict[str, Any]]) -> dict[str, list[Any]]: + result: dict[str, list[Any]] = collections.defaultdict(list) input_keys = dict_batch.keys() missing_columns_in_data = set(self.columns) - set(input_keys) if missing_columns_in_data: @@ -653,10 +650,10 @@ def _process_batch( def run_inference( self, - batch: Sequence[Dict[str, List[str]]], + batch: Sequence[dict[str, list[str]]], model: ModelT, - inference_args: Optional[Dict[str, Any]] = None, - ) -> List[Dict[str, Union[List[float], List[str]]]]: + inference_args: Optional[dict[str, Any]] = None, + ) -> list[dict[str, Union[list[float], list[str]]]]: """ Runs inference on a batch of text inputs. The inputs are expected to be a list of dicts. Each dict should have the same keys, and the shape @@ -696,7 +693,7 @@ class _TextEmbeddingHandler(_EmbeddingHandler): For example, if the original mode is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a - PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. _TextEmbeddingHandler will accept an EmbeddingsManager instance, which contains the details of the model to be loaded and the inference_fn to be @@ -713,8 +710,8 @@ class _TextEmbeddingHandler(_EmbeddingHandler): def _validate_column_data(self, batch): if not isinstance(batch[0], (str, bytes)): raise TypeError( - 'Embeddings can only be generated on Dict[str, str].' - f'Got Dict[str, {type(batch[0])}] instead.') + 'Embeddings can only be generated on dict[str, str].' + f'Got dict[str, {type(batch[0])}] instead.') def get_metrics_namespace(self) -> str: return ( @@ -730,7 +727,7 @@ class _ImageEmbeddingHandler(_EmbeddingHandler): For example, if the original mode is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a - PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. _ImageEmbeddingHandler will accept an EmbeddingsManager instance, which contains the details of the model to be loaded and the inference_fn to be @@ -750,8 +747,8 @@ def _validate_column_data(self, batch): # here, so just catch columns of primatives for now. if isinstance(batch[0], (int, str, float, bool)): raise TypeError( - 'Embeddings can only be generated on Dict[str, Image].' - f'Got Dict[str, {type(batch[0])}] instead.') + 'Embeddings can only be generated on dict[str, Image].' + f'Got dict[str, {type(batch[0])}] instead.') def get_metrics_namespace(self) -> str: return ( diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 743c3683ce8e..3db5a63b9542 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -17,15 +17,14 @@ # pytype: skip-file import os +import secrets import shutil import tempfile -import typing +import time import unittest +from collections.abc import Sequence from typing import Any -from typing import Dict -from typing import List from typing import Optional -from typing import Sequence import numpy as np from parameterized import param @@ -140,8 +139,8 @@ def test_ml_transform_on_list_dict(self): 'x': int, 'y': float }, expected_dtype={ - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], }, ), param( @@ -153,8 +152,8 @@ def test_ml_transform_on_list_dict(self): 'x': np.int32, 'y': np.float32 }, expected_dtype={ - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], }, ), param( @@ -162,11 +161,11 @@ def test_ml_transform_on_list_dict(self): 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0] }], input_types={ - 'x': List[int], 'y': List[float] + 'x': list[int], 'y': list[float] }, expected_dtype={ - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], }, ), param( @@ -174,12 +173,12 @@ def test_ml_transform_on_list_dict(self): 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0] }], input_types={ - 'x': typing.Sequence[int], - 'y': typing.Sequence[float], + 'x': Sequence[int], + 'y': Sequence[float], }, expected_dtype={ - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], }, ), ]) @@ -320,7 +319,7 @@ def test_read_mode_with_transforms(self): class FakeModel: - def __call__(self, example: List[str]) -> List[str]: + def __call__(self, example: list[str]) -> list[str]: for i in range(len(example)): if not isinstance(example[i], str): raise TypeError('Input must be a string') @@ -333,7 +332,7 @@ def run_inference( self, batch: Sequence[str], model: Any, - inference_args: Optional[Dict[str, Any]] = None): + inference_args: Optional[dict[str, Any]] = None): return model(batch) def load_model(self): @@ -345,7 +344,7 @@ def __init__(self, columns, **kwargs): super().__init__(columns=columns, **kwargs) def get_model_handler(self) -> ModelHandler: - FakeModelHandler.__repr__ = lambda x: 'FakeEmbeddingsManager' # type: ignore[assignment] + FakeModelHandler.__repr__ = lambda x: 'FakeEmbeddingsManager' # type: ignore[method-assign] return FakeModelHandler() def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: @@ -508,7 +507,7 @@ def test_handler_with_inconsistent_keys(self): class FakeImageModel: - def __call__(self, example: List[PIL_Image]) -> List[PIL_Image]: + def __call__(self, example: list[PIL_Image]) -> list[PIL_Image]: for i in range(len(example)): if not isinstance(example[i], PIL_Image): raise TypeError('Input must be an Image') @@ -520,7 +519,7 @@ def run_inference( self, batch: Sequence[PIL_Image], model: Any, - inference_args: Optional[Dict[str, Any]] = None): + inference_args: Optional[dict[str, Any]] = None): return model(batch) def load_model(self): @@ -532,7 +531,7 @@ def __init__(self, columns, **kwargs): super().__init__(columns=columns, **kwargs) def get_model_handler(self) -> ModelHandler: - FakeModelHandler.__repr__ = lambda x: 'FakeImageEmbeddingsManager' # type: ignore[assignment] + FakeModelHandler.__repr__ = lambda x: 'FakeImageEmbeddingsManager' # type: ignore[method-assign] return FakeImageModelHandler() def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: @@ -701,7 +700,7 @@ def test_mltransform_to_ptransform_wrapper(self): @unittest.skipIf(apiclient is None, 'apache_beam[gcp] is not installed.') def test_with_gcs_location_with_none_options(self): - path = 'gs://fake_path' + path = f"gs://fake_path_{secrets.token_hex(3)}_{int(time.time())}" with self.assertRaises(RuntimeError): self.attribute_manager.save_attributes( ptransform_list=[], artifact_location=path, options=None) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py index 46b4ef9cf7d6..2162ed050c42 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py @@ -18,13 +18,11 @@ import logging import os +from collections.abc import Callable +from collections.abc import Mapping +from collections.abc import Sequence from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Mapping from typing import Optional -from typing import Sequence import requests @@ -80,7 +78,7 @@ def run_inference( self, batch: Sequence[str], model: SentenceTransformer, - inference_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[dict[str, Any]] = None, ): inference_args = inference_args or {} return model.encode(batch, **inference_args) @@ -113,7 +111,7 @@ class SentenceTransformerEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, - columns: List[str], + columns: list[str], max_seq_length: Optional[int] = None, image_model: bool = False, **kwargs): @@ -216,7 +214,7 @@ class InferenceAPIEmbeddings(EmbeddingsManager): def __init__( self, hf_token: Optional[str], - columns: List[str], + columns: list[str], model_name: Optional[str] = None, # example: "sentence-transformers/all-MiniLM-l6-v2" # pylint: disable=line-too-long api_url: Optional[str] = None, **kwargs, diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py index f78ddf3ff04a..c14904df7c2c 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -14,8 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable -from typing import List +from collections.abc import Iterable from typing import Optional import apache_beam as beam @@ -90,7 +89,7 @@ def run_inference(self, batch, model, inference_args, model_id=None): class TensorflowHubTextEmbeddings(EmbeddingsManager): def __init__( self, - columns: List[str], + columns: list[str], hub_url: str, preprocessing_url: Optional[str] = None, **kwargs): @@ -136,7 +135,7 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: class TensorflowHubImageEmbeddings(EmbeddingsManager): - def __init__(self, columns: List[str], hub_url: str, **kwargs): + def __init__(self, columns: list[str], hub_url: str, **kwargs): """ Embedding config for tensorflow hub models. This config can be used with MLTransform to embed image data. Models are loaded using the RunInference diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index fbefeec231f1..6df505508ae9 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -19,22 +19,27 @@ # Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long # to install Vertex AI Python SDK. +import logging +import time +from collections.abc import Iterable +from collections.abc import Sequence from typing import Any -from typing import Dict -from typing import Iterable -from typing import List from typing import Optional -from typing import Sequence +from google.api_core.exceptions import ServerError +from google.api_core.exceptions import TooManyRequests from google.auth.credentials import Credentials import apache_beam as beam import vertexai +from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler +from apache_beam.metrics.metric import Metrics from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import RunInference from apache_beam.ml.transforms.base import EmbeddingsManager from apache_beam.ml.transforms.base import _ImageEmbeddingHandler from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.utils import retry from vertexai.language_models import TextEmbeddingInput from vertexai.language_models import TextEmbeddingModel from vertexai.vision_models import Image @@ -53,6 +58,26 @@ "CLUSTERING" ] _BATCH_SIZE = 5 # Vertex AI limits requests to 5 at a time. +_MSEC_TO_SEC = 1000 + +LOGGER = logging.getLogger("VertexAIEmbeddings") + + +def _retry_on_appropriate_gcp_error(exception): + """ + Retry filter that returns True if a returned HTTP error code is 5xx or 429. + This is used to retry remote requests that fail, most notably 429 + (TooManyRequests.) + + Args: + exception: the returned exception encountered during the request/response + loop. + + Returns: + boolean indication whether or not the exception is a Server Error (5xx) or + a TooManyRequests (429) error. + """ + return isinstance(exception, (TooManyRequests, ServerError)) class _VertexAITextEmbeddingHandler(ModelHandler): @@ -76,11 +101,46 @@ def __init__( self.task_type = task_type self.title = title + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # for more details. + self.throttled_secs = Metrics.counter( + VertexAIImageEmbeddings, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=1, bucket_ms=1, overload_ratio=2) + + @retry.with_exponential_backoff( + num_retries=5, retry_filter=_retry_on_appropriate_gcp_error) + def get_request( + self, + text_batch: Sequence[TextEmbeddingInput], + model: MultiModalEmbeddingModel, + throttle_delay_secs: int): + while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC): + LOGGER.info( + "Delaying request for %d seconds due to previous failures", + throttle_delay_secs) + time.sleep(throttle_delay_secs) + self.throttled_secs.inc(throttle_delay_secs) + + try: + req_time = time.time() + prediction = model.get_embeddings(text_batch) + self.throttler.successful_request(req_time * _MSEC_TO_SEC) + return prediction + except TooManyRequests as e: + LOGGER.warning("request was limited by the service with code %i", e.code) + raise + except Exception as e: + LOGGER.error("unexpected exception raised as part of request, got %s", e) + raise + def run_inference( self, batch: Sequence[str], model: Any, - inference_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[dict[str, Any]] = None, ) -> Iterable: embeddings = [] batch_size = _BATCH_SIZE @@ -91,7 +151,8 @@ def run_inference( text=text, title=self.title, task_type=self.task_type) for text in text_batch ] - embeddings_batch = model.get_embeddings(text_batch) + embeddings_batch = self.get_request( + text_batch=text_batch, model=model, throttle_delay_secs=5) embeddings.extend([el.values for el in embeddings_batch]) return embeddings @@ -110,7 +171,7 @@ class VertexAITextEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, - columns: List[str], + columns: list[str], title: Optional[str] = None, task_type: str = DEFAULT_TASK_TYPE, project: Optional[str] = None, @@ -175,17 +236,51 @@ def __init__( self.model_name = model_name self.dimension = dimension + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # for more details. + self.throttled_secs = Metrics.counter( + VertexAIImageEmbeddings, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=1, bucket_ms=1, overload_ratio=2) + + @retry.with_exponential_backoff( + num_retries=5, retry_filter=_retry_on_appropriate_gcp_error) + def get_request( + self, + img: Image, + model: MultiModalEmbeddingModel, + throttle_delay_secs: int): + while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC): + LOGGER.info( + "Delaying request for %d seconds due to previous failures", + throttle_delay_secs) + time.sleep(throttle_delay_secs) + self.throttled_secs.inc(throttle_delay_secs) + + try: + req_time = time.time() + prediction = model.get_embeddings(image=img, dimension=self.dimension) + self.throttler.successful_request(req_time * _MSEC_TO_SEC) + return prediction + except TooManyRequests as e: + LOGGER.warning("request was limited by the service with code %i", e.code) + raise + except Exception as e: + LOGGER.error("unexpected exception raised as part of request, got %s", e) + raise + def run_inference( self, batch: Sequence[Image], model: MultiModalEmbeddingModel, - inference_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[dict[str, Any]] = None, ) -> Iterable: embeddings = [] # Maximum request size for muli-model embedding models is 1. for img in batch: - embedding_response = model.get_embeddings( - image=img, dimension=self.dimension) + embedding_response = self.get_request(img, model, throttle_delay_secs=5) embeddings.append(embedding_response.image_embedding) return embeddings @@ -204,7 +299,7 @@ class VertexAIImageEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, - columns: List[str], + columns: list[str], dimension: Optional[int], project: Optional[str] = None, location: Optional[str] = None, diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index 7a912f2d88ea..1e752049f6e5 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -20,11 +20,10 @@ import copy import os import typing +from collections.abc import Sequence from typing import Any -from typing import Dict -from typing import List +from typing import NamedTuple from typing import Optional -from typing import Sequence from typing import Union import numpy as np @@ -71,18 +70,18 @@ np.str_: tf.string, } _primitive_types_to_typing_container_type = { - int: List[int], float: List[float], str: List[str], bytes: List[bytes] + int: list[int], float: list[float], str: list[str], bytes: list[bytes] } -tft_process_handler_input_type = typing.Union[typing.NamedTuple, - beam.Row, - Dict[str, - typing.Union[str, - float, - int, - bytes, - np.ndarray]]] -tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]] +tft_process_handler_input_type = Union[NamedTuple, + beam.Row, + dict[str, + Union[str, + float, + int, + bytes, + np.ndarray]]] +tft_process_handler_output_type = Union[beam.Row, dict[str, np.ndarray]] class _DataCoder: @@ -131,15 +130,15 @@ def process( class _ConvertNamedTupleToDict( - beam.PTransform[beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]], - beam.PCollection[Dict[str, + beam.PTransform[beam.PCollection[Union[beam.Row, NamedTuple]], + beam.PCollection[dict[str, common_types.InstanceDictType]]]): """ A PTransform that converts a collection of NamedTuples or Rows into a collection of dictionaries. """ def expand( - self, pcoll: beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]] + self, pcoll: beam.PCollection[Union[beam.Row, NamedTuple]] ) -> beam.PCollection[common_types.InstanceDictType]: """ Args: @@ -163,7 +162,7 @@ def __init__( operations. """ self.transforms = transforms if transforms else [] - self.transformed_schema: Dict[str, type] = {} + self.transformed_schema: dict[str, type] = {} self.artifact_location = artifact_location self.artifact_mode = artifact_mode if artifact_mode not in ['produce', 'consume']: @@ -217,7 +216,7 @@ def _map_column_names_to_types_from_transforms(self): return column_type_mapping def get_raw_data_feature_spec( - self, input_types: Dict[str, type]) -> Dict[str, tf.io.VarLenFeature]: + self, input_types: dict[str, type]) -> dict[str, tf.io.VarLenFeature]: """ Return a DatasetMetadata object to be used with tft_beam.AnalyzeAndTransformDataset. @@ -265,7 +264,7 @@ def _get_raw_data_feature_spec_per_column( f"Union type is not supported for column: {col_name}. " f"Please pass a PCollection with valid schema for column " f"{col_name} by passing a single type " - "in container. For example, List[int].") + "in container. For example, list[int].") elif issubclass(typ, np.generic) or typ in _default_type_to_tensor_type_map: dtype = typ else: @@ -276,7 +275,7 @@ def _get_raw_data_feature_spec_per_column( return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype]) def get_raw_data_metadata( - self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + self, input_types: dict[str, type]) -> dataset_metadata.DatasetMetadata: raw_data_feature_spec = self.get_raw_data_feature_spec(input_types) raw_data_feature_spec[_TEMP_KEY] = tf.io.VarLenFeature(dtype=tf.string) return self.convert_raw_data_feature_spec_to_dataset_metadata( @@ -305,8 +304,8 @@ def _fail_on_non_default_windowing(self, pcoll: beam.PCollection): "to convert your PCollection to GlobalWindow.") def process_data_fn( - self, inputs: Dict[str, common_types.ConsistentTensorType] - ) -> Dict[str, common_types.ConsistentTensorType]: + self, inputs: dict[str, common_types.ConsistentTensorType] + ) -> dict[str, common_types.ConsistentTensorType]: """ This method is used in the AnalyzeAndTransformDataset step. It applies the transforms to the `inputs` in sequential order on the columns @@ -335,11 +334,11 @@ def _get_transformed_data_schema( name = feature.name feature_type = feature.type if feature_type == schema_pb2.FeatureType.FLOAT: - transformed_types[name] = typing.Sequence[np.float32] + transformed_types[name] = Sequence[np.float32] elif feature_type == schema_pb2.FeatureType.INT: - transformed_types[name] = typing.Sequence[np.int64] # type: ignore[assignment] + transformed_types[name] = Sequence[np.int64] # type: ignore[assignment] else: - transformed_types[name] = typing.Sequence[bytes] # type: ignore[assignment] + transformed_types[name] = Sequence[bytes] # type: ignore[assignment] return transformed_types def expand( @@ -372,7 +371,7 @@ def expand( raw_data = ( raw_data | _ConvertNamedTupleToDict().with_output_types( - Dict[str, typing.Union[tuple(column_type_mapping.values())]])) # type: ignore + dict[str, Union[tuple(column_type_mapping.values())]])) # type: ignore # AnalyzeAndTransformDataset raise type hint since this is # schema'd PCollection and the current output type would be a # custom type(NamedTuple) or a beam.Row type. @@ -408,7 +407,7 @@ def expand( raw_data = ( raw_data | _ConvertNamedTupleToDict().with_output_types( - Dict[str, typing.Union[tuple(column_type_mapping.values())]])) # type: ignore + dict[str, Union[tuple(column_type_mapping.values())]])) # type: ignore feature_set = [feature.name for feature in raw_data_metadata.schema.feature] diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py index 1331f1308087..bb5f9b5f0f70 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers_test.py +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -20,10 +20,9 @@ import shutil import sys import tempfile -import typing import unittest import uuid -from typing import List +from collections.abc import Sequence from typing import NamedTuple from typing import Union @@ -65,7 +64,7 @@ class IntType(NamedTuple): class ListIntType(NamedTuple): - x: List[int] + x: list[int] class NumpyType(NamedTuple): @@ -111,7 +110,7 @@ def test_input_type_from_schema_named_tuple_pcoll(self): artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) @@ -126,7 +125,7 @@ def test_input_type_from_schema_named_tuple_pcoll_list(self): artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) def test_input_type_from_row_type_pcoll(self): @@ -140,7 +139,7 @@ def test_input_type_from_row_type_pcoll(self): artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) def test_input_type_from_row_type_pcoll_list(self): @@ -149,14 +148,14 @@ def test_input_type_from_row_type_pcoll_list(self): data = ( p | beam.Create(data) | beam.Map(lambda ele: beam.Row(x=list(ele['x']))).with_output_types( - beam.row_type.RowTypeConstraint.from_fields([('x', List[int])]))) + beam.row_type.RowTypeConstraint.from_fields([('x', list[int])]))) element_type = data.element_type process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) def test_input_type_from_named_tuple_pcoll_numpy(self): @@ -190,8 +189,8 @@ def test_tensorflow_raw_data_metadata_primitive_types(self): self.assertIsInstance(feature_spec, tf.io.VarLenFeature) def test_tensorflow_raw_data_metadata_primitive_types_in_containers(self): - input_types = dict([("x", List[int]), ("y", List[float]), - ("k", List[bytes]), ("l", List[str])]) + input_types = dict([("x", list[int]), ("y", list[float]), + ("k", list[bytes]), ("l", list[str])]) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) for col_name, typ in input_types.items(): @@ -211,7 +210,7 @@ def test_tensorflow_raw_data_metadata_primitive_native_container_types(self): self.assertIsInstance(feature_spec, tf.io.VarLenFeature) def test_tensorflow_raw_data_metadata_numpy_types(self): - input_types = dict(x=np.int64, y=np.float32, z=List[np.int64]) + input_types = dict(x=np.int64, y=np.float32, z=list[np.int64]) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) for col_name, typ in input_types.items(): @@ -277,9 +276,9 @@ def test_tft_process_handler_transformed_data_schema(self): schema_utils.schema_from_feature_spec(raw_data_feature_spec)) expected_transformed_data_schema = { - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], - 'z': typing.Sequence[bytes] + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], + 'z': Sequence[bytes] } actual_transformed_data_schema = ( diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index 6903cca89419..bfe23757642b 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -34,12 +34,9 @@ # pytype: skip-file import logging +from collections.abc import Iterable from typing import Any -from typing import Dict -from typing import Iterable -from typing import List from typing import Optional -from typing import Tuple from typing import Union import apache_beam as beam @@ -67,7 +64,7 @@ # Register the expected input types for each operation # this will be used to determine schema for the tft.AnalyzeDataset -_EXPECTED_TYPES: Dict[str, Union[int, str, float]] = {} +_EXPECTED_TYPES: dict[str, Union[int, str, float]] = {} _LOGGER = logging.getLogger(__name__) @@ -84,7 +81,7 @@ def wrapper(fn): # Add support for outputting artifacts to a text file in human readable form. class TFTOperation(BaseOperation[common_types.TensorType, common_types.TensorType]): - def __init__(self, columns: List[str]) -> None: + def __init__(self, columns: list[str]) -> None: """ Base Operation class for TFT data processing transformations. Processing logic for the transformation is defined in the @@ -150,7 +147,7 @@ def _split_string_with_delimiter(self, data, delimiter): class ComputeAndApplyVocabulary(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], split_string_by_delimiter: Optional[str] = None, *, default_value: Any = -1, @@ -193,7 +190,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: if self.split_string_by_delimiter: data = self._split_string_with_delimiter( @@ -218,7 +215,7 @@ def apply_transform( class ScaleToZScore(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], *, elementwise: bool = False, name: Optional[str] = None): @@ -247,7 +244,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_column_name: tft.scale_to_z_score( x=data, elementwise=self.elementwise, name=self.name) @@ -259,7 +256,7 @@ def apply_transform( class ScaleTo01(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], elementwise: bool = False, name: Optional[str] = None): """ @@ -287,7 +284,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = tft.scale_to_0_1( x=data, elementwise=self.elementwise, name=self.name) @@ -299,7 +296,7 @@ def apply_transform( class ScaleToGaussian(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], elementwise: bool = False, name: Optional[str] = None): """ @@ -324,7 +321,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_column_name: tft.scale_to_gaussian( x=data, elementwise=self.elementwise, name=self.name) @@ -336,7 +333,7 @@ def apply_transform( class ApplyBuckets(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], bucket_boundaries: Iterable[Union[int, float]], name: Optional[str] = None): """ @@ -359,7 +356,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = { output_column_name: tft.apply_buckets( x=data, bucket_boundaries=self.bucket_boundaries, name=self.name) @@ -371,7 +368,7 @@ def apply_transform( class ApplyBucketsWithInterpolation(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], bucket_boundaries: Iterable[Union[int, float]], name: Optional[str] = None): """Interpolates values within the provided buckets and then normalizes to @@ -398,7 +395,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = { output_column_name: tft.apply_buckets_with_interpolation( x=data, bucket_boundaries=self.bucket_boundaries, name=self.name) @@ -410,7 +407,7 @@ def apply_transform( class Bucketize(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], num_buckets: int, *, epsilon: Optional[float] = None, @@ -443,7 +440,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = { output_column_name: tft.bucketize( x=data, @@ -459,7 +456,7 @@ def apply_transform( class TFIDF(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], vocab_size: Optional[int] = None, smooth: bool = True, name: Optional[str] = None, @@ -530,7 +527,7 @@ def apply_transform( class ScaleByMinMax(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], min_value: float = 0.0, max_value: float = 1.0, name: Optional[str] = None): @@ -566,10 +563,10 @@ def apply_transform( class NGrams(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], split_string_by_delimiter: Optional[str] = None, *, - ngram_range: Tuple[int, int] = (1, 1), + ngram_range: tuple[int, int] = (1, 1), ngrams_separator: Optional[str] = None, name: Optional[str] = None): """ @@ -599,7 +596,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: if self.split_string_by_delimiter: data = self._split_string_with_delimiter( data, self.split_string_by_delimiter) @@ -611,10 +608,10 @@ def apply_transform( class BagOfWords(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], split_string_by_delimiter: Optional[str] = None, *, - ngram_range: Tuple[int, int] = (1, 1), + ngram_range: tuple[int, int] = (1, 1), ngrams_separator: Optional[str] = None, compute_word_count: bool = False, key_vocab_filename: Optional[str] = None, @@ -686,9 +683,9 @@ def count_unique_words( class HashStrings(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], hash_buckets: int, - key: Optional[Tuple[int, int]] = None, + key: Optional[tuple[int, int]] = None, name: Optional[str] = None): '''Hashes strings into the provided number of buckets. @@ -715,7 +712,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_col_name: str) -> Dict[str, common_types.TensorType]: + output_col_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_col_name: tft.hash_strings( strings=data, @@ -728,7 +725,7 @@ def apply_transform( @register_input_dtype(str) class DeduplicateTensorPerRow(TFTOperation): - def __init__(self, columns: List[str], name: Optional[str] = None): + def __init__(self, columns: list[str], name: Optional[str] = None): """ Deduplicates each row (0th dimension) of the provided tensor. Args: @@ -740,7 +737,7 @@ def __init__(self, columns: List[str], name: Optional[str] = None): def apply_transform( self, data: common_types.TensorType, - output_col_name: str) -> Dict[str, common_types.TensorType]: + output_col_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_col_name: tft.deduplicate_tensor_per_row( input_tensor=data, name=self.name) diff --git a/sdks/python/apache_beam/ml/transforms/utils.py b/sdks/python/apache_beam/ml/transforms/utils.py index abf4c48fc642..023657895686 100644 --- a/sdks/python/apache_beam/ml/transforms/utils.py +++ b/sdks/python/apache_beam/ml/transforms/utils.py @@ -19,7 +19,6 @@ import os import tempfile -import typing from google.cloud.storage import Client from google.cloud.storage import transfer_manager @@ -72,7 +71,7 @@ def __init__(self, artifact_location: str): self._artifact_location = os.path.join(artifact_location, files[0]) self.transform_output = tft.TFTransformOutput(self._artifact_location) - def get_vocab_list(self, vocab_filename: str) -> typing.List[bytes]: + def get_vocab_list(self, vocab_filename: str) -> list[bytes]: """ Returns list of vocabulary terms created during MLTransform. """ diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 4497ab0993a4..8eba11d9ea34 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -267,8 +267,7 @@ def __getstate__(self): return self.__dict__ @classmethod - def _add_argparse_args(cls, parser): - # type: (_BeamArgumentParser) -> None + def _add_argparse_args(cls, parser: _BeamArgumentParser) -> None: # Override this in subclasses to provide options. pass @@ -317,11 +316,8 @@ def from_dictionary(cls, options): def get_all_options( self, drop_default=False, - add_extra_args_fn=None, # type: Optional[Callable[[_BeamArgumentParser], None]] - retain_unknown_options=False - ): - # type: (...) -> Dict[str, Any] - + add_extra_args_fn: Optional[Callable[[_BeamArgumentParser], None]] = None, + retain_unknown_options=False) -> Dict[str, Any]: """Returns a dictionary of all defined arguments. Returns a dictionary of all defined arguments (arguments that are defined in @@ -446,9 +442,7 @@ def from_urn(key): def display_data(self): return self.get_all_options(drop_default=True, retain_unknown_options=True) - def view_as(self, cls): - # type: (Type[PipelineOptionsT]) -> PipelineOptionsT - + def view_as(self, cls: Type[PipelineOptionsT]) -> PipelineOptionsT: """Returns a view of current object as provided PipelineOption subclass. Example Usage:: @@ -487,13 +481,11 @@ def view_as(self, cls): view._all_options = self._all_options return view - def _visible_option_list(self): - # type: () -> List[str] + def _visible_option_list(self) -> List[str]: return sorted( option for option in dir(self._visible_options) if option[0] != '_') - def __dir__(self): - # type: () -> List[str] + def __dir__(self) -> List[str]: return sorted( dir(type(self)) + list(self.__dict__) + self._visible_option_list()) @@ -643,9 +635,9 @@ def additional_option_ptransform_fn(): # Optional type checks that aren't enabled by default. -additional_type_checks = { +additional_type_checks: Dict[str, Callable[[], None]] = { 'ptransform_fn': additional_option_ptransform_fn, -} # type: Dict[str, Callable[[], None]] +} def enable_all_additional_type_checks(): @@ -1674,12 +1666,16 @@ def _add_argparse_args(cls, parser): action='append', default=[], help='JVM properties to pass to a Java job server.') + parser.add_argument( + '--jar_cache_dir', + default=None, + help='The location to store jar cache for job server.') class FlinkRunnerOptions(PipelineOptions): # These should stay in sync with gradle.properties. - PUBLISHED_FLINK_VERSIONS = ['1.15', '1.16', '1.17', '1.18'] + PUBLISHED_FLINK_VERSIONS = ['1.17', '1.18', '1.19'] @classmethod def _add_argparse_args(cls, parser): @@ -1836,7 +1832,7 @@ class OptionsContext(object): Can also be used as a decorator. """ - overrides = [] # type: List[Dict[str, Any]] + overrides: List[Dict[str, Any]] = [] def __init__(self, **options): self.options = options diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index c0616bc6451c..66acfe654791 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -31,6 +31,7 @@ from apache_beam.options.pipeline_options import CrossLanguageOptions from apache_beam.options.pipeline_options import DebugOptions from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.options.pipeline_options import JobServerOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import ProfilingOptions from apache_beam.options.pipeline_options import TypeOptions @@ -645,6 +646,11 @@ def test_transform_name_mapping(self): mapping = options.view_as(GoogleCloudOptions).transform_name_mapping self.assertEqual(mapping['from'], 'to') + def test_jar_cache_dir(self): + options = PipelineOptions(['--jar_cache_dir=/path/to/jar_cache_dir']) + jar_cache_dir = options.view_as(JobServerOptions).jar_cache_dir + self.assertEqual(jar_cache_dir, '/path/to/jar_cache_dir') + def test_dataflow_service_options(self): options = PipelineOptions([ '--dataflow_service_option', diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 61aac350280f..1c11f953c58d 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -1053,7 +1053,7 @@ def expand(self, p): self.p = p return p | beam.Create([None]) - def display_data(self): # type: () -> dict + def display_data(self) -> dict: parent_dd = super().display_data() parent_dd['p_dd_string'] = DisplayDataItem( 'p_dd_string_value', label='p_dd_string_label') @@ -1067,7 +1067,7 @@ def expand(self, p): self.p = p return p | beam.Create([None]) - def display_data(self): # type: () -> dict + def display_data(self) -> dict: parent_dd = super().display_data() parent_dd['dd_string'] = DisplayDataItem( 'dd_string_value', label='dd_string_label') @@ -1183,7 +1183,7 @@ class UseMaxValueHint(ResourceHint): @classmethod def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + cls, outer_value: bytes, inner_value: bytes) -> bytes: return ResourceHint._use_max(outer_value, inner_value) ResourceHint.register_resource_hint('foo_hint', FooHint) @@ -1312,7 +1312,7 @@ class UseMaxValueHint(ResourceHint): @classmethod def get_merged_value( - cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes + cls, outer_value: bytes, inner_value: bytes) -> bytes: return ResourceHint._use_max(outer_value, inner_value) ResourceHint.register_resource_hint('foo_hint', FooHint) diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py index 4effc91c3d40..74d9a39bb052 100644 --- a/sdks/python/apache_beam/portability/common_urns.py +++ b/sdks/python/apache_beam/portability/common_urns.py @@ -38,6 +38,7 @@ StandardSideInputTypes = beam_runner_api_pb2_urns.StandardSideInputTypes StandardUserStateTypes = beam_runner_api_pb2_urns.StandardUserStateTypes ExpansionMethods = external_transforms_pb2_urns.ExpansionMethods +ManagedTransforms = external_transforms_pb2_urns.ManagedTransforms MonitoringInfo = metrics_pb2_urns.MonitoringInfo MonitoringInfoSpecs = metrics_pb2_urns.MonitoringInfoSpecs MonitoringInfoTypeUrns = metrics_pb2_urns.MonitoringInfoTypeUrns diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 8a4f26c18e88..c43870d55ebb 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -1504,8 +1504,7 @@ def process(self, windowed_value): return [] def _maybe_sample_exception( - self, exn: BaseException, - windowed_value: Optional[WindowedValue]) -> None: + self, exc_info: Tuple, windowed_value: Optional[WindowedValue]) -> None: if self.execution_context is None: return @@ -1516,7 +1515,7 @@ def _maybe_sample_exception( output_sampler.sample_exception( windowed_value, - exn, + exc_info, self.transform_id, self.execution_context.instruction_id) diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py b/sdks/python/apache_beam/runners/dask/dask_runner.py index 109c4379b45d..0f2317074cea 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner.py @@ -31,12 +31,22 @@ from apache_beam.pipeline import PipelineVisitor from apache_beam.runners.dask.overrides import dask_overrides from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS +from apache_beam.runners.dask.transform_evaluator import DaskBagWindowedIterator +from apache_beam.runners.dask.transform_evaluator import Flatten from apache_beam.runners.dask.transform_evaluator import NoOp from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineState +from apache_beam.transforms.sideinputs import SideInputMap from apache_beam.utils.interactive_utils import is_in_notebook +try: + # Added to try to prevent threading related issues, see + # https://github.com/pytest-dev/pytest/issues/3216#issuecomment-1502451456 + import dask.distributed as ddist +except ImportError: + ddist = {} + class DaskOptions(PipelineOptions): @staticmethod @@ -86,10 +96,9 @@ def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None: @dataclasses.dataclass class DaskRunnerResult(PipelineResult): - from dask import distributed - client: distributed.Client - futures: t.Sequence[distributed.Future] + client: ddist.Client + futures: t.Sequence[ddist.Future] def __post_init__(self): super().__init__(PipelineState.RUNNING) @@ -99,8 +108,16 @@ def wait_until_finish(self, duration=None) -> str: if duration is not None: # Convert milliseconds to seconds duration /= 1000 - self.client.wait_for_workers(timeout=duration) - self.client.gather(self.futures, errors='raise') + for _ in ddist.as_completed(self.futures, + timeout=duration, + with_results=True): + # without gathering results, worker errors are not raised on the client: + # https://distributed.dask.org/en/stable/resilience.html#user-code-failures + # so we want to gather results to raise errors client-side, but we do + # not actually need to use the results here, so we just pass. to gather, + # we use the iterative `as_completed(..., with_results=True)`, instead + # of aggregate `client.gather`, to minimize memory footprint of results. + pass self._state = PipelineState.DONE except: # pylint: disable=broad-except self._state = PipelineState.FAILED @@ -133,6 +150,7 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None: op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp) op = op_class(transform_node) + op_kws = {"input_bag": None, "side_inputs": None} inputs = list(transform_node.inputs) if inputs: bag_inputs = [] @@ -144,13 +162,28 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None: if prev_op in self.bags: bag_inputs.append(self.bags[prev_op]) - if len(bag_inputs) == 1: - self.bags[transform_node] = op.apply(bag_inputs[0]) + # Input to `Flatten` could be of length 1, e.g. a single-element + # tuple: `(pcoll, ) | beam.Flatten()`. If so, we still pass it as + # an iterable, because `Flatten.apply` always takes an iterable. + if len(bag_inputs) == 1 and not isinstance(op, Flatten): + op_kws["input_bag"] = bag_inputs[0] else: - self.bags[transform_node] = op.apply(bag_inputs) + op_kws["input_bag"] = bag_inputs + + side_inputs = list(transform_node.side_inputs) + if side_inputs: + bag_side_inputs = [] + for si in side_inputs: + si_asbag = self.bags.get(si.pvalue.producer) + bag_side_inputs.append( + SideInputMap( + type(si), + si._view_options(), + DaskBagWindowedIterator(si_asbag, si._window_mapping_fn))) + + op_kws["side_inputs"] = bag_side_inputs - else: - self.bags[transform_node] = op.apply(None) + self.bags[transform_node] = op.apply(**op_kws) return DaskBagVisitor() @@ -159,6 +192,8 @@ def is_fnapi_compatible(): return False def run_pipeline(self, pipeline, options): + import dask + # TODO(alxr): Create interactive notebook support. if is_in_notebook(): raise NotImplementedError('interactive support will come later!') @@ -177,6 +212,6 @@ def run_pipeline(self, pipeline, options): dask_visitor = self.to_dask_bag_visitor() pipeline.visit(dask_visitor) - - futures = client.compute(list(dask_visitor.bags.values())) + opt_graph = dask.optimize(*list(dask_visitor.bags.values())) + futures = client.compute(opt_graph) return DaskRunnerResult(client, futures) diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py b/sdks/python/apache_beam/runners/dask/dask_runner_test.py index d8b3e17d8a56..66dda4a984f4 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner_test.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import datetime import inspect +import typing as t import unittest import apache_beam as beam @@ -22,12 +24,14 @@ from apache_beam.testing import test_pipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms import window try: - from apache_beam.runners.dask.dask_runner import DaskOptions - from apache_beam.runners.dask.dask_runner import DaskRunner import dask import dask.distributed as ddist + + from apache_beam.runners.dask.dask_runner import DaskOptions # pylint: disable=ungrouped-imports + from apache_beam.runners.dask.dask_runner import DaskRunner # pylint: disable=ungrouped-imports except (ImportError, ModuleNotFoundError): raise unittest.SkipTest('Dask must be installed to run tests.') @@ -73,6 +77,11 @@ def test_create(self): pcoll = p | beam.Create([1]) assert_that(pcoll, equal_to([1])) + def test_create_multiple(self): + with self.pipeline as p: + pcoll = p | beam.Create([1, 2, 3, 4]) + assert_that(pcoll, equal_to([1, 2, 3, 4])) + def test_create_and_map(self): def double(x): return x * 2 @@ -81,6 +90,22 @@ def double(x): pcoll = p | beam.Create([1]) | beam.Map(double) assert_that(pcoll, equal_to([2])) + def test_create_and_map_multiple(self): + def double(x): + return x * 2 + + with self.pipeline as p: + pcoll = p | beam.Create([1, 2]) | beam.Map(double) + assert_that(pcoll, equal_to([2, 4])) + + def test_create_and_map_many(self): + def double(x): + return x * 2 + + with self.pipeline as p: + pcoll = p | beam.Create(list(range(1, 11))) | beam.Map(double) + assert_that(pcoll, equal_to(list(range(2, 21, 2)))) + def test_create_map_and_groupby(self): def double(x): return x * 2, x @@ -89,6 +114,288 @@ def double(x): pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey() assert_that(pcoll, equal_to([(2, [1])])) + def test_create_map_and_groupby_multiple(self): + def double(x): + return x * 2, x + + with self.pipeline as p: + pcoll = ( + p + | beam.Create([1, 2, 1, 2, 3]) + | beam.Map(double) + | beam.GroupByKey()) + assert_that(pcoll, equal_to([(2, [1, 1]), (4, [2, 2]), (6, [3])])) + + def test_map_with_positional_side_input(self): + def mult_by(x, y): + return x * y + + with self.pipeline as p: + side = p | "side" >> beam.Create([3]) + pcoll = ( + p + | "main" >> beam.Create([1]) + | beam.Map(mult_by, beam.pvalue.AsSingleton(side))) + assert_that(pcoll, equal_to([3])) + + def test_map_with_keyword_side_input(self): + def mult_by(x, y): + return x * y + + with self.pipeline as p: + side = p | "side" >> beam.Create([3]) + pcoll = ( + p + | "main" >> beam.Create([1]) + | beam.Map(mult_by, y=beam.pvalue.AsSingleton(side))) + assert_that(pcoll, equal_to([3])) + + def test_pardo_side_inputs(self): + def cross_product(elem, sides): + for side in sides: + yield elem, side + + with self.pipeline as p: + main = p | "main" >> beam.Create(["a", "b", "c"]) + side = p | "side" >> beam.Create(["x", "y"]) + assert_that( + main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)), + equal_to([ + ("a", "x"), + ("b", "x"), + ("c", "x"), + ("a", "y"), + ("b", "y"), + ("c", "y"), + ]), + ) + + def test_pardo_side_input_dependencies(self): + with self.pipeline as p: + inputs = [p | beam.Create([None])] + for k in range(1, 10): + inputs.append( + inputs[0] + | beam.ParDo( + ExpectingSideInputsFn(f"Do{k}"), + *[beam.pvalue.AsList(inputs[s]) for s in range(1, k)], + )) + + def test_pardo_side_input_sparse_dependencies(self): + with self.pipeline as p: + inputs = [] + + def choose_input(s): + return inputs[(389 + s * 5077) % len(inputs)] + + for k in range(20): + num_inputs = int((k * k % 16)**0.5) + if num_inputs == 0: + inputs.append(p | f"Create{k}" >> beam.Create([f"Create{k}"])) + else: + inputs.append( + choose_input(0) + | beam.ParDo( + ExpectingSideInputsFn(f"Do{k}"), + *[ + beam.pvalue.AsList(choose_input(s)) + for s in range(1, num_inputs) + ], + )) + + @unittest.expectedFailure + def test_pardo_windowed_side_inputs(self): + with self.pipeline as p: + # Now with some windowing. + pcoll = ( + p + | beam.Create(list(range(10))) + | beam.Map(lambda t: window.TimestampedValue(t, t))) + # Intentionally choosing non-aligned windows to highlight the transition. + main = pcoll | "WindowMain" >> beam.WindowInto(window.FixedWindows(5)) + side = pcoll | "WindowSide" >> beam.WindowInto(window.FixedWindows(7)) + res = main | beam.Map( + lambda x, s: (x, sorted(s)), beam.pvalue.AsList(side)) + assert_that( + res, + equal_to([ + # The window [0, 5) maps to the window [0, 7). + (0, list(range(7))), + (1, list(range(7))), + (2, list(range(7))), + (3, list(range(7))), + (4, list(range(7))), + # The window [5, 10) maps to the window [7, 14). + (5, list(range(7, 10))), + (6, list(range(7, 10))), + (7, list(range(7, 10))), + (8, list(range(7, 10))), + (9, list(range(7, 10))), + ]), + label="windowed", + ) + + def test_flattened_side_input(self, with_transcoding=True): + with self.pipeline as p: + main = p | "main" >> beam.Create([None]) + side1 = p | "side1" >> beam.Create([("a", 1)]) + side2 = p | "side2" >> beam.Create([("b", 2)]) + if with_transcoding: + # Also test non-matching coder types (transcoding required) + third_element = [("another_type")] + else: + third_element = [("b", 3)] + side3 = p | "side3" >> beam.Create(third_element) + side = (side1, side2) | beam.Flatten() + assert_that( + main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)), + equal_to([(None, { + "a": 1, "b": 2 + })]), + label="CheckFlattenAsSideInput", + ) + assert_that( + (side, side3) | "FlattenAfter" >> beam.Flatten(), + equal_to([("a", 1), ("b", 2)] + third_element), + label="CheckFlattenOfSideInput", + ) + + def test_gbk_side_input(self): + with self.pipeline as p: + main = p | "main" >> beam.Create([None]) + side = p | "side" >> beam.Create([("a", 1)]) | beam.GroupByKey() + assert_that( + main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)), + equal_to([(None, { + "a": [1] + })]), + ) + + def test_multimap_side_input(self): + with self.pipeline as p: + main = p | "main" >> beam.Create(["a", "b"]) + side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)]) + assert_that( + main + | beam.Map( + lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)), + equal_to([("a", [1, 3]), ("b", [2])]), + ) + + def test_multimap_multiside_input(self): + # A test where two transforms in the same stage consume the same PCollection + # twice as side input. + with self.pipeline as p: + main = p | "main" >> beam.Create(["a", "b"]) + side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)]) + assert_that( + main + | "first map" >> beam.Map( + lambda k, + d, + l: (k, sorted(d[k]), sorted([e[1] for e in l])), + beam.pvalue.AsMultiMap(side), + beam.pvalue.AsList(side), + ) + | "second map" >> beam.Map( + lambda k, + d, + l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])), + beam.pvalue.AsMultiMap(side), + beam.pvalue.AsList(side), + ), + equal_to([("a", [1, 3], [1, 2, 3]), ("b", [2], [1, 2, 3])]), + ) + + def test_multimap_side_input_type_coercion(self): + with self.pipeline as p: + main = p | "main" >> beam.Create(["a", "b"]) + # The type of this side-input is forced to Any (overriding type + # inference). Without type coercion to Tuple[Any, Any], the usage of this + # side-input in AsMultiMap() below should fail. + side = p | "side" >> beam.Create([("a", 1), ("b", 2), + ("a", 3)]).with_output_types(t.Any) + assert_that( + main + | beam.Map( + lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)), + equal_to([("a", [1, 3]), ("b", [2])]), + ) + + def test_pardo_unfusable_side_inputs__one(self): + def cross_product(elem, sides): + for side in sides: + yield elem, side + + with self.pipeline as p: + pcoll = p | "Create1" >> beam.Create(["a", "b"]) + assert_that( + pcoll | + "FlatMap1" >> beam.FlatMap(cross_product, beam.pvalue.AsList(pcoll)), + equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]), + label="assert_that1", + ) + + def test_pardo_unfusable_side_inputs__two(self): + def cross_product(elem, sides): + for side in sides: + yield elem, side + + with self.pipeline as p: + pcoll = p | "Create2" >> beam.Create(["a", "b"]) + + derived = ((pcoll, ) + | beam.Flatten() + | beam.Map(lambda x: (x, x)) + | beam.GroupByKey() + | "Unkey" >> beam.Map(lambda kv: kv[0])) + assert_that( + pcoll | "FlatMap2" >> beam.FlatMap( + cross_product, beam.pvalue.AsList(derived)), + equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]), + label="assert_that2", + ) + + def test_groupby_with_fixed_windows(self): + def double(x): + return x * 2, x + + def add_timestamp(pair): + delta = datetime.timedelta(seconds=pair[1] * 60) + now = (datetime.datetime.now() + delta).timestamp() + return window.TimestampedValue(pair, now) + + with self.pipeline as p: + pcoll = ( + p + | beam.Create([1, 2, 1, 2, 3]) + | beam.Map(double) + | beam.WindowInto(window.FixedWindows(60)) + | beam.Map(add_timestamp) + | beam.GroupByKey()) + assert_that(pcoll, equal_to([(2, [1, 1]), (4, [2, 2]), (6, [3])])) + + def test_groupby_string_keys(self): + with self.pipeline as p: + pcoll = ( + p + | beam.Create([('a', 1), ('a', 2), ('b', 3), ('b', 4)]) + | beam.GroupByKey()) + assert_that(pcoll, equal_to([('a', [1, 2]), ('b', [3, 4])])) + + +class ExpectingSideInputsFn(beam.DoFn): + def __init__(self, name): + self._name = name + + def default_label(self): + return self._name + + def process(self, element, *side_inputs): + if not all(list(s) for s in side_inputs): + raise ValueError(f"Missing data in side input {side_inputs}") + yield self._name + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/dask/overrides.py b/sdks/python/apache_beam/runners/dask/overrides.py index d07c7cd518af..b952834f12d7 100644 --- a/sdks/python/apache_beam/runners/dask/overrides.py +++ b/sdks/python/apache_beam/runners/dask/overrides.py @@ -73,7 +73,6 @@ def infer_output_type(self, input_type): @typehints.with_input_types(t.Tuple[K, t.Iterable[V]]) @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupAlsoByWindow(beam.ParDo): - """Not used yet...""" def __init__(self, windowing): super().__init__(_GroupAlsoByWindowDoFn(windowing)) self.windowing = windowing @@ -86,12 +85,23 @@ def expand(self, input_or_inputs): @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) class _GroupByKey(beam.PTransform): def expand(self, input_or_inputs): - return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly() + return ( + input_or_inputs + | "ReifyWindows" >> beam.ParDo(beam.GroupByKey.ReifyWindows()) + | "GroupByKey" >> _GroupByKeyOnly() + | "GroupByWindow" >> _GroupAlsoByWindow(input_or_inputs.windowing)) class _Flatten(beam.PTransform): def expand(self, input_or_inputs): - is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs) + if isinstance(input_or_inputs, beam.PCollection): + # NOTE(cisaacstern): I needed this to avoid + # `TypeError: 'PCollection' object is not iterable` + # being raised by `all(...)` call below for single-element flattens, i.e., + # `(pcoll, ) | beam.Flatten() | ...` + is_bounded = input_or_inputs.is_bounded + else: + is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs) return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded) diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py b/sdks/python/apache_beam/runners/dask/transform_evaluator.py index d4d58879b7fe..e3bd5fd87763 100644 --- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py @@ -26,19 +26,110 @@ import dataclasses import math import typing as t +from dataclasses import field import apache_beam import dask.bag as db +from apache_beam import DoFn +from apache_beam import TaggedOutput from apache_beam.pipeline import AppliedPTransform +from apache_beam.runners.common import DoFnContext +from apache_beam.runners.common import DoFnInvoker +from apache_beam.runners.common import DoFnSignature +from apache_beam.runners.common import Receiver +from apache_beam.runners.common import _OutputHandler from apache_beam.runners.dask.overrides import _Create from apache_beam.runners.dask.overrides import _Flatten from apache_beam.runners.dask.overrides import _GroupByKeyOnly +from apache_beam.transforms.sideinputs import SideInputMap +from apache_beam.transforms.window import GlobalWindow +from apache_beam.transforms.window import TimestampedValue +from apache_beam.transforms.window import WindowFn +from apache_beam.utils.windowed_value import WindowedValue +# Inputs to DaskOps. OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None] +OpSide = t.Optional[t.Sequence[SideInputMap]] + +# Value types for PCollections (possibly Windowed Values). +PCollVal = t.Union[WindowedValue, t.Any] + + +def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue: + """Wraps a value (item) inside a Window.""" + if isinstance(item, TaggedOutput): + item = item.value + + if isinstance(item, WindowedValue): + windowed_value = item + elif isinstance(item, TimestampedValue): + assign_context = WindowFn.AssignContext(item.timestamp, item.value) + windowed_value = WindowedValue( + item.value, item.timestamp, tuple(window_fn.assign(assign_context))) + else: + windowed_value = WindowedValue(item, 0, (GlobalWindow(), )) + + return windowed_value + + +def defenestrate(x): + """Extracts the underlying item from a Window.""" + if isinstance(x, WindowedValue): + return x.value + return x + + +@dataclasses.dataclass +class DaskBagWindowedIterator: + """Iterator for `apache_beam.transforms.sideinputs.SideInputMap`""" + + bag: db.Bag + window_fn: WindowFn + + def __iter__(self): + # FIXME(cisaacstern): list() is likely inefficient, since it presumably + # materializes the full result before iterating over it. doing this for + # now as a proof-of-concept. can we can generate results incrementally? + for result in list(self.bag): + yield get_windowed_value(result, self.window_fn) + + +@dataclasses.dataclass +class TaggingReceiver(Receiver): + """A Receiver that handles tagged `WindowValue`s.""" + tag: str + values: t.List[PCollVal] + + def receive(self, windowed_value: WindowedValue): + if self.tag: + output = TaggedOutput(self.tag, windowed_value) + else: + output = windowed_value + self.values.append(output) + + +@dataclasses.dataclass +class OneReceiver(dict): + """A Receiver that tags value via dictionary lookup key.""" + values: t.List[PCollVal] = field(default_factory=list) + + def __missing__(self, key): + if key not in self: + self[key] = TaggingReceiver(key, self.values) + return self[key] @dataclasses.dataclass class DaskBagOp(abc.ABC): + """Abstract Base Class for all Dask-supported Operations. + + All DaskBagOps must support an `apply()` operation, which invokes the dask + bag upon the previous op's input. + + Attributes + applied: The underlying `AppliedPTransform` which holds the code for the + target operation. + """ applied: AppliedPTransform @property @@ -46,17 +137,19 @@ def transform(self): return self.applied.transform @abc.abstractmethod - def apply(self, input_bag: OpInput) -> db.Bag: + def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: pass class NoOp(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: + """An identity on a dask bag: returns the input as-is.""" + def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: return input_bag class Create(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: + """The beginning of a Beam pipeline; the input must be `None`.""" + def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: assert input_bag is None, 'Create expects no input!' original_transform = t.cast(_Create, self.transform) items = original_transform.values @@ -66,42 +159,95 @@ def apply(self, input_bag: OpInput) -> db.Bag: 1, math.ceil(math.sqrt(len(items)) / math.sqrt(100)))) +def apply_dofn_to_bundle( + items, do_fn_invoker_args, do_fn_invoker_kwargs, tagged_receivers): + """Invokes a DoFn within a bundle, implemented as a Dask partition.""" + + do_fn_invoker = DoFnInvoker.create_invoker( + *do_fn_invoker_args, **do_fn_invoker_kwargs) + + do_fn_invoker.invoke_setup() + do_fn_invoker.invoke_start_bundle() + + for it in items: + do_fn_invoker.invoke_process(it) + + results = [v.value for v in tagged_receivers.values] + + do_fn_invoker.invoke_finish_bundle() + do_fn_invoker.invoke_teardown() + + return results + + class ParDo(DaskBagOp): - def apply(self, input_bag: db.Bag) -> db.Bag: - transform = t.cast(apache_beam.ParDo, self.transform) - return input_bag.map( - transform.fn.process, *transform.args, **transform.kwargs).flatten() + """Apply a pure function in an embarrassingly-parallel way. + This consumes a sequence of items and returns a sequence of items. + """ + def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag: + transform = t.cast(apache_beam.ParDo, self.transform) -class Map(DaskBagOp): - def apply(self, input_bag: db.Bag) -> db.Bag: - transform = t.cast(apache_beam.Map, self.transform) - return input_bag.map( - transform.fn.process, *transform.args, **transform.kwargs) + args, kwargs = transform.raw_side_inputs + args = list(args) + main_input = next(iter(self.applied.main_inputs.values())) + window_fn = main_input.windowing.windowfn if hasattr( + main_input, "windowing") else None + + tagged_receivers = OneReceiver() + + do_fn_invoker_args = [ + DoFnSignature(transform.fn), + _OutputHandler( + window_fn=window_fn, + main_receivers=tagged_receivers[None], + tagged_receivers=tagged_receivers, + per_element_output_counter=None, + output_batch_converter=None, + process_yields_batches=False, + process_batch_yields_elements=False), + ] + do_fn_invoker_kwargs = dict( + context=DoFnContext(transform.label, state=None), + side_inputs=side_inputs, + input_args=args, + input_kwargs=kwargs, + user_state_context=None, + bundle_finalizer_param=DoFn.BundleFinalizerParam(), + ) + + return input_bag.map(get_windowed_value, window_fn).map_partitions( + apply_dofn_to_bundle, + do_fn_invoker_args, + do_fn_invoker_kwargs, + tagged_receivers, + ) class GroupByKey(DaskBagOp): - def apply(self, input_bag: db.Bag) -> db.Bag: + """Group a PCollection into a mapping of keys to elements.""" + def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag: def key(item): return item[0] def value(item): k, v = item - return k, [elm[1] for elm in v] + return k, [defenestrate(elm[1]) for elm in v] return input_bag.groupby(key).map(value) class Flatten(DaskBagOp): - def apply(self, input_bag: OpInput) -> db.Bag: - assert type(input_bag) is list, 'Must take a sequence of bags!' + """Produces a flattened bag from a collection of bags.""" + def apply( + self, input_bag: t.List[db.Bag], side_inputs: OpSide = None) -> db.Bag: + assert isinstance(input_bag, list), 'Must take a sequence of bags!' return db.concat(input_bag) TRANSLATIONS = { _Create: Create, apache_beam.ParDo: ParDo, - apache_beam.Map: Map, _GroupByKeyOnly: GroupByKey, _Flatten: Flatten, } diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 20cae582f320..97996bd6cbb2 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -82,7 +82,7 @@ _LOGGER = logging.getLogger(__name__) -_PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW = ['3.8', '3.9', '3.10', '3.11', '3.12'] +_PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW = ['3.9', '3.10', '3.11', '3.12'] class Environment(object): diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 022136aae9a2..6587e619a500 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -1003,7 +1003,21 @@ def test_interpreter_version_check_passes_with_experiment(self): 'apache_beam.runners.dataflow.internal.apiclient.' 'beam_version.__version__', '2.2.0') - def test_interpreter_version_check_passes_py38(self): + def test_interpreter_version_check_fails_py38(self): + pipeline_options = PipelineOptions([]) + self.assertRaises( + Exception, + apiclient._verify_interpreter_version_is_supported, + pipeline_options) + + @mock.patch( + 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', + (3, 9, 6)) + @mock.patch( + 'apache_beam.runners.dataflow.internal.apiclient.' + 'beam_version.__version__', + '2.2.0') + def test_interpreter_version_check_passes_py39(self): pipeline_options = PipelineOptions([]) apiclient._verify_interpreter_version_is_supported(pipeline_options) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py index 239ee8c700a2..c0d20c3ec8f9 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py @@ -30,4 +30,4 @@ pass # pylint: enable=wrong-import-order, wrong-import-position -__path__ = pkgutil.extend_path(__path__, __name__) # type: ignore +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/names.py b/sdks/python/apache_beam/runners/dataflow/internal/names.py index 044e144c65a0..ac575e82717e 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/names.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/names.py @@ -34,6 +34,6 @@ # Unreleased sdks use container image tag specified below. # Update this tag whenever there is a change that # requires changes to SDK harness container or SDK harness launcher. -BEAM_DEV_SDK_CONTAINER_TAG = 'beam-master-20240918' +BEAM_DEV_SDK_CONTAINER_TAG = 'beam-master-20241118' DATAFLOW_CONTAINER_IMAGE_REPOSITORY = 'gcr.io/cloud-dataflow/v1beta3' diff --git a/sdks/python/apache_beam/runners/direct/direct_metrics.py b/sdks/python/apache_beam/runners/direct/direct_metrics.py index f715ce3bf521..d20849d769af 100644 --- a/sdks/python/apache_beam/runners/direct/direct_metrics.py +++ b/sdks/python/apache_beam/runners/direct/direct_metrics.py @@ -24,23 +24,84 @@ import threading from collections import defaultdict +from typing import Any +from typing import SupportsInt -from apache_beam.metrics.cells import CounterAggregator -from apache_beam.metrics.cells import DistributionAggregator -from apache_beam.metrics.cells import GaugeAggregator -from apache_beam.metrics.cells import StringSetAggregator +from apache_beam.metrics.cells import DistributionData +from apache_beam.metrics.cells import GaugeData +from apache_beam.metrics.cells import StringSetData from apache_beam.metrics.execution import MetricKey from apache_beam.metrics.execution import MetricResult from apache_beam.metrics.metric import MetricResults +class MetricAggregator(object): + """For internal use only; no backwards-compatibility guarantees. + + Base interface for aggregating metric data during pipeline execution.""" + def identity_element(self): + # type: () -> Any + + """Returns the identical element of an Aggregation. + + For the identity element, it must hold that + Aggregator.combine(any_element, identity_element) == any_element. + """ + raise NotImplementedError + + def combine(self, x, y): + # type: (Any, Any) -> Any + raise NotImplementedError + + def result(self, x): + # type: (Any) -> Any + raise NotImplementedError + + +class CounterAggregator(MetricAggregator): + """For internal use only; no backwards-compatibility guarantees. + + Aggregator for Counter metric data during pipeline execution. + + Values aggregated should be ``int`` objects. + """ + @staticmethod + def identity_element(): + # type: () -> int + return 0 + + def combine(self, x, y): + # type: (SupportsInt, SupportsInt) -> int + return int(x) + int(y) + + def result(self, x): + # type: (SupportsInt) -> int + return int(x) + + +class GenericAggregator(MetricAggregator): + def __init__(self, data_class): + self._data_class = data_class + + def identity_element(self): + return self._data_class.identity_element() + + def combine(self, x, y): + return x.combine(y) + + def result(self, x): + return x.get_result() + + class DirectMetrics(MetricResults): def __init__(self): self._counters = defaultdict(lambda: DirectMetric(CounterAggregator())) self._distributions = defaultdict( - lambda: DirectMetric(DistributionAggregator())) - self._gauges = defaultdict(lambda: DirectMetric(GaugeAggregator())) - self._string_sets = defaultdict(lambda: DirectMetric(StringSetAggregator())) + lambda: DirectMetric(GenericAggregator(DistributionData))) + self._gauges = defaultdict( + lambda: DirectMetric(GenericAggregator(GaugeData))) + self._string_sets = defaultdict( + lambda: DirectMetric(GenericAggregator(StringSetData))) def _apply_operation(self, bundle, updates, op): for k, v in updates.counters.items(): diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 49b6622816ce..8b8937653688 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -110,6 +110,38 @@ def visit_transform(self, applied_ptransform): if timer.time_domain == TimeDomain.REAL_TIME: self.supported_by_fnapi_runner = False + class _PrismRunnerSupportVisitor(PipelineVisitor): + """Visitor determining if a Pipeline can be run on the PrismRunner.""" + def accept(self, pipeline): + self.supported_by_prism_runner = True + pipeline.visit(self) + return self.supported_by_prism_runner + + def visit_transform(self, applied_ptransform): + transform = applied_ptransform.transform + # Python SDK assumes the direct runner TestStream implementation is + # being used. + if isinstance(transform, TestStream): + self.supported_by_prism_runner = False + if isinstance(transform, beam.ParDo): + dofn = transform.dofn + # It's uncertain if the Prism Runner supports execution of CombineFns + # with deferred side inputs. + if isinstance(dofn, CombineValuesDoFn): + args, kwargs = transform.raw_side_inputs + args_to_check = itertools.chain(args, kwargs.values()) + if any(isinstance(arg, ArgumentPlaceholder) + for arg in args_to_check): + self.supported_by_prism_runner = False + if userstate.is_stateful_dofn(dofn): + # https://github.com/apache/beam/issues/32786 - + # Remove once Real time clock is used. + _, timer_specs = userstate.get_dofn_specs(dofn) + for timer in timer_specs: + if timer.time_domain == TimeDomain.REAL_TIME: + self.supported_by_prism_runner = False + + tryingPrism = False # Check whether all transforms used in the pipeline are supported by the # FnApiRunner, and the pipeline was not meant to be run as streaming. if _FnApiRunnerSupportVisitor().accept(pipeline): @@ -122,9 +154,33 @@ def visit_transform(self, applied_ptransform): beam_provision_api_pb2.ProvisionInfo( pipeline_options=encoded_options)) runner = fn_runner.FnApiRunner(provision_info=provision_info) + elif _PrismRunnerSupportVisitor().accept(pipeline): + _LOGGER.info('Running pipeline with PrismRunner.') + from apache_beam.runners.portability import prism_runner + runner = prism_runner.PrismRunner() + tryingPrism = True else: runner = BundleBasedDirectRunner() + if tryingPrism: + try: + pr = runner.run_pipeline(pipeline, options) + # This is non-blocking, so if the state is *already* finished, something + # probably failed on job submission. + if pr.state.is_terminal() and pr.state != PipelineState.DONE: + _LOGGER.info( + 'Pipeline failed on PrismRunner, falling back toDirectRunner.') + runner = BundleBasedDirectRunner() + else: + return pr + except Exception as e: + # If prism fails in Preparing the portable job, then the PortableRunner + # code raises an exception. Catch it, log it, and use the Direct runner + # instead. + _LOGGER.info('Exception with PrismRunner:\n %s\n' % (e)) + _LOGGER.info('Falling back to DirectRunner') + runner = BundleBasedDirectRunner() + return runner.run_pipeline(pipeline, options) diff --git a/sdks/python/apache_beam/runners/interactive/caching/reify.py b/sdks/python/apache_beam/runners/interactive/caching/reify.py index ce82785b2585..c82033dc1b9b 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/reify.py +++ b/sdks/python/apache_beam/runners/interactive/caching/reify.py @@ -28,7 +28,6 @@ import apache_beam as beam from apache_beam.runners.interactive import cache_manager as cache from apache_beam.testing import test_stream -from apache_beam.transforms.window import WindowedValue READ_CACHE = 'ReadCache_' WRITE_CACHE = 'WriteCache_' @@ -40,13 +39,8 @@ class Reify(beam.DoFn): Internally used to capture window info with each element into cache for replayability. """ - def process( - self, - e, - w=beam.DoFn.WindowParam, - p=beam.DoFn.PaneInfoParam, - t=beam.DoFn.TimestampParam): - yield test_stream.WindowedValueHolder(WindowedValue(e, t, [w], p)) + def process(self, e, wv=beam.DoFn.WindowedValueParam): + yield test_stream.WindowedValueHolder(wv) class Unreify(beam.DoFn): diff --git a/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager.py b/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager.py index 2b39279f43e9..4d260d4a6a56 100644 --- a/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager.py +++ b/sdks/python/apache_beam/runners/interactive/dataproc/dataproc_cluster_manager.py @@ -169,6 +169,7 @@ def create_cluster(self, cluster: dict) -> None: def create_flink_cluster(self) -> None: """Calls _create_cluster with a configuration that enables FlinkRunner.""" init_action_path = self.stage_init_action() + # https://cloud.google.com/php/docs/reference/cloud-dataproc/latest/V1.Cluster cluster = { 'project_id': self.cluster_metadata.project_id, 'cluster_name': self.cluster_metadata.cluster_name, @@ -194,7 +195,8 @@ def create_flink_cluster(self) -> None: }, 'service_account_scopes': [ 'https://www.googleapis.com/auth/cloud-platform' - ] + ], + 'internal_ip_only': False }, 'master_config': { # There must be 1 and only 1 instance of master. diff --git a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py index 693abb2aeeee..d767a15a345d 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py +++ b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py @@ -26,6 +26,7 @@ import datetime import html import logging +import warnings from datetime import timedelta from typing import Optional @@ -350,7 +351,12 @@ def display(self, updating_pv=None): ] # String-ify the dictionaries for display because elements of type dict # cannot be ordered. - data = data.applymap(lambda x: str(x) if isinstance(x, dict) else x) + with warnings.catch_warnings(): + # TODO(yathu) switch to use DataFrame.map when dropped pandas<2.1 support + warnings.filterwarnings( + "ignore", message="DataFrame.applymap has been deprecated") + data = data.applymap(lambda x: str(x) if isinstance(x, dict) else x) + if updating_pv: # Only updates when data is not empty. Otherwise, consider it a bad # iteration and noop since there is nothing to be updated. diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py index 9554abf3a47a..e3dc8b8968ad 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py @@ -273,9 +273,9 @@ class Recordings(): from all defined unbounded sources for that PCollection's pipeline. The following methods allow for introspection into that background recording job. """ - def describe(self, pipeline=None): - # type: (Optional[beam.Pipeline]) -> dict[str, Any] # noqa: F821 - + def describe( + self, + pipeline: Optional[beam.Pipeline] = None) -> Dict[str, Any]: # noqa: F821 """Returns a description of all the recordings for the given pipeline. If no pipeline is given then this returns a dictionary of descriptions for @@ -292,9 +292,7 @@ def describe(self, pipeline=None): return description[pipeline] return description - def clear(self, pipeline): - # type: (beam.Pipeline) -> bool - + def clear(self, pipeline: beam.Pipeline) -> bool: """Clears all recordings of the given pipeline. Returns True if cleared.""" description = self.describe(pipeline) @@ -308,18 +306,14 @@ def clear(self, pipeline): ie.current_env().cleanup(pipeline) return True - def stop(self, pipeline): - # type: (beam.Pipeline) -> None - + def stop(self, pipeline: beam.Pipeline) -> None: """Stops the background source recording of the given pipeline.""" recording_manager = ie.current_env().get_recording_manager( pipeline, create_if_absent=True) recording_manager.cancel() - def record(self, pipeline): - # type: (beam.Pipeline) -> bool - + def record(self, pipeline: beam.Pipeline) -> bool: """Starts a background source recording job for the given pipeline. Returns True if the recording job was started. """ @@ -408,10 +402,11 @@ class Clusters: To configure a pipeline to run on a local FlinkRunner, explicitly set the default cluster metadata to None: ib.clusters.set_default_cluster(None). """ - # Explicitly set the Flink version here to ensure compatibility with 2.1 + # Explicitly set the Flink version here to ensure compatibility with 2.2 # Dataproc images: - # https://cloud.google.com/dataproc/docs/concepts/versioning/dataproc-release-2.1 - DATAPROC_FLINK_VERSION = '1.15' + # https://cloud.google.com/dataproc/docs/concepts/versioning/dataproc-release-2.2 + # you can manually override this by importing Clusters + DATAPROC_FLINK_VERSION = '1.17' # The minimum worker number to create a Dataproc cluster. DATAPROC_MINIMUM_WORKER_NUM = 2 diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py index b7d56ce90acb..828f23a467c2 100644 --- a/sdks/python/apache_beam/runners/interactive/utils.py +++ b/sdks/python/apache_beam/runners/interactive/utils.py @@ -24,12 +24,17 @@ import json import logging from typing import Any +from typing import Callable from typing import Dict +from typing import Iterator +from typing import List from typing import Tuple +from typing import Union import pandas as pd import apache_beam as beam +from apache_beam.coders import Coder from apache_beam.dataframe.convert import to_pcollection from apache_beam.dataframe.frame_base import DeferredBase from apache_beam.options.pipeline_options import PipelineOptions @@ -55,14 +60,13 @@ def to_element_list( - reader, # type: Generator[Union[beam_runner_api_pb2.TestStreamPayload.Event, WindowedValueHolder]] # noqa: F821 - coder, # type: Coder # noqa: F821 - include_window_info, # type: bool - n=None, # type: int - include_time_events=False, # type: bool -): - # type: (...) -> List[WindowedValue] # noqa: F821 - + reader: Iterator[Union[beam_runner_api_pb2.TestStreamPayload.Event, + WindowedValueHolder]], + coder: Coder, + include_window_info: bool, + n: int = None, + include_time_events: bool = False, +) -> List[WindowedValue]: """Returns an iterator that properly decodes the elements from the reader. """ @@ -102,9 +106,10 @@ def elements(): count += 1 -def elements_to_df(elements, include_window_info=False, element_type=None): - # type: (List[WindowedValue], bool, Any) -> DataFrame # noqa: F821 - +def elements_to_df( + elements: List[WindowedValue], + include_window_info: bool = False, + element_type: Any = None) -> 'DataFrame': # noqa: F821 """Parses the given elements into a Dataframe. If the elements are a list of WindowedValues, then it will break out the @@ -143,9 +148,7 @@ def elements_to_df(elements, include_window_info=False, element_type=None): return final_df -def register_ipython_log_handler(): - # type: () -> None - +def register_ipython_log_handler() -> None: """Adds the IPython handler to a dummy parent logger (named 'apache_beam.runners.interactive') of all interactive modules' loggers so that if is_in_notebook, logging displays the logs as HTML in frontends. @@ -200,9 +203,7 @@ def emit(self, record): pass # NOOP when dependencies are not available. -def obfuscate(*inputs): - # type: (*Any) -> str - +def obfuscate(*inputs: Any) -> str: """Obfuscates any inputs into a hexadecimal string.""" str_inputs = [str(input) for input in inputs] merged_inputs = '_'.join(str_inputs) @@ -223,8 +224,7 @@ class ProgressIndicator(object): spinner_removal_template = """ $("#{id}").remove();""" - def __init__(self, enter_text, exit_text): - # type: (str, str) -> None + def __init__(self, enter_text: str, exit_text: str) -> None: self._id = 'progress_indicator_{}'.format(obfuscate(id(self))) self._enter_text = enter_text @@ -267,9 +267,7 @@ def __exit__(self, exc_type, exc_value, traceback): 'or notebook environment: %s' % e) -def progress_indicated(func): - # type: (Callable[..., Any]) -> Callable[..., Any] # noqa: F821 - +def progress_indicated(func: Callable[..., Any]) -> Callable[..., Any]: """A decorator using a unique progress indicator as a context manager to execute the given function within.""" @functools.wraps(func) @@ -280,9 +278,7 @@ def run_within_progress_indicator(*args, **kwargs): return run_within_progress_indicator -def as_json(func): - # type: (Callable[..., Any]) -> Callable[..., str] # noqa: F821 - +def as_json(func: Callable[..., Any]) -> Callable[..., str]: """A decorator convert python objects returned by callables to json string. @@ -440,9 +436,7 @@ def create_var_in_main(name: str, return name, value -def assert_bucket_exists(bucket_name): - # type: (str) -> None - +def assert_bucket_exists(bucket_name: str) -> None: """Asserts whether the specified GCS bucket with the name bucket_name exists. diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 0a03c96bc19b..13ab665c1eb1 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -306,7 +306,7 @@ def get_or_create_environment_with_resource_hints( # "Message"; expected "Environment" [arg-type] # Here, Environment is a subclass of Message but mypy still # throws an error. - cloned_env.CopyFrom(template_env) # type: ignore[arg-type] + cloned_env.CopyFrom(template_env) cloned_env.resource_hints.clear() cloned_env.resource_hints.update(resource_hints) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index 4a737feaf288..1309e7c74abc 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -776,7 +776,8 @@ def process( state.clear() yield buffer else: - timer.set(ts + 1) + # Set the timer to fire within it's window. + timer.set(ts + (1 - timestamp.Duration(micros=1000))) @userstate.on_timer(timer_spec) def process_timer(self, state=beam.DoFn.StateParam(state_spec)): @@ -790,8 +791,10 @@ def is_buffered_correctly(actual): # Acutal should be a grouping of the inputs into batches of size # at most buffer_size, but the actual batching is nondeterministic # based on ordering and trigger firing timing. - self.assertEqual(sorted(sum((list(b) for b in actual), [])), elements) - self.assertEqual(max(len(list(buffer)) for buffer in actual), buffer_size) + self.assertEqual( + sorted(sum((list(b) for b in actual), [])), elements, actual) + self.assertEqual( + max(len(list(buffer)) for buffer in actual), buffer_size, actual) if windowed: # Elements were assigned to windows based on their parity. # Assert that each grouping consists of elements belonging to the diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py index 07af1c958cfd..c1c7f649f77a 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py @@ -256,7 +256,7 @@ def has_as_main_input(self, pcoll): transform.spec.payload, beam_runner_api_pb2.ParDoPayload) local_side_inputs = payload.side_inputs else: - local_side_inputs = {} # type: ignore[assignment] + local_side_inputs = {} for local_id, pipeline_id in transform.inputs.items(): if pcoll == pipeline_id and local_id not in local_side_inputs: return True diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py index bcfa965c0469..d798e96d3aa3 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py @@ -48,7 +48,9 @@ from typing import overload import grpc +from sortedcontainers import SortedSet +from apache_beam import coders from apache_beam.io import filesystems from apache_beam.io.filesystems import CompressionTypes from apache_beam.portability import common_urns @@ -959,7 +961,8 @@ class StateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer, 'multimap_keys_values_side_input', 'iterable_side_input', 'bag_user_state', - 'multimap_user_state' + 'multimap_user_state', + 'ordered_list_user_state' ]) class CopyOnWriteState(object): @@ -1021,6 +1024,8 @@ def __init__(self): self._checkpoint = None # type: Optional[StateServicer.StateType] self._use_continuation_tokens = False self._continuations = {} # type: Dict[bytes, Tuple[bytes, ...]] + self._ordered_list_keys = collections.defaultdict( + SortedSet) # type: DefaultDict[bytes, SortedSet] def checkpoint(self): # type: () -> None @@ -1050,6 +1055,14 @@ def process_instruction_id(self, unused_instruction_id): # type: (Any) -> Iterator yield + def _get_one_interval_key(self, state_key, start): + # type: (beam_fn_api_pb2.StateKey, int) -> bytes + state_key_copy = beam_fn_api_pb2.StateKey() + state_key_copy.CopyFrom(state_key) + state_key_copy.ordered_list_user_state.range.start = start + state_key_copy.ordered_list_user_state.range.end = start + 1 + return self._to_key(state_key_copy) + def get_raw(self, state_key, # type: beam_fn_api_pb2.StateKey continuation_token=None # type: Optional[bytes] @@ -1058,10 +1071,33 @@ def get_raw(self, if state_key.WhichOneof('type') not in self._SUPPORTED_STATE_TYPES: raise NotImplementedError( - 'Unknown state type: ' + state_key.WhichOneof('type')) + 'Unknown state type: ' + state_key.WhichOneof('type')) # type: ignore[operator] with self._lock: - full_state = self._state[self._to_key(state_key)] + if not continuation_token: + # Compute full_state only when no continuation token is provided. + # If there is continuation token, full_state is already in + # continuation cache. No need to recompute. + full_state = [] # type: List[bytes] + if state_key.WhichOneof('type') == 'ordered_list_user_state': + maybe_start = state_key.ordered_list_user_state.range.start + maybe_end = state_key.ordered_list_user_state.range.end + persistent_state_key = beam_fn_api_pb2.StateKey() + persistent_state_key.CopyFrom(state_key) + persistent_state_key.ordered_list_user_state.ClearField("range") + + available_keys = self._ordered_list_keys[self._to_key( + persistent_state_key)] + + for i in available_keys.irange(maybe_start, + maybe_end, + inclusive=(True, False)): + entries = self._state[self._get_one_interval_key( + persistent_state_key, i)] + full_state.extend(entries) + else: + full_state.extend(self._state[self._to_key(state_key)]) + if self._use_continuation_tokens: # The token is "nonce:index". if not continuation_token: @@ -1087,14 +1123,40 @@ def append_raw( ): # type: (...) -> _Future with self._lock: - self._state[self._to_key(state_key)].append(data) + if state_key.WhichOneof('type') == 'ordered_list_user_state': + coder = coders.TupleCoder([ + coders.VarIntCoder(), + coders.coders.LengthPrefixCoder(coders.BytesCoder()) + ]).get_impl() + + for key, value in coder.decode_all(data): + self._state[self._get_one_interval_key(state_key, key)].append( + coder.encode((key, value))) + self._ordered_list_keys[self._to_key(state_key)].add(key) + else: + self._state[self._to_key(state_key)].append(data) return _Future.done() def clear(self, state_key): # type: (beam_fn_api_pb2.StateKey) -> _Future with self._lock: try: - del self._state[self._to_key(state_key)] + if state_key.WhichOneof('type') == 'ordered_list_user_state': + start = state_key.ordered_list_user_state.range.start + end = state_key.ordered_list_user_state.range.end + persistent_state_key = beam_fn_api_pb2.StateKey() + persistent_state_key.CopyFrom(state_key) + persistent_state_key.ordered_list_user_state.ClearField("range") + available_keys = self._ordered_list_keys[self._to_key( + persistent_state_key)] + + for i in list(available_keys.irange(start, + end, + inclusive=(True, False))): + del self._state[self._get_one_interval_key(persistent_state_key, i)] + available_keys.remove(i) + else: + del self._state[self._to_key(state_key)] except KeyError: # This may happen with the caching layer across bundles. Caching may # skip this storage layer for a blocking_get(key) request. Without diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index e44d8ab0ae93..eee75f66a277 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -127,6 +127,7 @@ def __init__(self, options): self._artifacts_dir = options.artifacts_dir self._java_launcher = options.job_server_java_launcher self._jvm_properties = options.job_server_jvm_properties + self._jar_cache_dir = options.jar_cache_dir def java_arguments( self, job_port, artifact_port, expansion_port, artifacts_dir): @@ -141,11 +142,11 @@ def path_to_beam_jar(gradle_target, artifact_id=None): gradle_target, artifact_id=artifact_id) @staticmethod - def local_jar(url): - return subprocess_server.JavaJarServer.local_jar(url) + def local_jar(url, jar_cache_dir=None): + return subprocess_server.JavaJarServer.local_jar(url, jar_cache_dir) def subprocess_cmd_and_endpoint(self): - jar_path = self.local_jar(self.path_to_jar()) + jar_path = self.local_jar(self.path_to_jar(), self._jar_cache_dir) artifacts_dir = ( self._artifacts_dir if self._artifacts_dir else self.local_temp_dir( prefix='artifacts')) diff --git a/sdks/python/apache_beam/runners/portability/job_server_test.py b/sdks/python/apache_beam/runners/portability/job_server_test.py index 1e2ede281c9d..13b3629b24bf 100644 --- a/sdks/python/apache_beam/runners/portability/job_server_test.py +++ b/sdks/python/apache_beam/runners/portability/job_server_test.py @@ -41,7 +41,8 @@ def path_to_jar(self): return '/path/to/jar' @staticmethod - def local_jar(url): + def local_jar(url, jar_cache_dir=None): + logging.debug("url({%s}), jar_cache_dir({%s})", url, jar_cache_dir) return url diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py index 869f013d0d26..a2b4e5e7f939 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service.py @@ -35,7 +35,7 @@ import grpc from google.protobuf import json_format from google.protobuf import struct_pb2 -from google.protobuf import text_format # type: ignore # not in typeshed +from google.protobuf import text_format from apache_beam import pipeline from apache_beam.metrics import monitoring_infos diff --git a/sdks/python/apache_beam/runners/portability/prism_runner.py b/sdks/python/apache_beam/runners/portability/prism_runner.py index eeccaf5748ce..77dc8a214e8e 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner.py @@ -27,6 +27,7 @@ import platform import shutil import stat +import subprocess import typing import urllib import zipfile @@ -167,38 +168,93 @@ def construct_download_url(self, root_tag: str, sys: str, mach: str) -> str: def path_to_binary(self) -> str: if self._path is not None: - if not os.path.exists(self._path): - url = urllib.parse.urlparse(self._path) - if not url.scheme: - raise ValueError( - 'Unable to parse binary URL "%s". If using a full URL, make ' - 'sure the scheme is specified. If using a local file xpath, ' - 'make sure the file exists; you may have to first build prism ' - 'using `go build `.' % (self._path)) - - # We have a URL, see if we need to construct a valid file name. - if self._path.startswith(GITHUB_DOWNLOAD_PREFIX): - # If this URL starts with the download prefix, let it through. - return self._path - # The only other valid option is a github release page. - if not self._path.startswith(GITHUB_TAG_PREFIX): - raise ValueError( - 'Provided --prism_location URL is not an Apache Beam Github ' - 'Release page URL or download URL: %s' % (self._path)) - # Get the root tag for this URL - root_tag = os.path.basename(os.path.normpath(self._path)) - return self.construct_download_url( - root_tag, platform.system(), platform.machine()) - return self._path - else: - if '.dev' in self._version: + # The path is overidden, check various cases. + if os.path.exists(self._path): + # The path is local and exists, use directly. + return self._path + + # Check if the path is a URL. + url = urllib.parse.urlparse(self._path) + if not url.scheme: + raise ValueError( + 'Unable to parse binary URL "%s". If using a full URL, make ' + 'sure the scheme is specified. If using a local file xpath, ' + 'make sure the file exists; you may have to first build prism ' + 'using `go build `.' % (self._path)) + + # We have a URL, see if we need to construct a valid file name. + if self._path.startswith(GITHUB_DOWNLOAD_PREFIX): + # If this URL starts with the download prefix, let it through. + return self._path + # The only other valid option is a github release page. + if not self._path.startswith(GITHUB_TAG_PREFIX): raise ValueError( - 'Unable to derive URL for dev versions "%s". Please provide an ' - 'alternate version to derive the release URL with the ' - '--prism_beam_version_override flag.' % (self._version)) + 'Provided --prism_location URL is not an Apache Beam Github ' + 'Release page URL or download URL: %s' % (self._path)) + # Get the root tag for this URL + root_tag = os.path.basename(os.path.normpath(self._path)) + return self.construct_download_url( + root_tag, platform.system(), platform.machine()) + + if '.dev' not in self._version: + # Not a development version, so construct the production download URL return self.construct_download_url( self._version, platform.system(), platform.machine()) + # This is a development version! Assume Go is installed. + # Set the install directory to the cache location. + envdict = {**os.environ, "GOBIN": self.BIN_CACHE} + PRISMPKG = "github.com/apache/beam/sdks/v2/go/cmd/prism" + + process = subprocess.run(["go", "install", PRISMPKG], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=envdict, + check=False) + if process.returncode == 0: + # Successfully installed + return '%s/prism' % (self.BIN_CACHE) + + # We failed to build for some reason. + output = process.stdout.decode("utf-8") + if ("not in a module" not in output) and ( + "no required module provides" not in output): + # This branch handles two classes of failures: + # 1. Go isn't installed, so it needs to be installed by the Beam SDK + # developer. + # 2. Go is installed, and they are building in a local version of Prism, + # but there was a compile error that the developer should address. + # Either way, the @latest fallback either would fail, or hide the error, + # so fail now. + _LOGGER.info(output) + raise ValueError( + 'Unable to install a local of Prism: "%s";\n' + 'Likely Go is not installed, or a local change to Prism did not ' + 'compile.\nPlease install Go (see https://go.dev/doc/install) to ' + 'enable automatic local builds.\n' + 'Alternatively provide a binary with the --prism_location flag.' + '\nCaptured output:\n %s' % (self._version, output)) + + # Go is installed and claims we're not in a Go module that has access to + # the Prism package. + + # Fallback to using the @latest version of prism, which works everywhere. + process = subprocess.run(["go", "install", PRISMPKG + "@latest"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=envdict, + check=False) + + if process.returncode == 0: + return '%s/prism' % (self.BIN_CACHE) + + output = process.stdout.decode("utf-8") + raise ValueError( + 'We were unable to execute the subprocess "%s" to automatically ' + 'build prism.\nAlternatively provide an alternate binary with the ' + '--prism_location flag.' + '\nCaptured output:\n %s' % (process.args, output)) + def subprocess_cmd_and_endpoint( self) -> typing.Tuple[typing.List[typing.Any], str]: bin_path = self.local_bin( diff --git a/sdks/python/apache_beam/runners/portability/prism_runner_test.py b/sdks/python/apache_beam/runners/portability/prism_runner_test.py index 04a2dbd4faed..bc72d551f966 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner_test.py @@ -213,6 +213,24 @@ def test_expand_kafka_write(self): def test_sql(self): raise unittest.SkipTest("Requires an expansion service to execute.") + # The following tests require additional implementation in Prism. + + def test_custom_merging_window(self): + raise unittest.SkipTest( + "Requires Prism to support Custom Window " + + "Coders, and Merging Custom Windows. " + + "https://github.com/apache/beam/issues/31921") + + def test_custom_window_type(self): + raise unittest.SkipTest( + "Requires Prism to support Custom Window Coders." + + " https://github.com/apache/beam/issues/31921") + + def test_pack_combiners(self): + raise unittest.SkipTest( + "Requires Prism to support coder:" + + " 'beam:coder:tuple:v1'. https://github.com/apache/beam/issues/32636") + # Inherits all other tests. diff --git a/sdks/python/apache_beam/runners/portability/stager.py b/sdks/python/apache_beam/runners/portability/stager.py index 98c0e3176f75..c7142bfddcaf 100644 --- a/sdks/python/apache_beam/runners/portability/stager.py +++ b/sdks/python/apache_beam/runners/portability/stager.py @@ -107,9 +107,9 @@ class Stager(object): """ _DEFAULT_CHUNK_SIZE = 2 << 20 - def stage_artifact(self, local_path_to_artifact, artifact_name, sha256): - # type: (str, str, str) -> None - + def stage_artifact( + self, local_path_to_artifact: str, artifact_name: str, + sha256: str) -> None: """ Stages the artifact to Stager._staging_location and adds artifact_name to the manifest of artifacts that have been staged.""" raise NotImplementedError @@ -159,14 +159,16 @@ def extract_staging_tuple_iter( raise RuntimeError("unknown artifact type: %s" % artifact.type_urn) @staticmethod - def create_job_resources(options, # type: PipelineOptions - temp_dir, # type: str - build_setup_args=None, # type: Optional[List[str]] - pypi_requirements=None, # type: Optional[List[str]] - populate_requirements_cache=None, # type: Optional[Callable[[str, str, bool], None]] - skip_prestaged_dependencies=False, # type: Optional[bool] - log_submission_env_dependencies=True, # type: Optional[bool] - ): + def create_job_resources( + options: PipelineOptions, + temp_dir: str, + build_setup_args: Optional[List[str]] = None, + pypi_requirements: Optional[List[str]] = None, + populate_requirements_cache: Optional[Callable[[str, str, bool], + None]] = None, + skip_prestaged_dependencies: Optional[bool] = False, + log_submission_env_dependencies: Optional[bool] = True, + ): """For internal use only; no backwards-compatibility guarantees. Creates (if needed) a list of job resources. @@ -198,7 +200,7 @@ def create_job_resources(options, # type: PipelineOptions while trying to create the resources (e.g., build a setup package). """ - resources = [] # type: List[beam_runner_api_pb2.ArtifactInformation] + resources: List[beam_runner_api_pb2.ArtifactInformation] = [] setup_options = options.view_as(SetupOptions) use_beam_default_container = options.view_as( @@ -381,10 +383,10 @@ def create_job_resources(options, # type: PipelineOptions return resources - def stage_job_resources(self, - resources, # type: List[Tuple[str, str, str]] - staging_location=None # type: Optional[str] - ): + def stage_job_resources( + self, + resources: List[Tuple[str, str, str]], + staging_location: Optional[str] = None): """For internal use only; no backwards-compatibility guarantees. Stages job resources to staging_location. @@ -416,13 +418,13 @@ def stage_job_resources(self, def create_and_stage_job_resources( self, - options, # type: PipelineOptions - build_setup_args=None, # type: Optional[List[str]] - temp_dir=None, # type: Optional[str] - pypi_requirements=None, # type: Optional[List[str]] - populate_requirements_cache=None, # type: Optional[Callable[[str, str, bool], None]] - staging_location=None # type: Optional[str] - ): + options: PipelineOptions, + build_setup_args: Optional[List[str]] = None, + temp_dir: Optional[str] = None, + pypi_requirements: Optional[List[str]] = None, + populate_requirements_cache: Optional[Callable[[str, str, bool], + None]] = None, + staging_location: Optional[str] = None): """For internal use only; no backwards-compatibility guarantees. Creates (if needed) and stages job resources to staging_location. @@ -523,9 +525,8 @@ def _is_remote_path(path): return path.find('://') != -1 @staticmethod - def _create_jar_packages(jar_packages, temp_dir): - # type: (...) -> List[beam_runner_api_pb2.ArtifactInformation] - + def _create_jar_packages( + jar_packages, temp_dir) -> List[beam_runner_api_pb2.ArtifactInformation]: """Creates a list of local jar packages for Java SDK Harness. :param jar_packages: Ordered list of local paths to jar packages to be @@ -538,9 +539,9 @@ def _create_jar_packages(jar_packages, temp_dir): RuntimeError: If files specified are not found or do not have expected name patterns. """ - resources = [] # type: List[beam_runner_api_pb2.ArtifactInformation] + resources: List[beam_runner_api_pb2.ArtifactInformation] = [] staging_temp_dir = tempfile.mkdtemp(dir=temp_dir) - local_packages = [] # type: List[str] + local_packages: List[str] = [] for package in jar_packages: if not os.path.basename(package).endswith('.jar'): raise RuntimeError( @@ -574,9 +575,9 @@ def _create_jar_packages(jar_packages, temp_dir): return resources @staticmethod - def _create_extra_packages(extra_packages, temp_dir): - # type: (...) -> List[beam_runner_api_pb2.ArtifactInformation] - + def _create_extra_packages( + extra_packages, + temp_dir) -> List[beam_runner_api_pb2.ArtifactInformation]: """Creates a list of local extra packages. Args: @@ -595,9 +596,9 @@ def _create_extra_packages(extra_packages, temp_dir): RuntimeError: If files specified are not found or do not have expected name patterns. """ - resources = [] # type: List[beam_runner_api_pb2.ArtifactInformation] + resources: List[beam_runner_api_pb2.ArtifactInformation] = [] staging_temp_dir = tempfile.mkdtemp(dir=temp_dir) - local_packages = [] # type: List[str] + local_packages: List[str] = [] for package in extra_packages: if not (os.path.basename(package).endswith('.tar') or os.path.basename(package).endswith('.tar.gz') or @@ -665,9 +666,7 @@ def _get_python_executable(): @staticmethod def _remove_dependency_from_requirements( - requirements_file, # type: str - dependency_to_remove, # type: str - temp_directory_path): + requirements_file: str, dependency_to_remove: str, temp_directory_path): """Function to remove dependencies from a given requirements file.""" # read all the dependency names with open(requirements_file, 'r') as f: @@ -776,11 +775,10 @@ def _populate_requirements_cache( processes.check_output(cmd_args, stderr=processes.STDOUT) @staticmethod - def _build_setup_package(setup_file, # type: str - temp_dir, # type: str - build_setup_args=None # type: Optional[List[str]] - ): - # type: (...) -> str + def _build_setup_package( + setup_file: str, + temp_dir: str, + build_setup_args: Optional[List[str]] = None) -> str: saved_current_directory = os.getcwd() try: os.chdir(os.path.dirname(setup_file)) @@ -819,9 +817,7 @@ def _build_setup_package(setup_file, # type: str os.chdir(saved_current_directory) @staticmethod - def _desired_sdk_filename_in_staging_location(sdk_location): - # type: (...) -> str - + def _desired_sdk_filename_in_staging_location(sdk_location) -> str: """Returns the name that SDK file should have in the staging location. Args: sdk_location: Full path to SDK file. @@ -836,9 +832,9 @@ def _desired_sdk_filename_in_staging_location(sdk_location): return names.STAGED_SDK_SOURCES_FILENAME @staticmethod - def _create_beam_sdk(sdk_remote_location, temp_dir): - # type: (...) -> List[beam_runner_api_pb2.ArtifactInformation] - + def _create_beam_sdk( + sdk_remote_location, + temp_dir) -> List[beam_runner_api_pb2.ArtifactInformation]: """Creates a Beam SDK file with the appropriate version. Args: diff --git a/sdks/python/apache_beam/runners/render.py b/sdks/python/apache_beam/runners/render.py index fccfa8aacd61..45e66e1ba06a 100644 --- a/sdks/python/apache_beam/runners/render.py +++ b/sdks/python/apache_beam/runners/render.py @@ -64,7 +64,7 @@ import urllib.parse from google.protobuf import json_format -from google.protobuf import text_format # type: ignore +from google.protobuf import text_format from apache_beam.options import pipeline_options from apache_beam.portability.api import beam_runner_api_pb2 diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index fdb13a03bb94..89c137fe4366 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -19,16 +19,21 @@ # pytype: skip-file +from __future__ import annotations + import base64 import bisect import collections import copy +import heapq +import itertools import json import logging import random import threading from dataclasses import dataclass from dataclasses import field +from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -40,6 +45,7 @@ from typing import Iterator from typing import List from typing import Mapping +from typing import MutableMapping from typing import Optional from typing import Set from typing import Tuple @@ -50,6 +56,8 @@ from google.protobuf import duration_pb2 from google.protobuf import timestamp_pb2 +from sortedcontainers import SortedDict +from sortedcontainers import SortedList import apache_beam as beam from apache_beam import coders @@ -104,7 +112,8 @@ FnApiUserRuntimeStateTypes = Union['ReadModifyWriteRuntimeState', 'CombiningValueRuntimeState', 'SynchronousSetRuntimeState', - 'SynchronousBagRuntimeState'] + 'SynchronousBagRuntimeState', + 'SynchronousOrderedListRuntimeState'] DATA_INPUT_URN = 'beam:runner:source:v1' DATA_OUTPUT_URN = 'beam:runner:sink:v1' @@ -122,18 +131,16 @@ class RunnerIOOperation(operations.Operation): """Common baseclass for runner harness IO operations.""" - - def __init__(self, - name_context, # type: common.NameContext - step_name, # type: Any - consumers, # type: Mapping[Any, Iterable[operations.Operation]] - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - windowed_coder, # type: coders.Coder - transform_id, # type: str - data_channel # type: data_plane.DataChannel - ): - # type: (...) -> None + def __init__( + self, + name_context: common.NameContext, + step_name: Any, + consumers: Mapping[Any, Iterable[operations.Operation]], + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + windowed_coder: coders.Coder, + transform_id: str, + data_channel: data_plane.DataChannel) -> None: super().__init__(name_context, None, counter_factory, state_sampler) self.windowed_coder = windowed_coder self.windowed_coder_impl = windowed_coder.get_impl() @@ -149,36 +156,32 @@ def __init__(self, class DataOutputOperation(RunnerIOOperation): """A sink-like operation that gathers outputs to be sent back to the runner. """ - def set_output_stream(self, output_stream): - # type: (data_plane.ClosableOutputStream) -> None + def set_output_stream( + self, output_stream: data_plane.ClosableOutputStream) -> None: self.output_stream = output_stream - def process(self, windowed_value): - # type: (windowed_value.WindowedValue) -> None + def process(self, windowed_value: windowed_value.WindowedValue) -> None: self.windowed_coder_impl.encode_to_stream( windowed_value, self.output_stream, True) self.output_stream.maybe_flush() - def finish(self): - # type: () -> None + def finish(self) -> None: super().finish() self.output_stream.close() class DataInputOperation(RunnerIOOperation): """A source-like operation that gathers input from the runner.""" - - def __init__(self, - operation_name, # type: common.NameContext - step_name, - consumers, # type: Mapping[Any, List[operations.Operation]] - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - windowed_coder, # type: coders.Coder - transform_id, - data_channel # type: data_plane.GrpcClientDataChannel - ): - # type: (...) -> None + def __init__( + self, + operation_name: common.NameContext, + step_name, + consumers: Mapping[Any, List[operations.Operation]], + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + windowed_coder: coders.Coder, + transform_id, + data_channel: data_plane.GrpcClientDataChannel) -> None: super().__init__( operation_name, step_name, @@ -209,18 +212,15 @@ def setup(self, data_sampler=None): producer_batch_converter=self.get_output_batch_converter()) ] - def start(self): - # type: () -> None + def start(self) -> None: super().start() with self.splitting_lock: self.started = True - def process(self, windowed_value): - # type: (windowed_value.WindowedValue) -> None + def process(self, windowed_value: windowed_value.WindowedValue) -> None: self.output(windowed_value) - def process_encoded(self, encoded_windowed_values): - # type: (bytes) -> None + def process_encoded(self, encoded_windowed_values: bytes) -> None: input_stream = coder_impl.create_InputStream(encoded_windowed_values) while input_stream.size() > 0: with self.splitting_lock: @@ -236,8 +236,9 @@ def process_encoded(self, encoded_windowed_values): str(self.windowed_coder)) from exn self.output(decoded_value) - def monitoring_infos(self, transform_id, tag_to_pcollection_id): - # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] + def monitoring_infos( + self, transform_id: str, tag_to_pcollection_id: Dict[str, str] + ) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]: all_monitoring_infos = super().monitoring_infos( transform_id, tag_to_pcollection_id) read_progress_info = monitoring_infos.int64_counter( @@ -251,8 +252,13 @@ def monitoring_infos(self, transform_id, tag_to_pcollection_id): # TODO(https://github.com/apache/beam/issues/19737): typing not compatible # with super type def try_split( # type: ignore[override] - self, fraction_of_remainder, total_buffer_size, allowed_split_points): - # type: (...) -> Optional[Tuple[int, Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual], int]] + self, fraction_of_remainder, total_buffer_size, allowed_split_points + ) -> Optional[ + Tuple[ + int, + Iterable[operations.SdfSplitResultsPrimary], + Iterable[operations.SdfSplitResultsResidual], + int]]: with self.splitting_lock: if not self.started: return None @@ -306,9 +312,10 @@ def is_valid_split_point(index): # try splitting at the current element. if (keep_of_element_remainder < 1 and is_valid_split_point(index) and is_valid_split_point(index + 1)): - split = try_split( - keep_of_element_remainder - ) # type: Optional[Tuple[Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual]]] + split: Optional[Tuple[ + Iterable[operations.SdfSplitResultsPrimary], + Iterable[operations.SdfSplitResultsResidual]]] = try_split( + keep_of_element_remainder) if split: element_primaries, element_residuals = split return index - 1, element_primaries, element_residuals, index + 1 @@ -335,15 +342,13 @@ def is_valid_split_point(index): else: return None - def finish(self): - # type: () -> None + def finish(self) -> None: super().finish() with self.splitting_lock: self.index += 1 self.started = False - def reset(self): - # type: () -> None + def reset(self) -> None: with self.splitting_lock: self.index = -1 self.stop = float('inf') @@ -351,12 +356,12 @@ def reset(self): class _StateBackedIterable(object): - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - coder_or_impl, # type: Union[coders.Coder, coder_impl.CoderImpl] - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + coder_or_impl: Union[coders.Coder, coder_impl.CoderImpl], + ) -> None: self._state_handler = state_handler self._state_key = state_key if isinstance(coder_or_impl, coders.Coder): @@ -364,8 +369,7 @@ def __init__(self, else: self._coder_impl = coder_or_impl - def __iter__(self): - # type: () -> Iterator[Any] + def __iter__(self) -> Iterator[Any]: return iter( self._state_handler.blocking_get(self._state_key, self._coder_impl)) @@ -383,15 +387,15 @@ class StateBackedSideInputMap(object): _BULK_READ_FULLY = "fully" _BULK_READ_PARTIALLY = "partially" - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - transform_id, # type: str - tag, # type: Optional[str] - side_input_data, # type: pvalue.SideInputData - coder, # type: WindowedValueCoder - use_bulk_read = False, # type: bool - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + transform_id: str, + tag: Optional[str], + side_input_data: pvalue.SideInputData, + coder: WindowedValueCoder, + use_bulk_read: bool = False, + ) -> None: self._state_handler = state_handler self._transform_id = transform_id self._tag = tag @@ -399,7 +403,7 @@ def __init__(self, self._element_coder = coder.wrapped_value_coder self._target_window_coder = coder.window_coder # TODO(robertwb): Limit the cache size. - self._cache = {} # type: Dict[BoundedWindow, Any] + self._cache: Dict[BoundedWindow, Any] = {} self._use_bulk_read = use_bulk_read def __getitem__(self, window): @@ -495,14 +499,12 @@ def __reduce__(self): self._cache[target_window] = self._side_input_data.view_fn(raw_view) return self._cache[target_window] - def is_globally_windowed(self): - # type: () -> bool + def is_globally_windowed(self) -> bool: return ( self._side_input_data.window_mapping_fn == sideinputs._global_window_mapping_fn) - def reset(self): - # type: () -> None + def reset(self) -> None: # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens. self._cache = {} @@ -511,26 +513,28 @@ class ReadModifyWriteRuntimeState(userstate.ReadModifyWriteRuntimeState): def __init__(self, underlying_bag_state): self._underlying_bag_state = underlying_bag_state - def read(self): # type: () -> Any + def read(self) -> Any: values = list(self._underlying_bag_state.read()) if not values: return None return values[0] - def write(self, value): # type: (Any) -> None + def write(self, value: Any) -> None: self.clear() self._underlying_bag_state.add(value) - def clear(self): # type: () -> None + def clear(self) -> None: self._underlying_bag_state.clear() - def commit(self): # type: () -> None + def commit(self) -> None: self._underlying_bag_state.commit() class CombiningValueRuntimeState(userstate.CombiningValueRuntimeState): - def __init__(self, underlying_bag_state, combinefn): - # type: (userstate.AccumulatingRuntimeState, core.CombineFn) -> None + def __init__( + self, + underlying_bag_state: userstate.AccumulatingRuntimeState, + combinefn: core.CombineFn) -> None: self._combinefn = combinefn self._combinefn.setup() self._underlying_bag_state = underlying_bag_state @@ -544,12 +548,10 @@ def _read_accumulator(self, rewrite=True): self._underlying_bag_state.add(merged_accumulator) return merged_accumulator - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: return self._combinefn.extract_output(self._read_accumulator()) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: # Prefer blind writes, but don't let them grow unboundedly. # This should be tuned to be much lower, but for now exercise # both paths well. @@ -561,8 +563,7 @@ def add(self, value): self._underlying_bag_state.add( self._combinefn.add_input(accumulator, value)) - def clear(self): - # type: () -> None + def clear(self) -> None: self._underlying_bag_state.clear() def commit(self): @@ -579,13 +580,11 @@ class _ConcatIterable(object): Unlike itertools.chain, this allows reiteration. """ - def __init__(self, first, second): - # type: (Iterable[Any], Iterable[Any]) -> None + def __init__(self, first: Iterable[Any], second: Iterable[Any]) -> None: self.first = first self.second = second - def __iter__(self): - # type: () -> Iterator[Any] + def __iter__(self) -> Iterator[Any]: for elem in self.first: yield elem for elem in self.second: @@ -596,38 +595,32 @@ def __iter__(self): class SynchronousBagRuntimeState(userstate.BagRuntimeState): - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - value_coder # type: coders.Coder - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = [] # type: List[Any] + self._added_elements: List[Any] = [] - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: return _ConcatIterable([] if self._cleared else cast( 'Iterable[Any]', _StateBackedIterable( self._state_handler, self._state_key, self._value_coder)), self._added_elements) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: self._added_elements.append(value) - def clear(self): - # type: () -> None + def clear(self) -> None: self._cleared = True self._added_elements = [] - def commit(self): - # type: () -> None + def commit(self) -> None: to_await = None if self._cleared: to_await = self._state_handler.clear(self._state_key) @@ -640,18 +633,16 @@ def commit(self): class SynchronousSetRuntimeState(userstate.SetRuntimeState): - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - value_coder # type: coders.Coder - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = set() # type: Set[Any] + self._added_elements: Set[Any] = set() def _compact_data(self, rewrite=True): accumulator = set( @@ -671,12 +662,10 @@ def _compact_data(self, rewrite=True): return accumulator - def read(self): - # type: () -> Set[Any] + def read(self) -> Set[Any]: return self._compact_data(rewrite=False) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: if self._cleared: # This is a good time explicitly clear. self._state_handler.clear(self._state_key) @@ -686,13 +675,11 @@ def add(self, value): if random.random() > 0.5: self._compact_data() - def clear(self): - # type: () -> None + def clear(self) -> None: self._cleared = True self._added_elements = set() - def commit(self): - # type: () -> None + def commit(self) -> None: to_await = None if self._cleared: to_await = self._state_handler.clear(self._state_key) @@ -704,17 +691,191 @@ def commit(self): to_await.get() +class RangeSet: + """For Internal Use only. A simple range set for ranges of [x,y).""" + def __init__(self) -> None: + # The start points and end points are stored separately in order. + self._sorted_starts = SortedList() + self._sorted_ends = SortedList() + + def add(self, start: int, end: int) -> None: + if start >= end: + return + + # ranges[:min_idx] and ranges[max_idx:] is unaffected by this insertion + # the first range whose end point >= the start of the new range + min_idx = self._sorted_ends.bisect_left(start) + # the first range whose start point > the end point of the new range + max_idx = self._sorted_starts.bisect_right(end) + + if min_idx >= len(self._sorted_starts) or max_idx <= 0: + # the new range is beyond any current ranges + new_start = start + new_end = end + else: + # the new range overlaps with ranges[min_idx:max_idx] + new_start = min(start, self._sorted_starts[min_idx]) + new_end = max(end, self._sorted_ends[max_idx - 1]) + + del self._sorted_starts[min_idx:max_idx] + del self._sorted_ends[min_idx:max_idx] + + self._sorted_starts.add(new_start) + self._sorted_ends.add(new_end) + + def __contains__(self, key: int) -> bool: + idx = self._sorted_starts.bisect_left(key) + return (idx < len(self._sorted_starts) and self._sorted_starts[idx] == key + ) or (idx > 0 and self._sorted_ends[idx - 1] > key) + + def __len__(self) -> int: + assert len(self._sorted_starts) == len(self._sorted_ends) + return len(self._sorted_starts) + + def __iter__(self) -> Iterator[Tuple[int, int]]: + return zip(self._sorted_starts, self._sorted_ends) + + def __str__(self) -> str: + return str(list(zip(self._sorted_starts, self._sorted_ends))) + + +class SynchronousOrderedListRuntimeState(userstate.OrderedListRuntimeState): + RANGE_MIN = -(1 << 63) + RANGE_MAX = (1 << 63) - 1 + TIMESTAMP_RANGE_MIN = timestamp.Timestamp(micros=RANGE_MIN) + TIMESTAMP_RANGE_MAX = timestamp.Timestamp(micros=RANGE_MAX) + + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: + self._state_handler = state_handler + self._state_key = state_key + self._elem_coder = beam.coders.TupleCoder( + [coders.VarIntCoder(), coders.coders.LengthPrefixCoder(value_coder)]) + self._cleared = False + self._pending_adds = SortedDict() + self._pending_removes = RangeSet() + + def add(self, elem: Tuple[timestamp.Timestamp, Any]) -> None: + assert len(elem) == 2 + key_ts, value = elem + key = key_ts.micros + + if key >= self.RANGE_MAX or key < self.RANGE_MIN: + raise ValueError("key value %d is out of range" % key) + self._pending_adds.setdefault(key, []).append(value) + + def read(self) -> Iterable[Tuple[timestamp.Timestamp, Any]]: + return self.read_range(self.TIMESTAMP_RANGE_MIN, self.TIMESTAMP_RANGE_MAX) + + def read_range( + self, + min_timestamp: timestamp.Timestamp, + limit_timestamp: timestamp.Timestamp + ) -> Iterable[Tuple[timestamp.Timestamp, Any]]: + # convert timestamp to int, as sort keys are stored as int internally. + min_key = min_timestamp.micros + limit_key = limit_timestamp.micros + + keys_to_add = self._pending_adds.irange( + min_key, limit_key, inclusive=(True, False)) + + # use list interpretation here to construct the actual list + # of iterators of the selected range. + local_items = chain.from_iterable([ + itertools.islice( + zip(itertools.cycle([ + k, + ]), self._pending_adds[k]), + len(self._pending_adds[k])) for k in keys_to_add + ]) + + if not self._cleared: + range_query_state_key = beam_fn_api_pb2.StateKey() + range_query_state_key.CopyFrom(self._state_key) + range_query_state_key.ordered_list_user_state.range.start = min_key + range_query_state_key.ordered_list_user_state.range.end = limit_key + + # make a deep copy here because there could be other operations occur in + # the middle of an iteration and change pending_removes + pending_removes_snapshot = copy.deepcopy(self._pending_removes) + persistent_items = filter( + lambda kv: kv[0] not in pending_removes_snapshot, + _StateBackedIterable( + self._state_handler, range_query_state_key, self._elem_coder)) + + return map( + lambda x: (timestamp.Timestamp(micros=x[0]), x[1]), + heapq.merge(persistent_items, local_items)) + + return map(lambda x: (timestamp.Timestamp(micros=x[0]), x[1]), local_items) + + def clear(self) -> None: + self._cleared = True + self._pending_adds = SortedDict() + self._pending_removes = RangeSet() + self._pending_removes.add(self.RANGE_MIN, self.RANGE_MAX) + + def clear_range( + self, + min_timestamp: timestamp.Timestamp, + limit_timestamp: timestamp.Timestamp) -> None: + min_key = min_timestamp.micros + limit_key = limit_timestamp.micros + + # materialize the keys to remove before the actual removal + keys_to_remove = list( + self._pending_adds.irange(min_key, limit_key, inclusive=(True, False))) + for k in keys_to_remove: + del self._pending_adds[k] + + if not self._cleared: + self._pending_removes.add(min_key, limit_key) + + def commit(self) -> None: + futures = [] + if self._pending_removes: + for start, end in self._pending_removes: + range_query_state_key = beam_fn_api_pb2.StateKey() + range_query_state_key.CopyFrom(self._state_key) + range_query_state_key.ordered_list_user_state.range.start = start + range_query_state_key.ordered_list_user_state.range.end = end + futures.append(self._state_handler.clear(range_query_state_key)) + + self._pending_removes = RangeSet() + + if self._pending_adds: + items_to_add = [] + for k in self._pending_adds: + items_to_add.extend(zip(itertools.cycle([ + k, + ]), self._pending_adds[k])) + futures.append( + self._state_handler.extend( + self._state_key, self._elem_coder.get_impl(), items_to_add)) + self._pending_adds = SortedDict() + + if len(futures): + # To commit, we need to wait on every state request futures to complete. + for to_await in futures: + to_await.get() + + self._cleared = False + + class OutputTimer(userstate.BaseTimer): - def __init__(self, - key, - window, # type: BoundedWindow - timestamp, # type: timestamp.Timestamp - paneinfo, # type: windowed_value.PaneInfo - time_domain, # type: str - timer_family_id, # type: str - timer_coder_impl, # type: coder_impl.TimerCoderImpl - output_stream # type: data_plane.ClosableOutputStream - ): + def __init__( + self, + key, + window: BoundedWindow, + timestamp: timestamp.Timestamp, + paneinfo: windowed_value.PaneInfo, + time_domain: str, + timer_family_id: str, + timer_coder_impl: coder_impl.TimerCoderImpl, + output_stream: data_plane.ClosableOutputStream): self._key = key self._window = window self._input_timestamp = timestamp @@ -760,15 +921,13 @@ def __init__(self, timer_coder_impl, output_stream=None): class FnApiUserStateContext(userstate.UserStateContext): """Interface for state and timers from SDK to Fn API servicer of state..""" - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - transform_id, # type: str - key_coder, # type: coders.Coder - window_coder, # type: coders.Coder - ): - # type: (...) -> None - + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + transform_id: str, + key_coder: coders.Coder, + window_coder: coders.Coder, + ) -> None: """Initialize a ``FnApiUserStateContext``. Args: @@ -782,11 +941,10 @@ def __init__(self, self._key_coder = key_coder self._window_coder = window_coder # A mapping of {timer_family_id: TimerInfo} - self._timers_info = {} # type: Dict[str, TimerInfo] - self._all_states = {} # type: Dict[tuple, FnApiUserRuntimeStateTypes] + self._timers_info: Dict[str, TimerInfo] = {} + self._all_states: Dict[tuple, FnApiUserRuntimeStateTypes] = {} - def add_timer_info(self, timer_family_id, timer_info): - # type: (str, TimerInfo) -> None + def add_timer_info(self, timer_family_id: str, timer_info: TimerInfo) -> None: self._timers_info[timer_family_id] = timer_info def get_timer( @@ -805,19 +963,15 @@ def get_timer( timer_coder_impl, output_stream) - def get_state(self, *args): - # type: (*Any) -> FnApiUserRuntimeStateTypes + def get_state(self, *args: Any) -> FnApiUserRuntimeStateTypes: state_handle = self._all_states.get(args) if state_handle is None: state_handle = self._all_states[args] = self._create_state(*args) return state_handle - def _create_state(self, - state_spec, # type: userstate.StateSpec - key, - window # type: BoundedWindow - ): - # type: (...) -> FnApiUserRuntimeStateTypes + def _create_state( + self, state_spec: userstate.StateSpec, key, + window: BoundedWindow) -> FnApiUserRuntimeStateTypes: if isinstance(state_spec, (userstate.BagStateSpec, userstate.CombiningValueStateSpec, @@ -850,16 +1004,25 @@ def _create_state(self, # State keys are expected in nested encoding format key=self._key_coder.encode_nested(key))), value_coder=state_spec.coder) + elif isinstance(state_spec, userstate.OrderedListStateSpec): + return SynchronousOrderedListRuntimeState( + self._state_handler, + state_key=beam_fn_api_pb2.StateKey( + ordered_list_user_state=beam_fn_api_pb2.StateKey. + OrderedListUserState( + transform_id=self._transform_id, + user_state_id=state_spec.name, + window=self._window_coder.encode(window), + key=self._key_coder.encode_nested(key))), + value_coder=state_spec.coder) else: raise NotImplementedError(state_spec) - def commit(self): - # type: () -> None + def commit(self) -> None: for state in self._all_states.values(): state.commit() - def reset(self): - # type: () -> None + def reset(self) -> None: for state in self._all_states.values(): state.finalize() self._all_states = {} @@ -878,14 +1041,12 @@ def wrapper(*args): return wrapper -def only_element(iterable): - # type: (Iterable[T]) -> T +def only_element(iterable: Iterable[T]) -> T: element, = iterable return element -def _environments_compatible(submission, runtime): - # type: (str, str) -> bool +def _environments_compatible(submission: str, runtime: str) -> bool: if submission == runtime: return True if 'rc' in submission and runtime in submission: @@ -895,8 +1056,8 @@ def _environments_compatible(submission, runtime): return False -def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor): - # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None +def _verify_descriptor_created_in_a_compatible_env( + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor) -> None: runtime_sdk = environments.sdk_base_version_capability() for t in process_bundle_descriptor.transforms.values(): @@ -918,16 +1079,14 @@ def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor): class BundleProcessor(object): """ A class for processing bundles of elements. """ - - def __init__(self, - runner_capabilities, # type: FrozenSet[str] - process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor - state_handler, # type: sdk_worker.CachingStateHandler - data_channel_factory, # type: data_plane.DataChannelFactory - data_sampler=None, # type: Optional[data_sampler.DataSampler] - ): - # type: (...) -> None - + def __init__( + self, + runner_capabilities: FrozenSet[str], + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + state_handler: sdk_worker.CachingStateHandler, + data_channel_factory: data_plane.DataChannelFactory, + data_sampler: Optional[data_sampler.DataSampler] = None, + ) -> None: """Initialize a bundle processor. Args: @@ -943,7 +1102,7 @@ def __init__(self, self.state_handler = state_handler self.data_channel_factory = data_channel_factory self.data_sampler = data_sampler - self.current_instruction_id = None # type: Optional[str] + self.current_instruction_id: Optional[str] = None # Represents whether the SDK is consuming received data. self.consuming_received_data = False @@ -962,7 +1121,7 @@ def __init__(self, # {(transform_id, timer_family_id): TimerInfo} # The mapping is empty when there is no timer_family_specs in the # ProcessBundleDescriptor. - self.timers_info = {} # type: Dict[Tuple[str, str], TimerInfo] + self.timers_info: Dict[Tuple[str, str], TimerInfo] = {} # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. @@ -977,10 +1136,8 @@ def __init__(self, self.splitting_lock = threading.Lock() def create_execution_tree( - self, - descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor - ): - # type: (...) -> collections.OrderedDict[str, operations.DoOperation] + self, descriptor: beam_fn_api_pb2.ProcessBundleDescriptor + ) -> collections.OrderedDict[str, operations.DoOperation]: transform_factory = BeamTransformFactory( self.runner_capabilities, descriptor, @@ -999,16 +1156,14 @@ def is_side_input(transform_proto, tag): transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload).side_inputs - pcoll_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[str]] + pcoll_consumers: DefaultDict[str, List[str]] = collections.defaultdict(list) for transform_id, transform_proto in descriptor.transforms.items(): for tag, pcoll_id in transform_proto.inputs.items(): if not is_side_input(transform_proto, tag): pcoll_consumers[pcoll_id].append(transform_id) @memoize - def get_operation(transform_id): - # type: (str) -> operations.Operation + def get_operation(transform_id: str) -> operations.Operation: transform_consumers = { tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] for tag, @@ -1025,8 +1180,7 @@ def get_operation(transform_id): # Operations must be started (hence returned) in order. @memoize - def topological_height(transform_id): - # type: (str) -> int + def topological_height(transform_id: str) -> int: return 1 + max([0] + [ topological_height(consumer) for pcoll in descriptor.transforms[transform_id].outputs.values() @@ -1039,18 +1193,18 @@ def topological_height(transform_id): get_operation(transform_id))) for transform_id in sorted( descriptor.transforms, key=topological_height, reverse=True)]) - def reset(self): - # type: () -> None + def reset(self) -> None: self.counter_factory.reset() self.state_sampler.reset() # Side input caches. for op in self.ops.values(): op.reset() - def process_bundle(self, instruction_id): - # type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + def process_bundle( + self, instruction_id: str + ) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]: - expected_input_ops = [] # type: List[DataInputOperation] + expected_input_ops: List[DataInputOperation] = [] for op in self.ops.values(): if isinstance(op, DataOutputOperation): @@ -1076,9 +1230,10 @@ def process_bundle(self, instruction_id): # both data input and timer input. The data input is identied by # transform_id. The data input is identified by # (transform_id, timer_family_id). - data_channels = collections.defaultdict( - list - ) # type: DefaultDict[data_plane.DataChannel, List[Union[str, Tuple[str, str]]]] + data_channels: DefaultDict[data_plane.DataChannel, + List[Union[str, Tuple[ + str, + str]]]] = collections.defaultdict(list) # Add expected data inputs for each data channel. input_op_by_transform_id = {} @@ -1144,18 +1299,17 @@ def process_bundle(self, instruction_id): self.current_instruction_id = None self.state_sampler.stop_if_still_running() - def finalize_bundle(self): - # type: () -> beam_fn_api_pb2.FinalizeBundleResponse + def finalize_bundle(self) -> beam_fn_api_pb2.FinalizeBundleResponse: for op in self.ops.values(): op.finalize_bundle() return beam_fn_api_pb2.FinalizeBundleResponse() - def requires_finalization(self): - # type: () -> bool + def requires_finalization(self) -> bool: return any(op.needs_finalization() for op in self.ops.values()) - def try_split(self, bundle_split_request): - # type: (beam_fn_api_pb2.ProcessBundleSplitRequest) -> beam_fn_api_pb2.ProcessBundleSplitResponse + def try_split( + self, bundle_split_request: beam_fn_api_pb2.ProcessBundleSplitRequest + ) -> beam_fn_api_pb2.ProcessBundleSplitResponse: split_response = beam_fn_api_pb2.ProcessBundleSplitResponse() with self.splitting_lock: if bundle_split_request.instruction_id != self.current_instruction_id: @@ -1193,20 +1347,18 @@ def try_split(self, bundle_split_request): return split_response - def delayed_bundle_application(self, - op, # type: operations.DoOperation - deferred_remainder # type: SplitResultResidual - ): - # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication + def delayed_bundle_application( + self, op: operations.DoOperation, deferred_remainder: SplitResultResidual + ) -> beam_fn_api_pb2.DelayedBundleApplication: assert op.input_info is not None # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. (element_and_restriction, current_watermark, deferred_timestamp) = ( deferred_remainder) if deferred_timestamp: assert isinstance(deferred_timestamp, timestamp.Duration) - proto_deferred_watermark = proto_utils.from_micros( - duration_pb2.Duration, - deferred_timestamp.micros) # type: Optional[duration_pb2.Duration] + proto_deferred_watermark: Optional[ + duration_pb2.Duration] = proto_utils.from_micros( + duration_pb2.Duration, deferred_timestamp.micros) else: proto_deferred_watermark = None return beam_fn_api_pb2.DelayedBundleApplication( @@ -1214,29 +1366,26 @@ def delayed_bundle_application(self, application=self.construct_bundle_application( op.input_info, current_watermark, element_and_restriction)) - def bundle_application(self, - op, # type: operations.DoOperation - primary # type: SplitResultPrimary - ): - # type: (...) -> beam_fn_api_pb2.BundleApplication + def bundle_application( + self, op: operations.DoOperation, + primary: SplitResultPrimary) -> beam_fn_api_pb2.BundleApplication: assert op.input_info is not None return self.construct_bundle_application( op.input_info, None, primary.primary_value) - def construct_bundle_application(self, - op_input_info, # type: operations.OpInputInfo - output_watermark, # type: Optional[timestamp.Timestamp] - element - ): - # type: (...) -> beam_fn_api_pb2.BundleApplication + def construct_bundle_application( + self, + op_input_info: operations.OpInputInfo, + output_watermark: Optional[timestamp.Timestamp], + element) -> beam_fn_api_pb2.BundleApplication: transform_id, main_input_tag, main_input_coder, outputs = op_input_info if output_watermark: proto_output_watermark = proto_utils.from_micros( timestamp_pb2.Timestamp, output_watermark.micros) - output_watermarks = { + output_watermarks: Optional[Dict[str, timestamp_pb2.Timestamp]] = { output: proto_output_watermark for output in outputs - } # type: Optional[Dict[str, timestamp_pb2.Timestamp]] + } else: output_watermarks = None return beam_fn_api_pb2.BundleApplication( @@ -1245,9 +1394,7 @@ def construct_bundle_application(self, output_watermarks=output_watermarks, element=main_input_coder.get_impl().encode_nested(element)) - def monitoring_infos(self): - # type: () -> List[metrics_pb2.MonitoringInfo] - + def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]: """Returns the list of MonitoringInfos collected processing this bundle.""" # Construct a new dict first to remove duplicates. all_monitoring_infos_dict = {} @@ -1259,8 +1406,7 @@ def monitoring_infos(self): return list(all_monitoring_infos_dict.values()) - def shutdown(self): - # type: () -> None + def shutdown(self) -> None: for op in self.ops.values(): op.teardown() @@ -1281,15 +1427,16 @@ class ExecutionContext: class BeamTransformFactory(object): """Factory for turning transform_protos into executable operations.""" - def __init__(self, - runner_capabilities, # type: FrozenSet[str] - descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor - data_channel_factory, # type: data_plane.DataChannelFactory - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - state_handler, # type: sdk_worker.CachingStateHandler - data_sampler, # type: Optional[data_sampler.DataSampler] - ): + def __init__( + self, + runner_capabilities: FrozenSet[str], + descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + data_channel_factory: data_plane.DataChannelFactory, + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + state_handler: sdk_worker.CachingStateHandler, + data_sampler: Optional[data_sampler.DataSampler], + ): self.runner_capabilities = runner_capabilities self.descriptor = descriptor self.data_channel_factory = data_channel_factory @@ -1306,27 +1453,41 @@ def __init__(self, element_coder_impl)) self.data_sampler = data_sampler - _known_urns = { - } # type: Dict[str, Tuple[ConstructorFn, Union[Type[message.Message], Type[bytes], None]]] + _known_urns: Dict[str, + Tuple[ConstructorFn, + Union[Type[message.Message], Type[bytes], + None]]] = {} @classmethod def register_urn( - cls, - urn, # type: str - parameter_type # type: Optional[Type[T]] - ): - # type: (...) -> Callable[[Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]], Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]] + cls, urn: str, parameter_type: Optional[Type[T]] + ) -> Callable[[ + Callable[[ + BeamTransformFactory, + str, + beam_runner_api_pb2.PTransform, + T, + Dict[str, List[operations.Operation]] + ], + operations.Operation] + ], + Callable[[ + BeamTransformFactory, + str, + beam_runner_api_pb2.PTransform, + T, + Dict[str, List[operations.Operation]] + ], + operations.Operation]]: def wrapper(func): cls._known_urns[urn] = func, parameter_type return func return wrapper - def create_operation(self, - transform_id, # type: str - consumers # type: Dict[str, List[operations.Operation]] - ): - # type: (...) -> operations.Operation + def create_operation( + self, transform_id: str, + consumers: Dict[str, List[operations.Operation]]) -> operations.Operation: transform_proto = self.descriptor.transforms[transform_id] if not transform_proto.unique_name: _LOGGER.debug("No unique name set for transform %s" % transform_id) @@ -1336,8 +1497,7 @@ def create_operation(self, transform_proto.spec.payload, parameter_type) return creator(self, transform_id, transform_proto, payload, consumers) - def extract_timers_info(self): - # type: () -> Dict[Tuple[str, str], TimerInfo] + def extract_timers_info(self) -> Dict[Tuple[str, str], TimerInfo]: timers_info = {} for transform_id, transform_proto in self.descriptor.transforms.items(): if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: @@ -1352,8 +1512,7 @@ def extract_timers_info(self): timer_coder_impl=timer_coder_impl) return timers_info - def get_coder(self, coder_id): - # type: (str) -> coders.Coder + def get_coder(self, coder_id: str) -> coders.Coder: if coder_id not in self.descriptor.coders: raise KeyError("No such coder: %s" % coder_id) coder_proto = self.descriptor.coders[coder_id] @@ -1364,8 +1523,7 @@ def get_coder(self, coder_id): return operation_specs.get_coder_from_spec( json.loads(coder_proto.spec.payload.decode('utf-8'))) - def get_windowed_coder(self, pcoll_id): - # type: (str) -> WindowedValueCoder + def get_windowed_coder(self, pcoll_id: str) -> WindowedValueCoder: coder = self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) # TODO(robertwb): Remove this condition once all runners are consistent. if not isinstance(coder, WindowedValueCoder): @@ -1376,32 +1534,34 @@ def get_windowed_coder(self, pcoll_id): else: return coder - def get_output_coders(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.Coder] + def get_output_coders( + self, transform_proto: beam_runner_api_pb2.PTransform + ) -> Dict[str, coders.Coder]: return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.outputs.items() } - def get_only_output_coder(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> coders.Coder + def get_only_output_coder( + self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: return only_element(self.get_output_coders(transform_proto).values()) - def get_input_coders(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.WindowedValueCoder] + def get_input_coders( + self, transform_proto: beam_runner_api_pb2.PTransform + ) -> Dict[str, coders.WindowedValueCoder]: return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.inputs.items() } - def get_only_input_coder(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> coders.Coder + def get_only_input_coder( + self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: return only_element(list(self.get_input_coders(transform_proto).values())) - def get_input_windowing(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Windowing + def get_input_windowing( + self, transform_proto: beam_runner_api_pb2.PTransform) -> Windowing: pcoll_id = only_element(transform_proto.inputs.values()) windowing_strategy_id = self.descriptor.pcollections[ pcoll_id].windowing_strategy_id @@ -1410,12 +1570,10 @@ def get_input_windowing(self, transform_proto): # TODO(robertwb): Update all operations to take these in the constructor. @staticmethod def augment_oldstyle_op( - op, # type: OperationT - step_name, # type: str - consumers, # type: Mapping[str, Iterable[operations.Operation]] - tag_list=None # type: Optional[List[str]] - ): - # type: (...) -> OperationT + op: OperationT, + step_name: str, + consumers: Mapping[str, Iterable[operations.Operation]], + tag_list: Optional[List[str]] = None) -> OperationT: op.step_name = step_name for tag, op_consumers in consumers.items(): for consumer in op_consumers: @@ -1426,13 +1584,11 @@ def augment_oldstyle_op( @BeamTransformFactory.register_urn( DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) def create_source_runner( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> DataInputOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + grpc_port: beam_fn_api_pb2.RemoteGrpcPort, + consumers: Dict[str, List[operations.Operation]]) -> DataInputOperation: output_coder = factory.get_coder(grpc_port.coder_id) return DataInputOperation( @@ -1449,13 +1605,11 @@ def create_source_runner( @BeamTransformFactory.register_urn( DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) def create_sink_runner( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> DataOutputOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + grpc_port: beam_fn_api_pb2.RemoteGrpcPort, + consumers: Dict[str, List[operations.Operation]]) -> DataOutputOperation: output_coder = factory.get_coder(grpc_port.coder_id) return DataOutputOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -1470,13 +1624,12 @@ def create_sink_runner( @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_READ_URN, None) def create_source_java( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, parameter, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ReadOperation + consumers: Dict[str, + List[operations.Operation]]) -> operations.ReadOperation: # The Dataflow runner harness strips the base64 encoding. source = pickler.loads(base64.b64encode(parameter)) spec = operation_specs.WorkerRead( @@ -1495,13 +1648,12 @@ def create_source_java( @BeamTransformFactory.register_urn( common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload) def create_deprecated_read( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ReadPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ReadOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ReadPayload, + consumers: Dict[str, + List[operations.Operation]]) -> operations.ReadOperation: source = iobase.BoundedSource.from_runner_api( parameter.source, factory.context) spec = operation_specs.WorkerRead( @@ -1520,13 +1672,12 @@ def create_deprecated_read( @BeamTransformFactory.register_urn( python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload) def create_read_from_impulse_python( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ReadPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ImpulseReadOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ReadPayload, + consumers: Dict[str, List[operations.Operation]] +) -> operations.ImpulseReadOperation: return operations.ImpulseReadOperation( common.NameContext(transform_proto.unique_name, transform_id), factory.counter_factory, @@ -1538,12 +1689,11 @@ def create_read_from_impulse_python( @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN, None) def create_dofn_javasdk( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, serialized_fn, - consumers # type: Dict[str, List[operations.Operation]] -): + consumers: Dict[str, List[operations.Operation]]): return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -1627,12 +1777,11 @@ def process(self, element_restriction, *args, **kwargs): common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn, beam_runner_api_pb2.ParDoPayload) def create_process_sized_elements_and_restrictions( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ParDoPayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ParDoPayload, + consumers: Dict[str, List[operations.Operation]]): return _create_pardo_operation( factory, transform_id, @@ -1674,13 +1823,11 @@ def _create_sdf_operation( @BeamTransformFactory.register_urn( common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) def create_par_do( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ParDoPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.DoOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ParDoPayload, + consumers: Dict[str, List[operations.Operation]]) -> operations.DoOperation: return _create_pardo_operation( factory, transform_id, @@ -1692,14 +1839,13 @@ def create_par_do( def _create_pardo_operation( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, consumers, serialized_fn, - pardo_proto=None, # type: Optional[beam_runner_api_pb2.ParDoPayload] - operation_cls=operations.DoOperation -): + pardo_proto: Optional[beam_runner_api_pb2.ParDoPayload] = None, + operation_cls=operations.DoOperation): if pardo_proto and pardo_proto.side_inputs: input_tags_to_coders = factory.get_input_coders(transform_proto) @@ -1731,9 +1877,8 @@ def _create_pardo_operation( if not dofn_data[-1]: # Windowing not set. if pardo_proto: - other_input_tags = set.union( - set(pardo_proto.side_inputs), - set(pardo_proto.timer_family_specs)) # type: Container[str] + other_input_tags: Container[str] = set.union( + set(pardo_proto.side_inputs), set(pardo_proto.timer_family_specs)) else: other_input_tags = () pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items() @@ -1757,12 +1902,12 @@ def _create_pardo_operation( main_input_coder = found_input_coder if pardo_proto.timer_family_specs or pardo_proto.state_specs: - user_state_context = FnApiUserStateContext( - factory.state_handler, - transform_id, - main_input_coder.key_coder(), - main_input_coder.window_coder - ) # type: Optional[FnApiUserStateContext] + user_state_context: Optional[ + FnApiUserStateContext] = FnApiUserStateContext( + factory.state_handler, + transform_id, + main_input_coder.key_coder(), + main_input_coder.window_coder) else: user_state_context = None else: @@ -1796,12 +1941,13 @@ def _create_pardo_operation( return result -def _create_simple_pardo_operation(factory, # type: BeamTransformFactory - transform_id, - transform_proto, - consumers, - dofn, # type: beam.DoFn - ): +def _create_simple_pardo_operation( + factory: BeamTransformFactory, + transform_id, + transform_proto, + consumers, + dofn: beam.DoFn, +): serialized_fn = pickler.dumps((dofn, (), {}, [], None)) return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -1811,12 +1957,11 @@ def _create_simple_pardo_operation(factory, # type: BeamTransformFactory common_urns.primitives.ASSIGN_WINDOWS.urn, beam_runner_api_pb2.WindowingStrategy) def create_assign_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.WindowingStrategy - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.WindowingStrategy, + consumers: Dict[str, List[operations.Operation]]): class WindowIntoDoFn(beam.DoFn): def __init__(self, windowing): self.windowing = windowing @@ -1843,13 +1988,12 @@ def process( @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None) def create_identity_dofn( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, parameter, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.FlattenOperation + consumers: Dict[str, List[operations.Operation]] +) -> operations.FlattenOperation: return factory.augment_oldstyle_op( operations.FlattenOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -1865,13 +2009,12 @@ def create_identity_dofn( common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_precombine( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.PGBKCVOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, + List[operations.Operation]]) -> operations.PGBKCVOperation: serialized_combine_fn = pickler.dumps(( beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -1892,12 +2035,11 @@ def create_combine_per_key_precombine( common_urns.combine_components.COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn, beam_runner_api_pb2.CombinePayload) def create_combbine_per_key_merge_accumulators( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'merge') @@ -1906,12 +2048,11 @@ def create_combbine_per_key_merge_accumulators( common_urns.combine_components.COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_extract_outputs( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'extract') @@ -1920,12 +2061,11 @@ def create_combine_per_key_extract_outputs( common_urns.combine_components.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_convert_to_accumulators( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'convert') @@ -1934,19 +2074,18 @@ def create_combine_per_key_convert_to_accumulators( common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, beam_runner_api_pb2.CombinePayload) def create_combine_grouped_values( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'all') def _create_combine_phase_operation( - factory, transform_id, transform_proto, payload, consumers, phase): - # type: (...) -> operations.CombineOperation + factory, transform_id, transform_proto, payload, consumers, + phase) -> operations.CombineOperation: serialized_combine_fn = pickler.dumps(( beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -1965,13 +2104,12 @@ def _create_combine_phase_operation( @BeamTransformFactory.register_urn(common_urns.primitives.FLATTEN.urn, None) def create_flatten( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, payload, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.FlattenOperation + consumers: Dict[str, List[operations.Operation]] +) -> operations.FlattenOperation: return factory.augment_oldstyle_op( operations.FlattenOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -1986,12 +2124,11 @@ def create_flatten( @BeamTransformFactory.register_urn( common_urns.primitives.MAP_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) def create_map_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN window_mapping_fn = pickler.loads(mapping_fn_spec.payload) @@ -2007,12 +2144,11 @@ def process(self, element): @BeamTransformFactory.register_urn( common_urns.primitives.MERGE_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) def create_merge_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOWFN window_fn = pickler.loads(mapping_fn_spec.payload) @@ -2020,24 +2156,25 @@ class MergeWindows(beam.DoFn): def process(self, element): nonce, windows = element - original_windows = set(windows) # type: Set[window.BoundedWindow] - merged_windows = collections.defaultdict( - set - ) # type: MutableMapping[window.BoundedWindow, Set[window.BoundedWindow]] # noqa: F821 + original_windows: Set[window.BoundedWindow] = set(windows) + merged_windows: MutableMapping[ + window.BoundedWindow, + Set[window.BoundedWindow]] = collections.defaultdict( + set) # noqa: F821 class RecordingMergeContext(window.WindowFn.MergeContext): def merge( self, - to_be_merged, # type: Iterable[window.BoundedWindow] - merge_result, # type: window.BoundedWindow - ): + to_be_merged: Iterable[window.BoundedWindow], + merge_result: window.BoundedWindow, + ): originals = merged_windows[merge_result] - for window in to_be_merged: - if window in original_windows: - originals.add(window) - original_windows.remove(window) + for w in to_be_merged: + if w in original_windows: + originals.add(w) + original_windows.remove(w) else: - originals.update(merged_windows.pop(window)) + originals.update(merged_windows.pop(w)) window_fn.merge(RecordingMergeContext(windows)) yield nonce, (original_windows, merged_windows.items()) @@ -2048,12 +2185,11 @@ def merge( @BeamTransformFactory.register_urn(common_urns.primitives.TO_STRING.urn, None) def create_to_string_fn( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): class ToString(beam.DoFn): def process(self, element): key, value = element diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py index dafb4dbd4bf0..0eb4dd9485fd 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py @@ -18,24 +18,31 @@ """Unit tests for bundle processing.""" # pytype: skip-file +import random import unittest import apache_beam as beam +from apache_beam.coders import StrUtf8Coder from apache_beam.coders.coders import FastPrimitivesCoder from apache_beam.portability import common_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners import common +from apache_beam.runners.portability.fn_api_runner.worker_handlers import StateServicer from apache_beam.runners.worker import bundle_processor from apache_beam.runners.worker import operations from apache_beam.runners.worker.bundle_processor import BeamTransformFactory from apache_beam.runners.worker.bundle_processor import BundleProcessor from apache_beam.runners.worker.bundle_processor import DataInputOperation from apache_beam.runners.worker.bundle_processor import FnApiUserStateContext +from apache_beam.runners.worker.bundle_processor import SynchronousOrderedListRuntimeState from apache_beam.runners.worker.bundle_processor import TimerInfo from apache_beam.runners.worker.data_plane import SizeBasedBufferingClosableOutputStream from apache_beam.runners.worker.data_sampler import DataSampler +from apache_beam.runners.worker.sdk_worker import GlobalCachingStateHandler +from apache_beam.runners.worker.statecache import StateCache from apache_beam.transforms import userstate from apache_beam.transforms.window import GlobalWindow +from apache_beam.utils import timestamp from apache_beam.utils.windowed_value import WindowedValue @@ -422,5 +429,312 @@ def test_user_modified_sdks_need_to_be_installed_in_runtime_env(self): "beam:version:sdk_base:apache/beam_python3.5_sdk:2.1.0-custom")) +class OrderedListStateTest(unittest.TestCase): + class NoStateCache(StateCache): + def __init__(self): + super().__init__(max_weight=0) + + @staticmethod + def _create_state(window=b"my_window", key=b"my_key", coder=StrUtf8Coder()): + state_handler = GlobalCachingStateHandler( + OrderedListStateTest.NoStateCache(), StateServicer()) + state_key = beam_fn_api_pb2.StateKey( + ordered_list_user_state=beam_fn_api_pb2.StateKey.OrderedListUserState( + window=window, key=key)) + return SynchronousOrderedListRuntimeState(state_handler, state_key, coder) + + def setUp(self): + self.state = self._create_state() + + def test_read_range(self): + T0 = timestamp.Timestamp.of(0) + T1 = timestamp.Timestamp.of(1) + T2 = timestamp.Timestamp.of(2) + T3 = timestamp.Timestamp.of(3) + T4 = timestamp.Timestamp.of(4) + T5 = timestamp.Timestamp.of(5) + T9 = timestamp.Timestamp.of(9) + A1, B1, A4 = [(T1, "a1"), (T1, "b1"), (T4, "a4")] + self.assertEqual([], list(self.state.read_range(T0, T5))) + + self.state.add(A1) + self.assertEqual([A1], list(self.state.read_range(T0, T5))) + + self.state.add(B1) + self.assertEqual([A1, B1], list(self.state.read_range(T0, T5))) + + self.state.add(A4) + self.assertEqual([A1, B1, A4], list(self.state.read_range(T0, T5))) + + self.assertEqual([], list(self.state.read_range(T0, T1))) + self.assertEqual([], list(self.state.read_range(T5, T9))) + self.assertEqual([A1, B1], list(self.state.read_range(T1, T2))) + self.assertEqual([], list(self.state.read_range(T2, T3))) + self.assertEqual([], list(self.state.read_range(T2, T4))) + self.assertEqual([A4], list(self.state.read_range(T4, T5))) + + def test_read(self): + T1 = timestamp.Timestamp.of(1) + T4 = timestamp.Timestamp.of(4) + A1, B1, A4 = [(T1, "a1"), (T1, "b1"), (T4, "a4")] + self.assertEqual([], list(self.state.read())) + + self.state.add(A1) + self.assertEqual([A1], list(self.state.read())) + + self.state.add(A1) + self.assertEqual([A1, A1], list(self.state.read())) + + self.state.add(B1) + self.assertEqual([A1, A1, B1], list(self.state.read())) + + self.state.add(A4) + self.assertEqual([A1, A1, B1, A4], list(self.state.read())) + + def test_clear_range(self): + T0 = timestamp.Timestamp.of(0) + T1 = timestamp.Timestamp.of(1) + T2 = timestamp.Timestamp.of(2) + T3 = timestamp.Timestamp.of(3) + T4 = timestamp.Timestamp.of(4) + T5 = timestamp.Timestamp.of(5) + A1, B1, A4, A5 = [(T1, "a1"), (T1, "b1"), (T4, "a4"), (T5, "a5")] + self.state.clear_range(T0, T1) + self.assertEqual([], list(self.state.read())) + + self.state.add(A1) + self.state.add(B1) + self.state.add(A4) + self.state.add(A5) + self.assertEqual([A1, B1, A4, A5], list(self.state.read())) + + self.state.clear_range(T0, T1) + self.assertEqual([A1, B1, A4, A5], list(self.state.read())) + + self.state.clear_range(T1, T2) + self.assertEqual([A4, A5], list(self.state.read())) + + # no side effect on clearing the same range twice + self.state.clear_range(T1, T2) + self.assertEqual([A4, A5], list(self.state.read())) + + self.state.clear_range(T3, T4) + self.assertEqual([A4, A5], list(self.state.read())) + + self.state.clear_range(T3, T5) + self.assertEqual([A5], list(self.state.read())) + + def test_add_and_clear_range_after_commit(self): + T1 = timestamp.Timestamp.of(1) + T4 = timestamp.Timestamp.of(4) + T5 = timestamp.Timestamp.of(5) + T6 = timestamp.Timestamp.of(6) + A1, B1, C1, A4, A5, A6 = [(T1, "a1"), (T1, "b1"), (T1, "c1"), + (T4, "a4"), (T5, "a5"), (T6, "a6")] + self.state.add(A1) + self.state.add(B1) + self.state.add(A4) + self.state.add(A5) + self.state.clear_range(T4, T5) + self.assertEqual([A1, B1, A5], list(self.state.read())) + + self.state.commit() + self.assertEqual(len(self.state._pending_adds), 0) + self.assertEqual(len(self.state._pending_removes), 0) + self.assertEqual([A1, B1, A5], list(self.state.read())) + + self.state.add(C1) + self.state.add(A6) + self.assertEqual([A1, B1, C1, A5, A6], list(self.state.read())) + + self.state.clear_range(T5, T6) + self.assertEqual([A1, B1, C1, A6], list(self.state.read())) + + self.state.commit() + self.assertEqual(len(self.state._pending_adds), 0) + self.assertEqual(len(self.state._pending_removes), 0) + self.assertEqual([A1, B1, C1, A6], list(self.state.read())) + + def test_clear(self): + T1 = timestamp.Timestamp.of(1) + T4 = timestamp.Timestamp.of(4) + T5 = timestamp.Timestamp.of(5) + T9 = timestamp.Timestamp.of(9) + A1, B1, C1, A4, A5, B5 = [(T1, "a1"), (T1, "b1"), (T1, "c1"), + (T4, "a4"), (T5, "a5"), (T5, "b5")] + self.state.add(A1) + self.state.add(B1) + self.state.add(A4) + self.state.add(A5) + self.state.clear_range(T4, T5) + self.assertEqual([A1, B1, A5], list(self.state.read())) + self.state.commit() + + self.state.add(C1) + self.state.clear_range(T5, T9) + self.assertEqual([A1, B1, C1], list(self.state.read())) + self.state.clear() + self.assertEqual(len(self.state._pending_adds), 0) + self.assertEqual(len(self.state._pending_removes), 1) + + self.state.add(B5) + self.assertEqual([B5], list(self.state.read())) + self.state.commit() + + self.assertEqual(len(self.state._pending_adds), 0) + self.assertEqual(len(self.state._pending_removes), 0) + + self.assertEqual([B5], list(self.state.read())) + + def test_multiple_iterators(self): + T1 = timestamp.Timestamp.of(1) + T3 = timestamp.Timestamp.of(3) + T9 = timestamp.Timestamp.of(9) + A1, B1, A3, B3 = [(T1, "a1"), (T1, "b1"), (T3, "a3"), (T3, "b3")] + self.state.add(A1) + self.state.add(A3) + self.state.commit() + + iter_before_b1 = iter(self.state.read()) + self.assertEqual(A1, next(iter_before_b1)) + + self.state.add(B1) + self.assertEqual(A3, next(iter_before_b1)) + self.assertRaises(StopIteration, lambda: next(iter_before_b1)) + + self.state.add(B3) + iter_before_clear_range = iter(self.state.read()) + self.assertEqual(A1, next(iter_before_clear_range)) + self.state.clear_range(T3, T9) + self.assertEqual(B1, next(iter_before_clear_range)) + self.assertEqual(A3, next(iter_before_clear_range)) + self.assertEqual(B3, next(iter_before_clear_range)) + self.assertRaises(StopIteration, lambda: next(iter_before_clear_range)) + self.assertEqual([A1, B1], list(self.state.read())) + + iter_before_clear = iter(self.state.read()) + self.assertEqual(A1, next(iter_before_clear)) + self.state.clear() + self.assertEqual(B1, next(iter_before_clear)) + self.assertRaises(StopIteration, lambda: next(iter_before_clear)) + + self.assertEqual([], list(self.state.read())) + + def fuzz_test_helper(self, seed=0, lower=0, upper=20): + class NaiveState: + def __init__(self): + self._data = [[] for i in range((upper - lower + 1))] + self._logs = [] + + def add(self, elem): + k, v = elem + k = k.micros + self._data[k - lower].append(v) + self._logs.append("add(%d, %s)" % (k, v)) + + def clear_range(self, lo, hi): + lo = lo.micros + hi = hi.micros + for i in range(lo, hi): + self._data[i - lower] = [] + self._logs.append("clear_range(%d, %d)" % (lo, hi)) + + def clear(self): + for i in range(len(self._data)): + self._data[i] = [] + self._logs.append("clear()") + + def read(self): + self._logs.append("read()") + for i in range(len(self._data)): + for v in self._data[i]: + yield (timestamp.Timestamp(micros=(i + lower)), v) + + random.seed(seed) + + state = self._create_state() + bench_state = NaiveState() + + steps = random.randint(20, 50) + for i in range(steps): + op = random.randint(1, 100) + if 1 <= op < 70: + num = random.randint(lower, upper) + state.add((timestamp.Timestamp(micros=num), "a%d" % num)) + bench_state.add((timestamp.Timestamp(micros=num), "a%d" % num)) + elif 70 <= op < 95: + num1 = random.randint(lower, upper) + num2 = random.randint(lower, upper) + min_time = timestamp.Timestamp(micros=min(num1, num2)) + max_time = timestamp.Timestamp(micros=max(num1, num2)) + state.clear_range(min_time, max_time) + bench_state.clear_range(min_time, max_time) + elif op >= 95: + state.clear() + bench_state.clear() + + op = random.randint(1, 10) + if 1 <= op <= 9: + pass + else: + state.commit() + + a = list(bench_state.read()) + b = list(state.read()) + self.assertEqual( + a, + b, + "Mismatch occurred on seed=%d, step=%d, logs=%s" % + (seed, i, ';'.join(bench_state._logs))) + + def test_fuzz(self): + for _ in range(1000): + seed = random.randint(0, 0xffffffffffffffff) + try: + self.fuzz_test_helper(seed=seed) + except Exception as e: + raise RuntimeError("Exception occurred on seed=%d: %s" % (seed, e)) + + def test_min_max(self): + T_MIN = timestamp.Timestamp(micros=(-(1 << 63))) + T_MAX_MINUS_ONE = timestamp.Timestamp(micros=((1 << 63) - 2)) + T_MAX = timestamp.Timestamp(micros=((1 << 63) - 1)) + T0 = timestamp.Timestamp(micros=0) + INT64_MIN, INT64_MAX_MINUS_ONE, INT64_MAX = [(T_MIN, "min"), + (T_MAX_MINUS_ONE, "max"), + (T_MAX, "err")] + self.state.add(INT64_MIN) + self.state.add(INT64_MAX_MINUS_ONE) + self.assertRaises(ValueError, lambda: self.state.add(INT64_MAX)) + + self.assertEqual([INT64_MIN, INT64_MAX_MINUS_ONE], list(self.state.read())) + self.assertEqual([INT64_MIN], list(self.state.read_range(T_MIN, T0))) + self.assertEqual([INT64_MAX_MINUS_ONE], + list(self.state.read_range(T0, T_MAX))) + + def test_continuation_token(self): + T1 = timestamp.Timestamp.of(1) + T2 = timestamp.Timestamp.of(2) + T7 = timestamp.Timestamp.of(7) + T8 = timestamp.Timestamp.of(8) + A1, A2, A7, B7, A8 = [(T1, "a1"), (T2, "a2"), (T7, "a7"), + (T7, "b7"), (T8, "a8")] + self.state._state_handler._underlying._use_continuation_tokens = True + self.assertEqual([], list(self.state.read_range(T1, T8))) + + self.state.add(A1) + self.state.add(A2) + self.state.add(A7) + self.state.add(B7) + self.state.add(A8) + + self.assertEqual([A2, A7, B7], list(self.state.read_range(T2, T8))) + + self.state.commit() + self.assertEqual([A2, A7, B7], list(self.state.read_range(T2, T8))) + + self.assertEqual([A1, A2, A7, B7, A8], list(self.state.read())) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index 88cc3c9791d5..979c7cdb53be 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -125,7 +125,7 @@ def emit(self, record: logging.LogRecord) -> None: log_entry.message = ( "Failed to format '%s' with args '%s' during logging." % (str(record.msg), record.args)) - log_entry.thread = record.threadName + log_entry.thread = record.threadName # type: ignore[assignment] log_entry.log_location = '%s:%s' % ( record.pathname or record.module, record.lineno or record.funcName) (fraction, seconds) = math.modf(record.created) diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py index 51ca4cf0545b..5496bccd014e 100644 --- a/sdks/python/apache_beam/runners/worker/opcounters.py +++ b/sdks/python/apache_beam/runners/worker/opcounters.py @@ -259,7 +259,7 @@ def do_sample(self, windowed_value): self.type_check(windowed_value.value) size, observables = ( - self.coder_impl.get_estimated_size_and_observables(windowed_value)) + self.coder_impl.get_estimated_size_and_observables(windowed_value)) # type: ignore[union-attr] if not observables: self.current_size = size else: diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 2a1423fccba9..b091220a06b5 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -1335,8 +1335,8 @@ def _get_cache_token(self, state_key): return self._context.user_state_cache_token else: return self._context.bundle_cache_token - elif state_key.WhichOneof('type').endswith('_side_input'): - side_input = getattr(state_key, state_key.WhichOneof('type')) + elif state_key.WhichOneof('type').endswith('_side_input'): # type: ignore[union-attr] + side_input = getattr(state_key, state_key.WhichOneof('type')) # type: ignore[arg-type] return self._context.side_input_cache_tokens.get( (side_input.transform_id, side_input.side_input_id), self._context.bundle_cache_token) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index cd49c69a80aa..3389f0c7afb1 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -27,7 +27,7 @@ import sys import traceback -from google.protobuf import text_format # type: ignore # not in typeshed +from google.protobuf import text_format from apache_beam.internal import pickler from apache_beam.io import filesystems diff --git a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py index 87c631762287..ba3fae6819bd 100644 --- a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py +++ b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py @@ -29,7 +29,7 @@ from apache_beam.testing.load_tests.load_test_metrics_utils import MeasureTime from apache_beam.testing.load_tests.load_test_metrics_utils import MetricsReader -from google.protobuf import text_format # type: ignore # typeshed out of date +from google.protobuf import text_format from trainer import taxi diff --git a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py index 6d84c995bfca..88ed53e11fc4 100644 --- a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py +++ b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py @@ -18,7 +18,7 @@ from tensorflow_transform import coders as tft_coders from tensorflow_transform.tf_metadata import schema_utils -from google.protobuf import text_format # type: ignore # typeshed out of date +from google.protobuf import text_format from tensorflow.python.lib.io import file_io from tensorflow_metadata.proto.v0 import schema_pb2 diff --git a/sdks/python/apache_beam/testing/benchmarks/wordcount/__init__.py b/sdks/python/apache_beam/testing/benchmarks/wordcount/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/wordcount/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py b/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py new file mode 100644 index 000000000000..513ede47e80a --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/wordcount/wordcount.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import logging + +from apache_beam.examples import wordcount +from apache_beam.testing.load_tests.dataflow_cost_benchmark import DataflowCostBenchmark + + +class WordcountCostBenchmark(DataflowCostBenchmark): + def __init__(self): + super().__init__() + + def test(self): + extra_opts = {} + extra_opts['output'] = self.pipeline.get_option('output_file') + self.result = wordcount.run( + self.pipeline.get_full_options_as_args(**extra_opts), + save_main_session=False) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + WordcountCostBenchmark().run() diff --git a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py new file mode 100644 index 000000000000..96a1cd31e298 --- /dev/null +++ b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pytype: skip-file + +import logging +import time +from typing import Any +from typing import Optional + +import apache_beam.testing.load_tests.dataflow_cost_consts as costs +from apache_beam.metrics.execution import MetricResult +from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult +from apache_beam.runners.runner import PipelineState +from apache_beam.testing.load_tests.load_test import LoadTest + + +class DataflowCostBenchmark(LoadTest): + """Base class for Dataflow performance tests which export metrics to + external databases: BigQuery or/and InfluxDB. Calculates the expected cost + for running the job on Dataflow in region us-central1. + + Refer to :class:`~apache_beam.testing.load_tests.LoadTestOptions` for more + information on the required pipeline options. + + If using InfluxDB with Basic HTTP authentication enabled, provide the + following environment options: `INFLUXDB_USER` and `INFLUXDB_USER_PASSWORD`. + + If the hardware configuration for the job includes use of a GPU, please + specify the version in use with the Accelerator enumeration. This is used to + calculate the cost of the job later, as different accelerators have different + billing rates per hour of use. + """ + def __init__( + self, + metrics_namespace: Optional[str] = None, + is_streaming: bool = False, + gpu: Optional[costs.Accelerator] = None): + self.is_streaming = is_streaming + self.gpu = gpu + super().__init__(metrics_namespace=metrics_namespace) + + def run(self): + try: + self.test() + if not hasattr(self, 'result'): + self.result = self.pipeline.run() + # Defaults to waiting forever unless timeout has been set + state = self.result.wait_until_finish(duration=self.timeout_ms) + assert state != PipelineState.FAILED + logging.info( + 'Pipeline complete, sleeping for 4 minutes to allow resource ' + 'metrics to populate.') + time.sleep(240) + self.extra_metrics = self._retrieve_cost_metrics(self.result) + self._metrics_monitor.publish_metrics(self.result, self.extra_metrics) + finally: + self.cleanup() + + def _retrieve_cost_metrics(self, + result: DataflowPipelineResult) -> dict[str, Any]: + job_id = result.job_id() + metrics = result.metrics().all_metrics(job_id) + metrics_dict = self._process_metrics_list(metrics) + logging.info(metrics_dict) + cost = 0.0 + if (self.is_streaming): + cost += metrics_dict.get( + "TotalVcpuTime", 0.0) / 3600 * costs.VCPU_PER_HR_STREAMING + cost += ( + metrics_dict.get("TotalMemoryUsage", 0.0) / + 1000) / 3600 * costs.MEM_PER_GB_HR_STREAMING + cost += metrics_dict.get( + "TotalStreamingDataProcessed", 0.0) * costs.SHUFFLE_PER_GB_STREAMING + else: + cost += metrics_dict.get( + "TotalVcpuTime", 0.0) / 3600 * costs.VCPU_PER_HR_BATCH + cost += ( + metrics_dict.get("TotalMemoryUsage", 0.0) / + 1000) / 3600 * costs.MEM_PER_GB_HR_BATCH + cost += metrics_dict.get( + "TotalStreamingDataProcessed", 0.0) * costs.SHUFFLE_PER_GB_BATCH + if (self.gpu): + rate = costs.ACCELERATOR_TO_COST[self.gpu] + cost += metrics_dict.get("TotalGpuTime", 0.0) / 3600 * rate + cost += metrics_dict.get("TotalPdUsage", 0.0) / 3600 * costs.PD_PER_GB_HR + cost += metrics_dict.get( + "TotalSsdUsage", 0.0) / 3600 * costs.PD_SSD_PER_GB_HR + metrics_dict["EstimatedCost"] = cost + return metrics_dict + + def _process_metrics_list(self, + metrics: list[MetricResult]) -> dict[str, Any]: + system_metrics = {} + for entry in metrics: + metric_key = entry.key + metric = metric_key.metric + if metric_key.step == '' and metric.namespace == 'dataflow/v1b3': + if entry.committed is None: + entry.committed = 0.0 + system_metrics[metric.name] = entry.committed + return system_metrics diff --git a/sdks/python/apache_beam/testing/load_tests/dataflow_cost_consts.py b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_consts.py new file mode 100644 index 000000000000..f291991b48bb --- /dev/null +++ b/sdks/python/apache_beam/testing/load_tests/dataflow_cost_consts.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These values are Dataflow costs for running jobs in us-central1. +# The cost values are found at https://cloud.google.com/dataflow/pricing + +from enum import Enum + +VCPU_PER_HR_BATCH = 0.056 +VCPU_PER_HR_STREAMING = 0.069 +MEM_PER_GB_HR_BATCH = 0.003557 +MEM_PER_GB_HR_STREAMING = 0.0035557 +PD_PER_GB_HR = 0.000054 +PD_SSD_PER_GB_HR = 0.000298 +SHUFFLE_PER_GB_BATCH = 0.011 +SHUFFLE_PER_GB_STREAMING = 0.018 + +# GPU Resource Pricing +P100_PER_GPU_PER_HOUR = 1.752 +V100_PER_GPU_PER_HOUR = 2.976 +T4_PER_GPU_PER_HOUR = 0.42 +P4_PER_GPU_PER_HOUR = 0.72 +L4_PER_GPU_PER_HOUR = 0.672 +A100_40GB_PER_GPU_PER_HOUR = 3.72 +A100_80GB_PER_GPU_PER_HOUR = 4.7137 + + +class Accelerator(Enum): + P100 = 1 + V100 = 2 + T4 = 3 + P4 = 4 + L4 = 5 + A100_40GB = 6 + A100_80GB = 7 + + +ACCELERATOR_TO_COST: dict[Accelerator, float] = { + Accelerator.P100: P100_PER_GPU_PER_HOUR, + Accelerator.V100: V100_PER_GPU_PER_HOUR, + Accelerator.T4: T4_PER_GPU_PER_HOUR, + Accelerator.P4: P4_PER_GPU_PER_HOUR, + Accelerator.L4: L4_PER_GPU_PER_HOUR, + Accelerator.A100_40GB: A100_40GB_PER_GPU_PER_HOUR, + Accelerator.A100_80GB: A100_80GB_PER_GPU_PER_HOUR, +} diff --git a/sdks/python/apache_beam/testing/load_tests/sideinput_test.py b/sdks/python/apache_beam/testing/load_tests/sideinput_test.py index 745d961d2aac..3b5dfdf38cd9 100644 --- a/sdks/python/apache_beam/testing/load_tests/sideinput_test.py +++ b/sdks/python/apache_beam/testing/load_tests/sideinput_test.py @@ -111,7 +111,7 @@ class SequenceSideInputTestDoFn(beam.DoFn): def __init__(self, first_n: int): self._first_n = first_n - def process( # type: ignore[override] + def process( self, element: Any, side_input: Iterable[Tuple[bytes, bytes]]) -> None: i = 0 @@ -129,7 +129,7 @@ class MappingSideInputTestDoFn(beam.DoFn): def __init__(self, first_n: int): self._first_n = first_n - def process( # type: ignore[override] + def process( self, element: Any, dict_side_input: Dict[bytes, bytes]) -> None: i = 0 for key in dict_side_input: @@ -146,7 +146,7 @@ def __init__(self): # Avoid having to use save_main_session self.window = window - def process(self, element: int) -> Iterable[window.TimestampedValue]: # type: ignore[override] + def process(self, element: int) -> Iterable[window.TimestampedValue]: yield self.window.TimestampedValue(element, element) class GetSyntheticSDFOptions(beam.DoFn): @@ -156,7 +156,7 @@ def __init__( self.key_size = key_size self.value_size = value_size - def process(self, element: Any) -> Iterable[Dict[str, Union[int, str]]]: # type: ignore[override] + def process(self, element: Any) -> Iterable[Dict[str, Union[int, str]]]: yield { 'num_records': self.elements_per_record, 'key_size': self.key_size, diff --git a/sdks/python/apache_beam/testing/test_stream_service_test.py b/sdks/python/apache_beam/testing/test_stream_service_test.py index 5bfd0c104ba0..a04fa2303d08 100644 --- a/sdks/python/apache_beam/testing/test_stream_service_test.py +++ b/sdks/python/apache_beam/testing/test_stream_service_test.py @@ -30,9 +30,9 @@ # Nose automatically detects tests if they match a regex. Here, it mistakens # these protos as tests. For more info see the Nose docs at: # https://nose.readthedocs.io/en/latest/writing_tests.html -beam_runner_api_pb2.TestStreamPayload.__test__ = False # type: ignore[attr-defined] -beam_interactive_api_pb2.TestStreamFileHeader.__test__ = False # type: ignore[attr-defined] -beam_interactive_api_pb2.TestStreamFileRecord.__test__ = False # type: ignore[attr-defined] +beam_runner_api_pb2.TestStreamPayload.__test__ = False +beam_interactive_api_pb2.TestStreamFileHeader.__test__ = False +beam_interactive_api_pb2.TestStreamFileRecord.__test__ = False class EventsReader: diff --git a/sdks/python/apache_beam/transforms/__init__.py b/sdks/python/apache_beam/transforms/__init__.py index 4e66a290842c..b8b6839019e8 100644 --- a/sdks/python/apache_beam/transforms/__init__.py +++ b/sdks/python/apache_beam/transforms/__init__.py @@ -22,6 +22,7 @@ from apache_beam.transforms import combiners from apache_beam.transforms.core import * from apache_beam.transforms.external import * +from apache_beam.transforms.managed import * from apache_beam.transforms.ptransform import * from apache_beam.transforms.stats import * from apache_beam.transforms.timeutil import TimeDomain diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py index 3cb5f32c3114..56610e95297f 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py @@ -17,6 +17,7 @@ # pytype: skip-file +import math from typing import Set from typing import Tuple @@ -124,6 +125,38 @@ def run_combine(pipeline, input_elements=5, lift_combiners=True): assert_that(pcoll, equal_to([(expected_result, expected_result)])) +def run_combine_uncopyable_attr( + pipeline, input_elements=5, lift_combiners=True): + # Calculate the expected result, which is the sum of an arithmetic sequence. + # By default, this is equal to: 0 + 1 + 2 + 3 + 4 = 10 + expected_result = input_elements * (input_elements - 1) / 2 + + # Enable runtime type checking in order to cover TypeCheckCombineFn by + # the test. + pipeline.get_pipeline_options().view_as(TypeOptions).runtime_type_check = True + pipeline.get_pipeline_options().view_as( + TypeOptions).allow_unsafe_triggers = True + + with pipeline as p: + pcoll = p | 'Start' >> beam.Create(range(input_elements)) + + # Certain triggers, such as AfterCount, are incompatible with combiner + # lifting. We can use that fact to prevent combiners from being lifted. + if not lift_combiners: + pcoll |= beam.WindowInto( + window.GlobalWindows(), + trigger=trigger.AfterCount(input_elements), + accumulation_mode=trigger.AccumulationMode.DISCARDING) + + combine_fn = CallSequenceEnforcingCombineFn() + # Modules are not deep copyable. Ensure fanout falls back to pickling for + # copying combine_fn. + combine_fn.module_attribute = math + pcoll |= 'Do' >> beam.CombineGlobally(combine_fn).with_fanout(fanout=1) + + assert_that(pcoll, equal_to([expected_result])) + + def run_pardo(pipeline, input_elements=10): with pipeline as p: _ = ( diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py index 62dbbc5fb77c..647e08db7aaa 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py @@ -31,6 +31,7 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.transforms.combinefn_lifecycle_pipeline import CallSequenceEnforcingCombineFn from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine +from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine_uncopyable_attr from apache_beam.transforms.combinefn_lifecycle_pipeline import run_pardo @@ -53,15 +54,24 @@ def test_combining_value_state(self): @parameterized_class([ - {'runner': direct_runner.BundleBasedDirectRunner}, - {'runner': fn_api_runner.FnApiRunner}, -]) # yapf: disable + {'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'dill'}, + {'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'cloudpickle'}, + {'runner': fn_api_runner.FnApiRunner, 'pickler': 'dill'}, + {'runner': fn_api_runner.FnApiRunner, 'pickler': 'cloudpickle'}, + ]) # yapf: disable class LocalCombineFnLifecycleTest(unittest.TestCase): def tearDown(self): CallSequenceEnforcingCombineFn.instances.clear() def test_combine(self): - run_combine(TestPipeline(runner=self.runner())) + test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"]) + run_combine(TestPipeline(runner=self.runner(), options=test_options)) + self._assert_teardown_called() + + def test_combine_deepcopy_fails(self): + test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"]) + run_combine_uncopyable_attr( + TestPipeline(runner=self.runner(), options=test_options)) self._assert_teardown_called() def test_non_liftable_combine(self): diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 91ca4c8e33c3..b420d1d66d09 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -103,6 +103,7 @@ 'Windowing', 'WindowInto', 'Flatten', + 'FlattenWith', 'Create', 'Impulse', 'RestrictionProvider', @@ -1414,7 +1415,7 @@ class PartitionFn(WithTypeHints): def default_label(self): return self.__class__.__name__ - def partition_for(self, element, num_partitions, *args, **kwargs): + def partition_for(self, element, num_partitions, *args, **kwargs): # type: ignore[empty-body] # type: (T, int, *typing.Any, **typing.Any) -> int """Specify which partition will receive this element. @@ -1595,7 +1596,7 @@ def with_exception_handling( error_handler=None, on_failure_callback: typing.Optional[typing.Callable[ [Exception, typing.Any], None]] = None): - """Automatically provides a dead letter output for skipping bad records. + """Automatically provides a dead letter output for saving bad inputs. This can allow a pipeline to continue successfully rather than fail or continuously throw errors on retry when bad elements are encountered. @@ -1606,17 +1607,18 @@ def with_exception_handling( For example, one would write:: - good, bad = Map(maybe_error_raising_function).with_exception_handling() + good, bad = inputs | Map(maybe_erroring_fn).with_exception_handling() and `good` will be a PCollection of mapped records and `bad` will contain - those that raised exceptions. + tuples of the form `(input, error_string`) for each input that raised an + exception. Args: main_tag: tag to be used for the main (good) output of the DoFn, useful to avoid possible conflicts if this DoFn already produces multiple outputs. Optional, defaults to 'good'. - dead_letter_tag: tag to be used for the bad records, useful to avoid + dead_letter_tag: tag to be used for the bad inputs, useful to avoid possible conflicts if this DoFn already produces multiple outputs. Optional, defaults to 'bad'. exc_class: An exception class, or tuple of exception classes, to catch. @@ -1635,9 +1637,9 @@ def with_exception_handling( than a new process per element, so the overhead should be minimal (and can be amortized if there's any per-process or per-bundle initialization that needs to be done). Optional, defaults to False. - threshold: An upper bound on the ratio of records that can be bad before + threshold: An upper bound on the ratio of inputs that can be bad before aborting the entire pipeline. Optional, defaults to 1.0 (meaning - up to 100% of records can be bad and the pipeline will still succeed). + up to 100% of inputs can be bad and the pipeline will still succeed). threshold_windowing: Event-time windowing to use for threshold. Optional, defaults to the windowing of the input. timeout: If the element has not finished processing in timeout seconds, @@ -2115,15 +2117,13 @@ def MapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name r""":func:`MapTuple` is like :func:`Map` but expects tuple inputs and flattens them into multiple input arguments. - beam.MapTuple(lambda a, b, ...: ...) - In other words - beam.MapTuple(fn) + "SwapKV" >> beam.Map(lambda kv: (kv[1], kv[0])) is equivalent to - beam.Map(lambda element, ...: fn(\*element, ...)) + "SwapKV" >> beam.MapTuple(lambda k, v: (v, k)) This can be useful when processing a PCollection of tuples (e.g. key-value pairs). @@ -2189,19 +2189,13 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name r""":func:`FlatMapTuple` is like :func:`FlatMap` but expects tuple inputs and flattens them into multiple input arguments. - beam.FlatMapTuple(lambda a, b, ...: ...) - - is equivalent to Python 2 - - beam.FlatMap(lambda (a, b, ...), ...: ...) - In other words - beam.FlatMapTuple(fn) + beam.FlatMap(lambda start_end: range(start_end[0], start_end[1])) is equivalent to - beam.FlatMap(lambda element, ...: fn(\*element, ...)) + beam.FlatMapTuple(lambda start, end: range(start, end)) This can be useful when processing a PCollection of tuples (e.g. key-value pairs). @@ -2236,7 +2230,7 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name if defaults or args or kwargs: wrapper = lambda x, *args, **kwargs: fn(*(tuple(x) + args), **kwargs) else: - wrapper = lambda x: fn(*x) + wrapper = lambda x: fn(*tuple(x)) # Proxy the type-hint information from the original function to this new # wrapped function. @@ -3170,33 +3164,48 @@ def process(self, element): yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value)) class PreCombineFn(CombineFn): + def __init__(self): + # Deepcopy of the combine_fn to avoid sharing state between lifted + # stages when using cloudpickle. + try: + self._combine_fn_copy = copy.deepcopy(combine_fn) + except Exception: + self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn)) + + self.setup = self._combine_fn_copy.setup + self.create_accumulator = self._combine_fn_copy.create_accumulator + self.add_input = self._combine_fn_copy.add_input + self.merge_accumulators = self._combine_fn_copy.merge_accumulators + self.compact = self._combine_fn_copy.compact + self.teardown = self._combine_fn_copy.teardown + @staticmethod def extract_output(accumulator): # Boolean indicates this is an accumulator. return (True, accumulator) - setup = combine_fn.setup - create_accumulator = combine_fn.create_accumulator - add_input = combine_fn.add_input - merge_accumulators = combine_fn.merge_accumulators - compact = combine_fn.compact - teardown = combine_fn.teardown - class PostCombineFn(CombineFn): - @staticmethod - def add_input(accumulator, element): + def __init__(self): + # Deepcopy of the combine_fn to avoid sharing state between lifted + # stages when using cloudpickle. + try: + self._combine_fn_copy = copy.deepcopy(combine_fn) + except Exception: + self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn)) + + self.setup = self._combine_fn_copy.setup + self.create_accumulator = self._combine_fn_copy.create_accumulator + self.merge_accumulators = self._combine_fn_copy.merge_accumulators + self.compact = self._combine_fn_copy.compact + self.extract_output = self._combine_fn_copy.extract_output + self.teardown = self._combine_fn_copy.teardown + + def add_input(self, accumulator, element): is_accumulator, value = element if is_accumulator: - return combine_fn.merge_accumulators([accumulator, value]) + return self._combine_fn_copy.merge_accumulators([accumulator, value]) else: - return combine_fn.add_input(accumulator, value) - - setup = combine_fn.setup - create_accumulator = combine_fn.create_accumulator - merge_accumulators = combine_fn.merge_accumulators - compact = combine_fn.compact - extract_output = combine_fn.extract_output - teardown = combine_fn.teardown + return self._combine_fn_copy.add_input(accumulator, value) def StripNonce(nonce_key_value): (_, key), value = nonce_key_value @@ -3434,7 +3443,7 @@ def _dynamic_named_tuple(type_name, field_names): type_name, field_names) # typing: can't override a method. also, self type is unknown and can't # be cast to tuple - result.__reduce__ = lambda self: ( # type: ignore[assignment] + result.__reduce__ = lambda self: ( # type: ignore[method-assign] _unpickle_dynamic_named_tuple, (type_name, field_names, tuple(self))) # type: ignore[arg-type] return result @@ -3865,6 +3874,33 @@ def from_runner_api_parameter( common_urns.primitives.FLATTEN.urn, None, Flatten.from_runner_api_parameter) +class FlattenWith(PTransform): + """A PTransform that flattens its input with other PCollections. + + This is equivalent to creating a tuple containing both the input and the + other PCollection(s), but has the advantage that it can be more easily used + inline. + + Root PTransforms can be passed as well as PCollections, in which case their + outputs will be flattened. + """ + def __init__(self, *others): + self._others = others + + def expand(self, pcoll): + pcolls = [pcoll] + for other in self._others: + if isinstance(other, pvalue.PCollection): + pcolls.append(other) + elif isinstance(other, PTransform): + pcolls.append(pcoll.pipeline | other) + else: + raise TypeError( + 'FlattenWith only takes other PCollections and PTransforms, ' + f'got {other}') + return tuple(pcolls) | Flatten() + + class Create(PTransform): """A transform that creates a PCollection from an iterable.""" def __init__(self, values, reshuffle=True): diff --git a/sdks/python/apache_beam/transforms/display.py b/sdks/python/apache_beam/transforms/display.py index 86bbf101f567..14cd485d1f8e 100644 --- a/sdks/python/apache_beam/transforms/display.py +++ b/sdks/python/apache_beam/transforms/display.py @@ -173,7 +173,7 @@ def create_payload(dd) -> Optional[beam_runner_api_pb2.LabelledPayload]: elif isinstance(value, (float, complex)): return beam_runner_api_pb2.LabelledPayload( label=label, - double_value=value, + double_value=value, # type: ignore[arg-type] key=display_data_dict['key'], namespace=display_data_dict.get('namespace', '')) else: diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 382ae123a81d..06b40bf38cc1 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -14,11 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from collections.abc import Callable +from collections.abc import Mapping from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Mapping from typing import Optional from typing import Union @@ -30,7 +28,7 @@ from apache_beam.transforms.enrichment import EnrichmentSourceHandler QueryFn = Callable[[beam.Row], str] -ConditionValueFn = Callable[[beam.Row], List[Any]] +ConditionValueFn = Callable[[beam.Row], list[Any]] def _validate_bigquery_metadata( @@ -54,8 +52,8 @@ def _validate_bigquery_metadata( "`condition_value_fn`") -class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]], - Union[Row, List[Row]]]): +class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]], + Union[Row, list[Row]]]): """Enrichment handler for Google Cloud BigQuery. Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` @@ -83,8 +81,8 @@ def __init__( *, table_name: str = "", row_restriction_template: str = "", - fields: Optional[List[str]] = None, - column_names: Optional[List[str]] = None, + fields: Optional[list[str]] = None, + column_names: Optional[list[str]] = None, condition_value_fn: Optional[ConditionValueFn] = None, query_fn: Optional[QueryFn] = None, min_batch_size: int = 1, @@ -107,10 +105,10 @@ def __init__( row_restriction_template (str): A template string for the `WHERE` clause in the BigQuery query with placeholders (`{}`) to dynamically filter rows based on input data. - fields: (Optional[List[str]]) List of field names present in the input + fields: (Optional[list[str]]) List of field names present in the input `beam.Row`. These are used to construct the WHERE clause (if `condition_value_fn` is not provided). - column_names: (Optional[List[str]]) Names of columns to select from the + column_names: (Optional[list[str]]) Names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected. condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function that takes a `beam.Row` and returns a list of value to populate in the @@ -171,16 +169,24 @@ def _execute_query(self, query: str): except RuntimeError as e: raise RuntimeError(f"Could not complete the query request: {query}. {e}") - def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): - if isinstance(request, List): + def create_row_key(self, row: beam.Row): + if self.condition_value_fn: + return tuple(self.condition_value_fn(row)) + if self.fields: + row_dict = row._asdict() + return (tuple(row_dict[field] for field in self.fields)) + raise ValueError("Either fields or condition_value_fn must be specified") + + def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): + if isinstance(request, list): values = [] responses = [] - requests_map: Dict[Any, Any] = {} + requests_map: dict[Any, Any] = {} batch_size = len(request) raw_query = self.query_template if batch_size > 1: batched_condition_template = ' or '.join( - [self.row_restriction_template] * batch_size) + [fr'({self.row_restriction_template})'] * batch_size) raw_query = self.query_template.replace( self.row_restriction_template, batched_condition_template) for req in request: @@ -194,14 +200,15 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): "Make sure the values passed in `fields` are the " "keys in the input `beam.Row`." + str(e)) values.extend(current_values) - requests_map.update((val, req) for val in current_values) + requests_map[self.create_row_key(req)] = req query = raw_query.format(*values) responses_dict = self._execute_query(query) for response in responses_dict: - for value in response.values(): - if value in requests_map: - responses.append((requests_map[value], beam.Row(**response))) + response_row = beam.Row(**response) + response_key = self.create_row_key(response_row) + if response_key in requests_map: + responses.append((requests_map[response_key], response_row)) return responses else: request_dict = request._asdict() @@ -221,8 +228,8 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() - def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]): - if isinstance(request, List): + def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]): + if isinstance(request, list): cache_keys = [] for req in request: req_dict = req._asdict() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py index 0b8a384b934d..dd99e386555e 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py @@ -14,7 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import functools import logging +import secrets +import time import unittest from unittest.mock import MagicMock @@ -22,7 +25,11 @@ import apache_beam as beam from apache_beam.coders import coders +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to # pylint: disable=ungrouped-imports try: @@ -31,8 +38,7 @@ from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.bigquery import \ BigQueryEnrichmentHandler - from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store_it_test import \ - ValidateResponse + from apitools.base.py.exceptions import HttpError except ImportError: raise unittest.SkipTest( 'Google Cloud BigQuery dependencies are not installed.') @@ -40,24 +46,101 @@ _LOGGER = logging.getLogger(__name__) -def query_fn(row: beam.Row): - query = ( - "SELECT * FROM " - "`apache-beam-testing.my_ecommerce.product_details`" - " WHERE id = '{}'".format(row.id)) # type: ignore[attr-defined] - return query - - def condition_value_fn(row: beam.Row): return [row.id] # type: ignore[attr-defined] +def query_fn(table, row: beam.Row): + return f"SELECT * FROM `{table}` WHERE id = {row.id}" # type: ignore[attr-defined] + + +@pytest.mark.uses_testcontainer +class BigQueryEnrichmentIT(unittest.TestCase): + bigquery_dataset_id = 'python_enrichment_transform_read_table_' + project = "apache-beam-testing" + + @classmethod + def setUpClass(cls): + cls.bigquery_client = BigQueryWrapper() + cls.dataset_id = '%s%d%s' % ( + cls.bigquery_dataset_id, int(time.time()), secrets.token_hex(3)) + cls.bigquery_client.get_or_create_dataset(cls.project, cls.dataset_id) + _LOGGER.info( + "Created dataset %s in project %s", cls.dataset_id, cls.project) + + @classmethod + def tearDownClass(cls): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=cls.project, datasetId=cls.dataset_id, deleteContents=True) + try: + _LOGGER.debug( + "Deleting dataset %s in project %s", cls.dataset_id, cls.project) + cls.bigquery_client.client.datasets.Delete(request) + except HttpError: + _LOGGER.warning( + 'Failed to clean up dataset %s in project %s', + cls.dataset_id, + cls.project) + + @pytest.mark.uses_testcontainer -class TestBigQueryEnrichmentIT(unittest.TestCase): +class TestBigQueryEnrichmentIT(BigQueryEnrichmentIT): + table_data = [ + { + "id": 1, "name": "A", 'quantity': 2, 'distribution_center_id': 3 + }, + { + "id": 2, "name": "B", 'quantity': 3, 'distribution_center_id': 1 + }, + { + "id": 3, "name": "C", 'quantity': 10, 'distribution_center_id': 4 + }, + { + "id": 4, "name": "D", 'quantity': 1, 'distribution_center_id': 3 + }, + { + "id": 5, "name": "C", 'quantity': 100, 'distribution_center_id': 4 + }, + { + "id": 6, "name": "D", 'quantity': 11, 'distribution_center_id': 3 + }, + { + "id": 7, "name": "C", 'quantity': 7, 'distribution_center_id': 1 + }, + { + "id": 8, "name": "D", 'quantity': 4, 'distribution_center_id': 1 + }, + ] + + @classmethod + def create_table(cls, table_name): + fields = [('id', 'INTEGER'), ('name', 'STRING'), ('quantity', 'INTEGER'), + ('distribution_center_id', 'INTEGER')] + table_schema = bigquery.TableSchema() + for name, field_type in fields: + table_field = bigquery.TableFieldSchema() + table_field.name = name + table_field.type = field_type + table_schema.fields.append(table_field) + table = bigquery.Table( + tableReference=bigquery.TableReference( + projectId=cls.project, datasetId=cls.dataset_id, + tableId=table_name), + schema=table_schema) + request = bigquery.BigqueryTablesInsertRequest( + projectId=cls.project, datasetId=cls.dataset_id, table=table) + cls.bigquery_client.client.tables.Insert(request) + cls.bigquery_client.insert_rows( + cls.project, cls.dataset_id, table_name, cls.table_data) + cls.table_name = f"{cls.project}.{cls.dataset_id}.{table_name}" + + @classmethod + def setUpClass(cls): + super(TestBigQueryEnrichmentIT, cls).setUpClass() + cls.create_table('product_details') + def setUp(self) -> None: - self.project = 'apache-beam-testing' - self.condition_template = "id = '{}'" - self.table_name = "`apache-beam-testing.my_ecommerce.product_details`" + self.condition_template = "id = {}" self.retries = 3 self._start_container() @@ -82,123 +165,119 @@ def tearDown(self) -> None: self.client = None def test_bigquery_enrichment(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] fields = ['id'] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] handler = BigQueryEnrichmentHandler( project=self.project, - row_restriction_template=self.condition_template, + row_restriction_template="id = {}", table_name=self.table_name, fields=fields, - min_batch_size=2, + min_batch_size=1, max_batch_size=100, ) + with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_bigquery_enrichment_with_query_fn(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_batched(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] + fields = ['id'] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] - handler = BigQueryEnrichmentHandler(project=self.project, query_fn=query_fn) + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template="id = {}", + table_name=self.table_name, + fields=fields, + min_batch_size=2, + max_batch_size=100, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_bigquery_enrichment_with_condition_value_fn(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_batched_multiple_fields(self): + expected_rows = [ + beam.Row(id=1, distribution_center_id=3, name="A", quantity=2), + beam.Row(id=2, distribution_center_id=1, name="B", quantity=3) ] + fields = ['id', 'distribution_center_id'] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, distribution_center_id=3), + beam.Row(id=2, distribution_center_id=1), ] handler = BigQueryEnrichmentHandler( project=self.project, - row_restriction_template=self.condition_template, + row_restriction_template="id = {} AND distribution_center_id = {}", table_name=self.table_name, - condition_value_fn=condition_value_fn, - min_batch_size=2, + fields=fields, + min_batch_size=8, max_batch_size=100, ) + with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_bigquery_enrichment_with_condition_without_batch(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_with_query_fn(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + fn = functools.partial(query_fn, self.table_name) + handler = BigQueryEnrichmentHandler(project=self.project, query_fn=fn) + with TestPipeline(is_integration_test=True) as test_pipeline: + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_with_condition_value_fn(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) + ] + requests = [ + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] handler = BigQueryEnrichmentHandler( project=self.project, row_restriction_template=self.condition_template, table_name=self.table_name, condition_value_fn=condition_value_fn, + min_batch_size=2, + max_batch_size=100, ) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) def test_bigquery_enrichment_bad_request(self): requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] handler = BigQueryEnrichmentHandler( project=self.project, @@ -231,18 +310,13 @@ def test_bigquery_enrichment_with_redis(self): requests. Since all requests are cached, it will return from there without making calls to the BigQuery service. """ - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' - ] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] handler = BigQueryEnrichmentHandler( project=self.project, @@ -253,11 +327,12 @@ def test_bigquery_enrichment_with_redis(self): max_batch_size=100, ) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( + pcoll_populate_cache = ( test_pipeline | beam.Create(requests) - | Enrichment(handler).with_redis_cache(self.host, self.port) - | beam.ParDo(ValidateResponse(expected_fields))) + | Enrichment(handler).with_redis_cache(self.host, self.port)) + + assert_that(pcoll_populate_cache, equal_to(expected_rows)) # manually check cache entry c = coders.StrUtf8Coder() @@ -268,20 +343,15 @@ def test_bigquery_enrichment_with_redis(self): raise ValueError("No cache entry found for %s" % key) actual = BigQueryEnrichmentHandler.__call__ - BigQueryEnrichmentHandler.__call__ = MagicMock( - return_value=( - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), - beam.Row())) + BigQueryEnrichmentHandler.__call__ = MagicMock(return_value=(beam.Row())) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( + pcoll_cached = ( test_pipeline | beam.Create(requests) - | Enrichment(handler).with_redis_cache(self.host, self.port) - | beam.ParDo(ValidateResponse(expected_fields))) + | Enrichment(handler).with_redis_cache(self.host, self.port)) + + assert_that(pcoll_cached, equal_to(expected_rows)) BigQueryEnrichmentHandler.__call__ = actual diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py index ddb62c2f60d5..c251ab05ecab 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -15,9 +15,8 @@ # limitations under the License. # import logging +from collections.abc import Callable from typing import Any -from typing import Callable -from typing import Dict from typing import Optional from google.api_core.exceptions import NotFound @@ -115,7 +114,7 @@ def __call__(self, request: beam.Row, *args, **kwargs): Args: request: the input `beam.Row` to enrich. """ - response_dict: Dict[str, Any] = {} + response_dict: dict[str, Any] = {} row_key_str: str = "" try: if self._row_key_fn: diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py index 79d73178e94e..6bf57cefacbe 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -18,10 +18,7 @@ import datetime import logging import unittest -from typing import Dict -from typing import List from typing import NamedTuple -from typing import Tuple from unittest.mock import MagicMock import pytest @@ -57,8 +54,8 @@ class ValidateResponse(beam.DoFn): def __init__( self, n_fields: int, - fields: List[str], - enriched_fields: Dict[str, List[str]], + fields: list[str], + enriched_fields: dict[str, list[str]], include_timestamp: bool = False, ): self.n_fields = n_fields @@ -88,7 +85,7 @@ def process(self, element: beam.Row, *args, **kwargs): "Response from bigtable should contain a %s column_family with " "%s columns." % (column_family, columns)) if (self._include_timestamp and - not isinstance(element_dict[column_family][key][0], Tuple)): # type: ignore[arg-type] + not isinstance(element_dict[column_family][key][0], tuple)): raise BeamAssertException( "Response from bigtable should contain timestamp associated with " "its value.") diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py index dc2a71786f65..f8e8b4db1d7f 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py @@ -16,11 +16,10 @@ # import logging import tempfile +from collections.abc import Callable +from collections.abc import Mapping from pathlib import Path from typing import Any -from typing import Callable -from typing import List -from typing import Mapping from typing import Optional import apache_beam as beam @@ -95,7 +94,7 @@ class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row, def __init__( self, feature_store_yaml_path: str, - feature_names: Optional[List[str]] = None, + feature_names: Optional[list[str]] = None, feature_service_name: Optional[str] = "", full_feature_names: Optional[bool] = False, entity_id: str = "", diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py index 89cb39c2c19c..9c4dab3d68b8 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py @@ -22,8 +22,8 @@ """ import unittest +from collections.abc import Mapping from typing import Any -from typing import Mapping import pytest diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py index 753b04e1793d..b6de3aa1c826 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py @@ -15,7 +15,6 @@ # limitations under the License. # import logging -from typing import List import proto from google.api_core.exceptions import NotFound @@ -209,7 +208,7 @@ def __init__( api_endpoint: str, feature_store_id: str, entity_type_id: str, - feature_ids: List[str], + feature_ids: list[str], row_key: str, *, exception_level: ExceptionLevel = ExceptionLevel.WARN, @@ -224,7 +223,7 @@ def __init__( Vertex AI Feature Store (Legacy). feature_store_id (str): The id of the Vertex AI Feature Store (Legacy). entity_type_id (str): The entity type of the feature store. - feature_ids (List[str]): A list of feature-ids to fetch + feature_ids (list[str]): A list of feature-ids to fetch from the Feature Store. row_key (str): The row key field name containing the entity id for the feature values. diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py index dbb227802925..77704e0522b2 100644 --- a/sdks/python/apache_beam/transforms/environments.py +++ b/sdks/python/apache_beam/transforms/environments.py @@ -895,6 +895,7 @@ def _python_sdk_capabilities_iter(): yield common_urns.primitives.TO_STRING.urn yield common_urns.protocols.DATA_SAMPLING.urn yield common_urns.protocols.SDK_CONSUMING_RECEIVED_DATA.urn + yield common_urns.protocols.ORDERED_LIST_STATE.urn def python_sdk_dependencies(options, tmp_dir=None): diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 8a04e7efb195..e44f7482dc61 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -653,8 +653,8 @@ def __init__(self, urn, payload, expansion_service=None): payload.payload() if isinstance(payload, PayloadBuilder) else payload) self._expansion_service = expansion_service self._external_namespace = self._fresh_namespace() - self._inputs = {} # type: Dict[str, pvalue.PCollection] - self._outputs = {} # type: Dict[str, pvalue.PCollection] + self._inputs: Dict[str, pvalue.PCollection] = {} + self._outputs: Dict[str, pvalue.PCollection] = {} def with_output_types(self, *args, **kwargs): return WithTypeHints.with_output_types(self, *args, **kwargs) @@ -691,13 +691,11 @@ def outer_namespace(cls, namespace): cls._external_namespace.value = prev @classmethod - def _fresh_namespace(cls): - # type: () -> str + def _fresh_namespace(cls) -> str: ExternalTransform._namespace_counter += 1 return '%s_%d' % (cls.get_local_namespace(), cls._namespace_counter) - def expand(self, pvalueish): - # type: (pvalue.PCollection) -> pvalue.PCollection + def expand(self, pvalueish: pvalue.PCollection) -> pvalue.PCollection: if isinstance(pvalueish, pvalue.PBegin): self._inputs = {} elif isinstance(pvalueish, (list, tuple)): @@ -740,7 +738,7 @@ def expand(self, pvalueish): components = context.to_runner_api() request = beam_expansion_api_pb2.ExpansionRequest( components=components, - namespace=self._external_namespace, + namespace=self._external_namespace, # type: ignore[arg-type] transform=transform_proto, output_coder_requests=output_coders, pipeline_options=pipeline._options.to_runner_api()) @@ -923,6 +921,7 @@ def _normalize(coder_proto): for tag, pcoll in self._expanded_transform.outputs.items() }, + annotations=self._expanded_transform.annotations, environment_id=self._expanded_transform.environment_id) diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index fe2914a08699..c95a5d19f0cd 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -52,6 +52,7 @@ from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.utils import proto_utils from apache_beam.utils.subprocess_server import JavaJarServer +from apache_beam.utils.subprocess_server import SubprocessServer # Protect against environments where apitools library is not available. # pylint: disable=wrong-import-order, wrong-import-position @@ -718,6 +719,9 @@ def test_implicit_builder_with_constructor_method(self): class JavaJarExpansionServiceTest(unittest.TestCase): + def setUp(self): + SubprocessServer._cache._live_owners = set() + def test_classpath(self): with tempfile.TemporaryDirectory() as temp_dir: try: diff --git a/sdks/python/apache_beam/transforms/managed.py b/sdks/python/apache_beam/transforms/managed.py new file mode 100644 index 000000000000..cbcb6de56ed7 --- /dev/null +++ b/sdks/python/apache_beam/transforms/managed.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Managed Transforms. + +This module builds and instantiates turnkey transforms that can be managed by +the underlying runner. This means the runner can upgrade the transform to a +more optimal/updated version without requiring the user to do anything. It may +also replace the transform with something entirely different if it chooses to. +By default, however, the specified transform will remain unchanged. + +Using Managed Transforms +======================== +Managed turnkey transforms have a defined configuration and can be built using +an inline :class:`dict` like so:: + + results = p | beam.managed.Read( + beam.managed.ICEBERG, + config={"table": "foo", + "catalog_name": "bar", + "catalog_properties": { + "warehouse": "path/to/warehouse", + "catalog-impl": "org.apache.my.CatalogImpl"}}) + +A YAML configuration file can also be used to build a Managed transform. Say we +have the following `config.yaml` file:: + + topic: "foo" + bootstrap_servers: "localhost:1234" + format: "AVRO" + +Simply provide the location to the file like so:: + + input_rows = p | beam.Create(...) + input_rows | beam.managed.Write( + beam.managed.KAFKA, + config_url="path/to/config.yaml") + +Available transforms +==================== +Available transforms are: + +- **Kafka Read and Write** +- **Iceberg Read and Write** + +**Note:** inputs and outputs need to be PCollection(s) of Beam +:py:class:`apache_beam.pvalue.Row` elements. + +**Note:** Today, all managed transforms are essentially cross-language +transforms, and Java's ManagedSchemaTransform is used under the hood. +""" + +from typing import Any +from typing import Dict +from typing import Optional + +import yaml + +from apache_beam.portability.common_urns import ManagedTransforms +from apache_beam.transforms.external import BeamJarExpansionService +from apache_beam.transforms.external import SchemaAwareExternalTransform +from apache_beam.transforms.ptransform import PTransform + +ICEBERG = "iceberg" +KAFKA = "kafka" +BIGQUERY = "bigquery" +_MANAGED_IDENTIFIER = "beam:transform:managed:v1" +_EXPANSION_SERVICE_JAR_TARGETS = { + "sdks:java:io:expansion-service:shadowJar": [KAFKA, ICEBERG], + "sdks:java:io:google-cloud-platform:expansion-service:shadowJar": [ + BIGQUERY + ] +} + +__all__ = ["ICEBERG", "KAFKA", "BIGQUERY", "Read", "Write"] + + +class Read(PTransform): + """Read using Managed Transforms""" + _READ_TRANSFORMS = { + ICEBERG: ManagedTransforms.Urns.ICEBERG_READ.urn, + KAFKA: ManagedTransforms.Urns.KAFKA_READ.urn, + BIGQUERY: ManagedTransforms.Urns.BIGQUERY_READ.urn + } + + def __init__( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + config_url: Optional[str] = None, + expansion_service=None): + super().__init__() + self._source = source + identifier = self._READ_TRANSFORMS.get(source.lower()) + if not identifier: + raise ValueError( + f"An unsupported source was specified: '{source}'. Please specify " + f"one of the following sources: {list(self._READ_TRANSFORMS.keys())}") + + self._expansion_service = _resolve_expansion_service( + source, identifier, expansion_service) + self._underlying_identifier = identifier + self._yaml_config = yaml.dump(config) + self._config_url = config_url + + def expand(self, input): + return input | SchemaAwareExternalTransform( + identifier=_MANAGED_IDENTIFIER, + expansion_service=self._expansion_service, + rearrange_based_on_discovery=True, + transform_identifier=self._underlying_identifier, + config=self._yaml_config, + config_url=self._config_url) + + def default_label(self) -> str: + return "Managed Read(%s)" % self._source.upper() + + +class Write(PTransform): + """Write using Managed Transforms""" + _WRITE_TRANSFORMS = { + ICEBERG: ManagedTransforms.Urns.ICEBERG_WRITE.urn, + KAFKA: ManagedTransforms.Urns.KAFKA_WRITE.urn, + BIGQUERY: ManagedTransforms.Urns.BIGQUERY_WRITE.urn + } + + def __init__( + self, + sink: str, + config: Optional[Dict[str, Any]] = None, + config_url: Optional[str] = None, + expansion_service=None): + super().__init__() + self._sink = sink + identifier = self._WRITE_TRANSFORMS.get(sink.lower()) + if not identifier: + raise ValueError( + f"An unsupported sink was specified: '{sink}'. Please specify " + f"one of the following sinks: {list(self._WRITE_TRANSFORMS.keys())}") + + self._expansion_service = _resolve_expansion_service( + sink, identifier, expansion_service) + self._underlying_identifier = identifier + self._yaml_config = yaml.dump(config) + self._config_url = config_url + + def expand(self, input): + return input | SchemaAwareExternalTransform( + identifier=_MANAGED_IDENTIFIER, + expansion_service=self._expansion_service, + rearrange_based_on_discovery=True, + transform_identifier=self._underlying_identifier, + config=self._yaml_config, + config_url=self._config_url) + + def default_label(self) -> str: + return "Managed Write(%s)" % self._sink.upper() + + +def _resolve_expansion_service( + transform_name: str, identifier: str, expansion_service): + if expansion_service: + return expansion_service + + default_target = None + for gradle_target, transforms in _EXPANSION_SERVICE_JAR_TARGETS.items(): + if transform_name.lower() in transforms: + default_target = gradle_target + break + if not default_target: + raise ValueError( + "No expansion service was specified and could not find a " + f"default expansion service for {transform_name}: '{identifier}'.") + return BeamJarExpansionService(default_target) diff --git a/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py new file mode 100644 index 000000000000..a09203f313eb --- /dev/null +++ b/sdks/python/apache_beam/transforms/managed_iceberg_it_test.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import time +import unittest + +import pytest + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + + +@pytest.mark.uses_io_java_expansion_service +@unittest.skipUnless( + os.environ.get('EXPANSION_JARS'), + "EXPANSION_JARS environment var is not provided, " + "indicating that jars have not been built") +class ManagedIcebergIT(unittest.TestCase): + WAREHOUSE = "gs://temp-storage-for-end-to-end-tests/xlang-python-using-java" + + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self.args = self.test_pipeline.get_full_options_as_args() + self.args.extend([ + '--experiments=enable_managed_transforms', + '--dataflow_endpoint=https://dataflow-staging.sandbox.googleapis.com', + ]) + + def _create_row(self, num: int): + return beam.Row( + int_=num, + str_=str(num), + bytes_=bytes(num), + bool_=(num % 2 == 0), + float_=(num + float(num) / 100)) + + def test_write_read_pipeline(self): + iceberg_config = { + "table": "test_iceberg_write_read.test_" + str(int(time.time())), + "catalog_name": "default", + "catalog_properties": { + "type": "hadoop", + "warehouse": self.WAREHOUSE, + } + } + + rows = [self._create_row(i) for i in range(100)] + expected_dicts = [row.as_dict() for row in rows] + + with beam.Pipeline(argv=self.args) as write_pipeline: + _ = ( + write_pipeline + | beam.Create(rows) + | beam.managed.Write(beam.managed.ICEBERG, config=iceberg_config)) + + with beam.Pipeline(argv=self.args) as read_pipeline: + output_dicts = ( + read_pipeline + | beam.managed.Read(beam.managed.ICEBERG, config=iceberg_config) + | beam.Map(lambda row: row._asdict())) + + assert_that(output_dicts, equal_to(expected_dicts)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 8554ebce5dbd..4848dc4aade8 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -497,13 +497,12 @@ def type_check_inputs_or_outputs(self, pvalueish, input_or_output): at_context = ' %s %s' % (input_or_output, context) if context else '' raise TypeCheckError( '{type} type hint violation at {label}{context}: expected {hint}, ' - 'got {actual_type}\nFull type hint:\n{debug_str}'.format( + 'got {actual_type}'.format( type=input_or_output.title(), label=self.label, context=at_context, hint=hint, - actual_type=pvalue_.element_type, - debug_str=type_hints.debug_str())) + actual_type=pvalue_.element_type)) def _infer_output_coder(self, input_type=None, input_coder=None): # type: (...) -> Optional[coders.Coder] @@ -748,7 +747,7 @@ def to_runner_api(self, context, has_parts=False, **extra_kwargs): # type: (PipelineContext, bool, Any) -> beam_runner_api_pb2.FunctionSpec from apache_beam.portability.api import beam_runner_api_pb2 # typing: only ParDo supports extra_kwargs - urn, typed_param = self.to_runner_api_parameter(context, **extra_kwargs) # type: ignore[call-arg] + urn, typed_param = self.to_runner_api_parameter(context, **extra_kwargs) if urn == python_urns.GENERIC_COMPOSITE_TRANSFORM and not has_parts: # TODO(https://github.com/apache/beam/issues/18713): Remove this fallback. urn, typed_param = self.to_runner_api_pickled(context) @@ -939,7 +938,25 @@ def element_type(side_input): bindings = getcallargs_forhints(argspec_fn, *arg_types, **kwargs_types) hints = getcallargs_forhints( argspec_fn, *input_types[0], **input_types[1]) - for arg, hint in hints.items(): + + # First check the main input. + arg_hints = iter(hints.items()) + element_arg, element_hint = next(arg_hints) + if not typehints.is_consistent_with( + bindings.get(element_arg, typehints.Any), element_hint): + transform_nest_level = self.label.count("/") + split_producer_label = pvalueish.producer.full_label.split("/") + producer_label = "/".join( + split_producer_label[:transform_nest_level + 1]) + raise TypeCheckError( + f"The transform '{self.label}' requires " + f"PCollections of type '{element_hint}' " + f"but was applied to a PCollection of type" + f" '{bindings[element_arg]}' " + f"(produced by the transform '{producer_label}'). ") + + # Now check the side inputs. + for arg, hint in arg_hints: if arg.startswith('__unknown__'): continue if hint is None: diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index d760ef74fb14..f9f6b230866e 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -787,6 +787,18 @@ def split_even_odd(element): assert_that(even_length, equal_to(['AA', 'CC']), label='assert:even') assert_that(odd_length, equal_to(['BBB']), label='assert:odd') + def test_flatten_with(self): + with TestPipeline() as pipeline: + input = pipeline | 'Start' >> beam.Create(['AA', 'BBB', 'CC']) + + result = ( + input + | 'WithPCollection' >> beam.FlattenWith(input | beam.Map(str.lower)) + | 'WithPTransform' >> beam.FlattenWith(beam.Create(['x', 'y']))) + + assert_that( + result, equal_to(['AA', 'BBB', 'CC', 'aa', 'bbb', 'cc', 'x', 'y'])) + def test_group_by_key_input_must_be_kv_pairs(self): with self.assertRaises(typehints.TypeCheckError) as e: with TestPipeline() as pipeline: @@ -1099,7 +1111,7 @@ def SamplePTransform(pcoll): class PTransformLabelsTest(unittest.TestCase): class CustomTransform(beam.PTransform): - pardo = None # type: Optional[beam.PTransform] + pardo: Optional[beam.PTransform] = None def expand(self, pcoll): self.pardo = '*Do*' >> beam.FlatMap(lambda x: [x + 1]) @@ -1286,17 +1298,13 @@ class ToUpperCaseWithPrefix(beam.DoFn): def process(self, element, prefix): return [prefix + element.upper()] - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'T' >> beam.Create([1, 2, 3]).with_output_types(int) | 'Upper' >> beam.ParDo(ToUpperCaseWithPrefix(), 'hello')) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for element".format(str, int)) - def test_do_fn_pipeline_runtime_type_check_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -1323,18 +1331,14 @@ class AddWithNum(beam.DoFn): def process(self, element, num): return [element + num] - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Add.*requires.*int.*applied.*str'): ( self.p | 'T' >> beam.Create(['1', '2', '3']).with_output_types(str) | 'Add' >> beam.ParDo(AddWithNum(), 5)) self.p.run() - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Add': " - "requires {} but got {} for element".format(int, str)) - def test_pardo_does_not_type_check_using_type_hint_decorators(self): @with_input_types(a=int) @with_output_types(typing.List[str]) @@ -1343,17 +1347,13 @@ def int_to_str(a): # The function above is expecting an int for its only parameter. However, it # will receive a str instead, which should result in a raised exception. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'ToStr.*requires.*int.*applied.*str'): ( self.p | 'S' >> beam.Create(['b', 'a', 'r']).with_output_types(str) | 'ToStr' >> beam.FlatMap(int_to_str)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'ToStr': " - "requires {} but got {} for a".format(int, str)) - def test_pardo_properly_type_checks_using_type_hint_decorators(self): @with_input_types(a=str) @with_output_types(typing.List[str]) @@ -1375,7 +1375,8 @@ def to_all_upper_case(a): def test_pardo_does_not_type_check_using_type_hint_methods(self): # The first ParDo outputs pcoll's of type int, however the second ParDo is # expecting pcoll's of type str instead. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'S' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str) @@ -1386,11 +1387,6 @@ def test_pardo_does_not_type_check_using_type_hint_methods(self): 'Upper' >> beam.FlatMap(lambda x: [x.upper()]).with_input_types( str).with_output_types(str))) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for x".format(str, int)) - def test_pardo_properly_type_checks_using_type_hint_methods(self): # Pipeline should be created successfully without an error d = ( @@ -1407,18 +1403,14 @@ def test_pardo_properly_type_checks_using_type_hint_methods(self): def test_map_does_not_type_check_using_type_hints_methods(self): # The transform before 'Map' has indicated that it outputs PCollections with # int's, while Map is expecting one of str. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int) | 'Upper' >> beam.Map(lambda x: x.upper()).with_input_types( str).with_output_types(str)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for x".format(str, int)) - def test_map_properly_type_checks_using_type_hints_methods(self): # No error should be raised if this type-checks properly. d = ( @@ -1437,17 +1429,13 @@ def upper(s): # Hinted function above expects a str at pipeline construction. # However, 'Map' should detect that Create has hinted an int instead. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int) | 'Upper' >> beam.Map(upper)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for s".format(str, int)) - def test_map_properly_type_checks_using_type_hints_decorator(self): @with_input_types(a=bool) @with_output_types(int) @@ -1465,7 +1453,8 @@ def bool_to_int(a): def test_filter_does_not_type_check_using_type_hints_method(self): # Filter is expecting an int but instead looks to the 'left' and sees a str # incoming. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Below 3.*requires.*int.*applied.*str'): ( self.p | 'Strs' >> beam.Create(['1', '2', '3', '4', '5' @@ -1474,11 +1463,6 @@ def test_filter_does_not_type_check_using_type_hints_method(self): str).with_output_types(str) | 'Below 3' >> beam.Filter(lambda x: x < 3).with_input_types(int)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Below 3': " - "requires {} but got {} for x".format(int, str)) - def test_filter_type_checks_using_type_hints_method(self): # No error should be raised if this type-checks properly. d = ( @@ -1496,17 +1480,13 @@ def more_than_half(a): return a > 0.50 # Func above was hinted to only take a float, yet a str will be passed. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Half.*requires.*float.*applied.*str'): ( self.p | 'Ints' >> beam.Create(['1', '2', '3', '4']).with_output_types(str) | 'Half' >> beam.Filter(more_than_half)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Half': " - "requires {} but got {} for a".format(float, str)) - def test_filter_type_checks_using_type_hints_decorator(self): @with_input_types(b=int) def half(b): @@ -2116,14 +2096,10 @@ def test_mean_globally_pipeline_checking_violated(self): self.p | 'C' >> beam.Create(['test']).with_output_types(str) | 'Mean' >> combine.Mean.Globally()) - - expected_msg = \ - "Type hint violation for 'CombinePerKey': " \ - "requires Tuple[TypeVariable[K], Union[, , " \ - ", ]] " \ - "but got Tuple[None, ] for element" - - self.assertStartswith(e.exception.args[0], expected_msg) + err_msg = e.exception.args[0] + assert "CombinePerKey" in err_msg + assert "Tuple[TypeVariable[K]" in err_msg + assert "Tuple[None, " in err_msg def test_mean_globally_runtime_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -2183,14 +2159,12 @@ def test_mean_per_key_pipeline_checking_violated(self): typing.Tuple[str, str])) | 'EvenMean' >> combine.Mean.PerKey()) self.p.run() - - expected_msg = \ - "Type hint violation for 'CombinePerKey(MeanCombineFn)': " \ - "requires Tuple[TypeVariable[K], Union[, , " \ - ", ]] " \ - "but got Tuple[, ] for element" - - self.assertStartswith(e.exception.args[0], expected_msg) + err_msg = e.exception.args[0] + assert "CombinePerKey(MeanCombineFn)" in err_msg + assert "requires" in err_msg + assert "Tuple[TypeVariable[K]" in err_msg + assert "applied" in err_msg + assert "Tuple[, ]" in err_msg def test_mean_per_key_runtime_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True diff --git a/sdks/python/apache_beam/transforms/sql_test.py b/sdks/python/apache_beam/transforms/sql_test.py index 854aec078ce5..a7da253c4617 100644 --- a/sdks/python/apache_beam/transforms/sql_test.py +++ b/sdks/python/apache_beam/transforms/sql_test.py @@ -20,6 +20,7 @@ # pytype: skip-file import logging +import subprocess import typing import unittest @@ -69,6 +70,22 @@ class SqlTransformTest(unittest.TestCase): """ _multiprocess_can_split_ = True + @staticmethod + def _disable_zetasql_test(): + # disable if run on Java8 which is no longer supported by ZetaSQL + try: + result = subprocess.run(['java', '-version'], + check=True, + capture_output=True, + text=True) + version_line = result.stderr.splitlines()[0] + version = version_line.split()[2].strip('\"') + if version.startswith("1."): + return True + return False + except: # pylint: disable=bare-except + return False + def test_generate_data(self): with TestPipeline() as p: out = p | SqlTransform( @@ -150,6 +167,9 @@ def test_row(self): assert_that(out, equal_to([(1, 1), (4, 1), (100, 2)])) def test_zetasql_generate_data(self): + if self._disable_zetasql_test(): + raise unittest.SkipTest("ZetaSQL tests need Java11+") + with TestPipeline() as p: out = p | SqlTransform( """SELECT diff --git a/sdks/python/apache_beam/transforms/stats.py b/sdks/python/apache_beam/transforms/stats.py index d389463e55a2..0d56b60b050f 100644 --- a/sdks/python/apache_beam/transforms/stats.py +++ b/sdks/python/apache_beam/transforms/stats.py @@ -919,7 +919,7 @@ def _offset(self, new_weight): # TODO(https://github.com/apache/beam/issues/19737): Signature incompatible # with supertype - def create_accumulator(self): # type: ignore[override] + def create_accumulator(self): # type: () -> _QuantileState self._qs = _QuantileState( unbuffered_elements=[], diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py index ada0b755bd6c..3b876bf9dbfb 100644 --- a/sdks/python/apache_beam/transforms/userstate.py +++ b/sdks/python/apache_beam/transforms/userstate.py @@ -150,6 +150,17 @@ def to_runner_api( urn=common_urns.user_state.BAG.urn)) +class OrderedListStateSpec(StateSpec): + """Specification for a user DoFn ordered list state cell.""" + def to_runner_api( + self, context: 'PipelineContext') -> beam_runner_api_pb2.StateSpec: + return beam_runner_api_pb2.StateSpec( + ordered_list_spec=beam_runner_api_pb2.OrderedListStateSpec( + element_coder_id=context.coders.get_id(self.coder)), + protocol=beam_runner_api_pb2.FunctionSpec( + urn=common_urns.user_state.ORDERED_LIST.urn)) + + # TODO(BEAM-9562): Update Timer to have of() and clear() APIs. Timer = NamedTuple( 'Timer', @@ -288,7 +299,7 @@ def validate_stateful_dofn(dofn: 'DoFn') -> None: 'callback: %s.') % (dofn, timer_spec)) method_name = timer_spec._attached_callback.__name__ if (timer_spec._attached_callback != getattr(dofn, method_name, - None).__func__): + None).__func__): # type: ignore[union-attr] raise ValueError(( 'The on_timer callback for %s is not the specified .%s method ' 'for DoFn %r (perhaps it was overwritten?).') % @@ -303,7 +314,7 @@ def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: raise NotImplementedError -_TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp')) +_TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp')) # type: ignore[name-match] class RuntimeTimer(BaseTimer): @@ -372,6 +383,24 @@ class CombiningValueRuntimeState(AccumulatingRuntimeState): """Combining value state interface object passed to user code.""" +class OrderedListRuntimeState(AccumulatingRuntimeState): + """Ordered list state interface object passed to user code.""" + def read(self) -> Iterable[Tuple[Timestamp, Any]]: + raise NotImplementedError(type(self)) + + def add(self, value: Tuple[Timestamp, Any]) -> None: + raise NotImplementedError(type(self)) + + def read_range( + self, min_time_stamp: Timestamp, + limit_time_stamp: Timestamp) -> Iterable[Tuple[Timestamp, Any]]: + raise NotImplementedError(type(self)) + + def clear_range( + self, min_time_stamp: Timestamp, limit_time_stamp: Timestamp) -> None: + raise NotImplementedError(type(self)) + + class UserStateContext(object): """Wrapper allowing user state and timers to be accessed by a DoFnInvoker.""" def get_timer( diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index a27c7aca9e20..a03652de2496 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -30,6 +30,7 @@ import uuid from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import List from typing import Tuple @@ -44,6 +45,7 @@ from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import AsSideInput +from apache_beam.pvalue import PCollection from apache_beam.transforms import window from apache_beam.transforms.combiners import CountCombineFn from apache_beam.transforms.core import CombinePerKey @@ -92,6 +94,7 @@ 'RemoveDuplicates', 'Reshuffle', 'ToString', + 'Tee', 'Values', 'WithKeys', 'GroupIntoBatches' @@ -1665,6 +1668,37 @@ def _process(element): return pcoll | FlatMap(_process) +@typehints.with_input_types(T) +@typehints.with_output_types(T) +class Tee(PTransform): + """A PTransform that returns its input, but also applies its input elsewhere. + + Similar to the shell {@code tee} command. This can be useful to write out or + otherwise process an intermediate transform without breaking the linear flow + of a chain of transforms, e.g.:: + + (input + | SomePTransform() + | ... + | Tee(SomeSideTransform()) + | ...) + """ + def __init__( + self, + *consumers: Union[PTransform[PCollection[T], Any], + Callable[[PCollection[T]], Any]]): + self._consumers = consumers + + def expand(self, input): + for consumer in self._consumers: + print("apply", consumer) + if callable(consumer): + _ = input | ptransform_fn(consumer)() + else: + _ = input | consumer + return input + + @typehints.with_input_types(T) @typehints.with_output_types(T) class WaitOn(PTransform): diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 9c70be7900da..d86509c7dde3 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -19,6 +19,8 @@ # pytype: skip-file +import collections +import importlib import logging import math import random @@ -27,6 +29,7 @@ import unittest import warnings from datetime import datetime +from typing import Mapping import pytest import pytz @@ -1812,6 +1815,32 @@ def test_split_without_empty(self): assert_that(result, equal_to(expected_result)) +class TeeTest(unittest.TestCase): + _side_effects: Mapping[str, int] = collections.defaultdict(int) + + def test_tee(self): + # The imports here are to avoid issues with the class (and its attributes) + # possibly being pickled rather than referenced. + def cause_side_effect(element): + importlib.import_module(__name__).TeeTest._side_effects[element] += 1 + + def count_side_effects(element): + return importlib.import_module(__name__).TeeTest._side_effects[element] + + with TestPipeline() as p: + result = ( + p + | beam.Create(['a', 'b', 'c']) + | 'TeePTransform' >> beam.Tee(beam.Map(cause_side_effect)) + | 'TeeCallable' >> beam.Tee( + lambda pcoll: pcoll | beam.Map( + lambda element: cause_side_effect('X' + element)))) + assert_that(result, equal_to(['a', 'b', 'c'])) + + self.assertEqual(count_side_effects('a'), 1) + self.assertEqual(count_side_effects('Xa'), 1) + + class WaitOnTest(unittest.TestCase): def test_find(self): # We need shared reference that survives pickling. diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index 592164a5ef49..fc20174ca1e2 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -449,8 +449,8 @@ def to_runner_api_parameter(self, context): standard_window_fns_pb2.FixedWindowsPayload) def from_runner_api_parameter(fn_parameter, unused_context) -> 'FixedWindows': return FixedWindows( - size=Duration(micros=fn_parameter.size.ToMicroseconds()), - offset=Timestamp(micros=fn_parameter.offset.ToMicroseconds())) + size=Duration(micros=proto_utils.to_micros(fn_parameter.size)), + offset=Timestamp(micros=proto_utils.to_micros(fn_parameter.offset))) class SlidingWindows(NonMergingWindowFn): @@ -522,9 +522,9 @@ def to_runner_api_parameter(self, context): def from_runner_api_parameter( fn_parameter, unused_context) -> 'SlidingWindows': return SlidingWindows( - size=Duration(micros=fn_parameter.size.ToMicroseconds()), - offset=Timestamp(micros=fn_parameter.offset.ToMicroseconds()), - period=Duration(micros=fn_parameter.period.ToMicroseconds())) + size=Duration(micros=proto_utils.to_micros(fn_parameter.size)), + offset=Timestamp(micros=proto_utils.to_micros(fn_parameter.offset)), + period=Duration(micros=proto_utils.to_micros(fn_parameter.period))) class Sessions(WindowFn): @@ -589,4 +589,4 @@ def to_runner_api_parameter(self, context): standard_window_fns_pb2.SessionWindowsPayload) def from_runner_api_parameter(fn_parameter, unused_context) -> 'Sessions': return Sessions( - gap_size=Duration(micros=fn_parameter.gap_size.ToMicroseconds())) + gap_size=Duration(micros=proto_utils.to_micros(fn_parameter.gap_size))) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 9c0cc2b8af4e..7050df7016e5 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -82,7 +82,6 @@ def foo((a, b)): import inspect import itertools import logging -import sys import traceback import types from typing import Any @@ -686,9 +685,6 @@ def get_type_hints(fn: Any) -> IOTypeHints: # Can't add arbitrary attributes to this object, # but might have some restrictions anyways... hints = IOTypeHints.empty() - # Python 3.7 introduces annotations for _MethodDescriptorTypes. - if isinstance(fn, _MethodDescriptorType) and sys.version_info < (3, 7): - hints = hints.with_input_types(fn.__objclass__) # type: ignore return hints return fn._type_hints # pylint: enable=protected-access diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index 3baf9fa8322f..dd110ced5bb8 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -20,7 +20,6 @@ # pytype: skip-file import functools -import sys import typing import unittest @@ -70,14 +69,7 @@ def test_from_callable_builtin(self): def test_from_callable_method_descriptor(self): # from_callable() injects an annotation in this special type of builtin. th = decorators.IOTypeHints.from_callable(str.strip) - if sys.version_info >= (3, 7): - self.assertEqual(th.input_types, ((str, Any), {})) - else: - self.assertEqual( - th.input_types, - ((str, decorators._ANY_VAR_POSITIONAL), { - '__unknown__keywords': decorators._ANY_VAR_KEYWORD - })) + self.assertEqual(th.input_types, ((str, Any), {})) self.assertEqual(th.output_types, ((Any, ), {})) def test_strip_iterable_not_simple_output_noop(self): @@ -409,7 +401,7 @@ def fn(a: int) -> int: return a with self.assertRaisesRegex(TypeCheckError, - r'requires .*int.* but got .*str'): + r'requires .*int.* but was applied .*str'): _ = ['a', 'b', 'c'] | Map(fn) # Same pipeline doesn't raise without annotations on fn. @@ -423,7 +415,7 @@ def fn(a: int) -> int: _ = [1, 2, 3] | Map(fn) # Doesn't raise - correct types. with self.assertRaisesRegex(TypeCheckError, - r'requires .*int.* but got .*str'): + r'requires .*int.* but was applied .*str'): _ = ['a', 'b', 'c'] | Map(fn) @decorators.no_annotations diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 621adc44507e..6f704b37a969 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -101,10 +101,7 @@ def _match_issubclass(match_against): def _match_is_exactly_mapping(user_type): # Avoid unintentionally catching all subtypes (e.g. strings and mappings). - if sys.version_info < (3, 7): - expected_origin = typing.Mapping - else: - expected_origin = collections.abc.Mapping + expected_origin = collections.abc.Mapping return getattr(user_type, '__origin__', None) is expected_origin @@ -112,10 +109,7 @@ def _match_is_exactly_iterable(user_type): if user_type is typing.Iterable: return True # Avoid unintentionally catching all subtypes (e.g. strings and mappings). - if sys.version_info < (3, 7): - expected_origin = typing.Iterable - else: - expected_origin = collections.abc.Iterable + expected_origin = collections.abc.Iterable return getattr(user_type, '__origin__', None) is expected_origin @@ -244,11 +238,10 @@ def convert_to_beam_type(typ): sys.version_info.minor >= 10) and (isinstance(typ, types.UnionType)): typ = typing.Union[typ] - if sys.version_info >= (3, 9) and isinstance(typ, types.GenericAlias): + if isinstance(typ, types.GenericAlias): typ = convert_builtin_to_typing(typ) - if sys.version_info >= (3, 9) and getattr(typ, '__module__', - None) == 'collections.abc': + if getattr(typ, '__module__', None) == 'collections.abc': typ = convert_collections_to_typing(typ) typ_module = getattr(typ, '__module__', None) diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index 2e6db6a7733c..ae8e1a0b2906 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -21,7 +21,6 @@ import collections.abc import enum -import sys import typing import unittest @@ -128,105 +127,98 @@ def test_convert_to_beam_type(self): self.assertEqual(converted_typing_type, typing_type, description) def test_convert_to_beam_type_with_builtin_types(self): - if sys.version_info >= (3, 9): - test_cases = [ - ('builtin dict', dict[str, int], typehints.Dict[str, int]), - ('builtin list', list[str], typehints.List[str]), - ('builtin tuple', tuple[str], - typehints.Tuple[str]), ('builtin set', set[str], typehints.Set[str]), - ('builtin frozenset', frozenset[int], typehints.FrozenSet[int]), - ( - 'nested builtin', - dict[str, list[tuple[float]]], - typehints.Dict[str, typehints.List[typehints.Tuple[float]]]), - ( - 'builtin nested tuple', - tuple[str, list], - typehints.Tuple[str, typehints.List[typehints.Any]], - ) - ] - - for test_case in test_cases: - description = test_case[0] - builtins_type = test_case[1] - expected_beam_type = test_case[2] - converted_beam_type = convert_to_beam_type(builtins_type) - self.assertEqual(converted_beam_type, expected_beam_type, description) + test_cases = [ + ('builtin dict', dict[str, int], typehints.Dict[str, int]), + ('builtin list', list[str], typehints.List[str]), + ('builtin tuple', tuple[str], + typehints.Tuple[str]), ('builtin set', set[str], typehints.Set[str]), + ('builtin frozenset', frozenset[int], typehints.FrozenSet[int]), + ( + 'nested builtin', + dict[str, list[tuple[float]]], + typehints.Dict[str, typehints.List[typehints.Tuple[float]]]), + ( + 'builtin nested tuple', + tuple[str, list], + typehints.Tuple[str, typehints.List[typehints.Any]], + ) + ] + + for test_case in test_cases: + description = test_case[0] + builtins_type = test_case[1] + expected_beam_type = test_case[2] + converted_beam_type = convert_to_beam_type(builtins_type) + self.assertEqual(converted_beam_type, expected_beam_type, description) def test_convert_to_beam_type_with_collections_types(self): - if sys.version_info >= (3, 9): - test_cases = [ - ( - 'collection iterable', - collections.abc.Iterable[int], - typehints.Iterable[int]), - ( - 'collection generator', - collections.abc.Generator[int], - typehints.Generator[int]), - ( - 'collection iterator', - collections.abc.Iterator[int], - typehints.Iterator[int]), - ( - 'nested iterable', - tuple[bytes, collections.abc.Iterable[int]], - typehints.Tuple[bytes, typehints.Iterable[int]]), - ( - 'iterable over tuple', - collections.abc.Iterable[tuple[str, int]], - typehints.Iterable[typehints.Tuple[str, int]]), - ( - 'mapping not caught', - collections.abc.Mapping[str, int], - collections.abc.Mapping[str, int]), - ('set', collections.abc.Set[str], typehints.Set[str]), - ('mutable set', collections.abc.MutableSet[int], typehints.Set[int]), - ( - 'enum set', - collections.abc.Set[_TestEnum], - typehints.Set[_TestEnum]), - ( - 'enum mutable set', - collections.abc.MutableSet[_TestEnum], - typehints.Set[_TestEnum]), - ( - 'collection enum', - collections.abc.Collection[_TestEnum], - typehints.Collection[_TestEnum]), - ( - 'collection of tuples', - collections.abc.Collection[tuple[str, int]], - typehints.Collection[typehints.Tuple[str, int]]), - ] - - for test_case in test_cases: - description = test_case[0] - builtins_type = test_case[1] - expected_beam_type = test_case[2] - converted_beam_type = convert_to_beam_type(builtins_type) - self.assertEqual(converted_beam_type, expected_beam_type, description) + test_cases = [ + ( + 'collection iterable', + collections.abc.Iterable[int], + typehints.Iterable[int]), + ( + 'collection generator', + collections.abc.Generator[int], + typehints.Generator[int]), + ( + 'collection iterator', + collections.abc.Iterator[int], + typehints.Iterator[int]), + ( + 'nested iterable', + tuple[bytes, collections.abc.Iterable[int]], + typehints.Tuple[bytes, typehints.Iterable[int]]), + ( + 'iterable over tuple', + collections.abc.Iterable[tuple[str, int]], + typehints.Iterable[typehints.Tuple[str, int]]), + ( + 'mapping not caught', + collections.abc.Mapping[str, int], + collections.abc.Mapping[str, int]), + ('set', collections.abc.Set[str], typehints.Set[str]), + ('mutable set', collections.abc.MutableSet[int], typehints.Set[int]), + ('enum set', collections.abc.Set[_TestEnum], typehints.Set[_TestEnum]), + ( + 'enum mutable set', + collections.abc.MutableSet[_TestEnum], + typehints.Set[_TestEnum]), + ( + 'collection enum', + collections.abc.Collection[_TestEnum], + typehints.Collection[_TestEnum]), + ( + 'collection of tuples', + collections.abc.Collection[tuple[str, int]], + typehints.Collection[typehints.Tuple[str, int]]), + ] + + for test_case in test_cases: + description = test_case[0] + builtins_type = test_case[1] + expected_beam_type = test_case[2] + converted_beam_type = convert_to_beam_type(builtins_type) + self.assertEqual(converted_beam_type, expected_beam_type, description) def test_convert_builtin_to_typing(self): - if sys.version_info >= (3, 9): - test_cases = [ - ('dict', dict[str, int], typing.Dict[str, int]), - ('list', list[str], typing.List[str]), - ('tuple', tuple[str], typing.Tuple[str]), - ('set', set[str], typing.Set[str]), - ( - 'nested', - dict[str, list[tuple[float]]], - typing.Dict[str, typing.List[typing.Tuple[float]]]), - ] - - for test_case in test_cases: - description = test_case[0] - builtin_type = test_case[1] - expected_typing_type = test_case[2] - converted_typing_type = convert_builtin_to_typing(builtin_type) - self.assertEqual( - converted_typing_type, expected_typing_type, description) + test_cases = [ + ('dict', dict[str, int], typing.Dict[str, int]), + ('list', list[str], typing.List[str]), + ('tuple', tuple[str], typing.Tuple[str]), + ('set', set[str], typing.Set[str]), + ( + 'nested', + dict[str, list[tuple[float]]], + typing.Dict[str, typing.List[typing.Tuple[float]]]), + ] + + for test_case in test_cases: + description = test_case[0] + builtin_type = test_case[1] + expected_typing_type = test_case[2] + converted_typing_type = convert_builtin_to_typing(builtin_type) + self.assertEqual(converted_typing_type, expected_typing_type, description) def test_generator_converted_to_iterator(self): self.assertEqual( @@ -293,14 +285,11 @@ def test_convert_bare_types(self): typing.Tuple[typing.Iterator], typehints.Tuple[typehints.Iterator[typehints.TypeVariable('T_co')]] ), + ( + 'bare generator', + typing.Generator, + typehints.Generator[typehints.TypeVariable('T_co')]), ] - if sys.version_info >= (3, 7): - test_cases += [ - ( - 'bare generator', - typing.Generator, - typehints.Generator[typehints.TypeVariable('T_co')]), - ] for test_case in test_cases: description = test_case[0] typing_type = test_case[1] diff --git a/sdks/python/apache_beam/typehints/opcodes.py b/sdks/python/apache_beam/typehints/opcodes.py index 62c7a8fadc35..7bea621841f6 100644 --- a/sdks/python/apache_beam/typehints/opcodes.py +++ b/sdks/python/apache_beam/typehints/opcodes.py @@ -246,14 +246,10 @@ def set_add(state, arg): def map_add(state, arg): - if sys.version_info >= (3, 8): - # PEP 572 The MAP_ADD expects the value as the first element in the stack - # and the key as the second element. - new_value_type = Const.unwrap(state.stack.pop()) - new_key_type = Const.unwrap(state.stack.pop()) - else: - new_key_type = Const.unwrap(state.stack.pop()) - new_value_type = Const.unwrap(state.stack.pop()) + # PEP 572 The MAP_ADD expects the value as the first element in the stack + # and the key as the second element. + new_value_type = Const.unwrap(state.stack.pop()) + new_key_type = Const.unwrap(state.stack.pop()) state.stack[-arg] = Dict[Union[state.stack[-arg].key_type, new_key_type], Union[state.stack[-arg].value_type, new_value_type]] diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index fd7885ad59c4..880a897bbbe8 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -49,7 +49,8 @@ def __init__( fields: Sequence[Tuple[str, type]], user_type, schema_options: Optional[Sequence[Tuple[str, Any]]] = None, - field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None): + field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, + field_descriptions: Optional[Dict[str, str]] = None): """For internal use only, no backwards comatibility guaratees. See https://beam.apache.org/documentation/programming-guide/#schemas-for-pl-types for guidance on creating PCollections with inferred schemas. @@ -96,6 +97,7 @@ def __init__( self._schema_options = schema_options or [] self._field_options = field_options or {} + self._field_descriptions = field_descriptions or {} @staticmethod def from_user_type( @@ -107,12 +109,15 @@ def from_user_type( fields = [(name, user_type.__annotations__[name]) for name in user_type._fields] + field_descriptions = getattr(user_type, '_field_descriptions', None) + if _user_type_is_generated(user_type): return RowTypeConstraint.from_fields( fields, schema_id=getattr(user_type, _BEAM_SCHEMA_ID), schema_options=schema_options, - field_options=field_options) + field_options=field_options, + field_descriptions=field_descriptions) # TODO(https://github.com/apache/beam/issues/22125): Add user API for # specifying schema/field options @@ -120,7 +125,8 @@ def from_user_type( fields=fields, user_type=user_type, schema_options=schema_options, - field_options=field_options) + field_options=field_options, + field_descriptions=field_descriptions) return None @@ -131,13 +137,15 @@ def from_fields( schema_options: Optional[Sequence[Tuple[str, Any]]] = None, field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, schema_registry: Optional[SchemaTypeRegistry] = None, + field_descriptions: Optional[Dict[str, str]] = None, ) -> RowTypeConstraint: return GeneratedClassRowTypeConstraint( fields, schema_id=schema_id, schema_options=schema_options, field_options=field_options, - schema_registry=schema_registry) + schema_registry=schema_registry, + field_descriptions=field_descriptions) def __call__(self, *args, **kwargs): # We make RowTypeConstraint callable (defers to constructing the user type) @@ -206,6 +214,7 @@ def __init__( schema_options: Optional[Sequence[Tuple[str, Any]]] = None, field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, schema_registry: Optional[SchemaTypeRegistry] = None, + field_descriptions: Optional[Dict[str, str]] = None, ): from apache_beam.typehints.schemas import named_fields_to_schema from apache_beam.typehints.schemas import named_tuple_from_schema @@ -224,7 +233,8 @@ def __init__( fields, user_type, schema_options=schema_options, - field_options=field_options) + field_options=field_options, + field_descriptions=field_descriptions) def __reduce__(self): return ( diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index ef82ca91044c..2c5d35a68cc2 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -274,6 +274,7 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: self.option_to_runner_api(option_tuple) for option_tuple in type_.field_options(field_name) ], + description=type_._field_descriptions.get(field_name, None), ) for (field_name, field_type) in type_._fields ], id=schema_id, @@ -512,13 +513,17 @@ def typing_from_runner_api( # generate a NamedTuple type to use. fields = named_fields_from_schema(schema) + descriptions = { + field.name: field.description + for field in schema.fields + } result = row_type.RowTypeConstraint.from_fields( fields=fields, schema_id=schema.id, schema_options=schema_options, field_options=field_options, schema_registry=self.schema_registry, - ) + field_descriptions=descriptions or None) return result else: return row_type.RowTypeConstraint.from_user_type( diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index 5d38b16d9783..15144c6c2c17 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -489,6 +489,44 @@ def test_row_type_constraint_to_schema_with_field_options(self): ] self.assertEqual(list(field.options), expected) + def test_row_type_constraint_to_schema_with_field_descriptions(self): + row_type_with_options = row_type.RowTypeConstraint.from_fields( + [ + ('foo', np.int8), + ('bar', float), + ('baz', bytes), + ], + field_descriptions={ + 'foo': 'foo description', + 'bar': 'bar description', + 'baz': 'baz description', + }) + result_type = typing_to_runner_api(row_type_with_options) + + self.assertIsInstance(result_type, schema_pb2.FieldType) + self.assertEqual(result_type.WhichOneof("type_info"), "row_type") + + fields = result_type.row_type.schema.fields + + expected = [ + schema_pb2.Field( + name='foo', + description='foo description', + type=schema_pb2.FieldType(atomic_type=schema_pb2.BYTE), + ), + schema_pb2.Field( + name='bar', + description='bar description', + type=schema_pb2.FieldType(atomic_type=schema_pb2.DOUBLE), + ), + schema_pb2.Field( + name='baz', + description='baz description', + type=schema_pb2.FieldType(atomic_type=schema_pb2.BYTES), + ), + ] + self.assertEqual(list(fields), expected) + def assert_namedtuple_equivalent(self, actual, expected): # Two types are only considered equal if they are literally the same # object (i.e. `actual == expected` is the same as `actual is expected` in diff --git a/sdks/python/apache_beam/typehints/trivial_inference.py b/sdks/python/apache_beam/typehints/trivial_inference.py index 8b6f43abaa83..fe9007ed63ca 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference.py +++ b/sdks/python/apache_beam/typehints/trivial_inference.py @@ -492,7 +492,7 @@ def infer_return_type_func(f, input_types, debug=False, depth=0): # stack[-has_kwargs]: Map of keyword args. # stack[-1 - has_kwargs]: Iterable of positional args. # stack[-2 - has_kwargs]: Function to call. - has_kwargs = arg & 1 # type: int + has_kwargs: int = arg & 1 pop_count = has_kwargs + 2 if has_kwargs: # TODO(BEAM-24755): Unimplemented. Requires same functionality as a diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test.py b/sdks/python/apache_beam/typehints/typed_pipeline_test.py index 72aed46f5e78..820f78fa9ef5 100644 --- a/sdks/python/apache_beam/typehints/typed_pipeline_test.py +++ b/sdks/python/apache_beam/typehints/typed_pipeline_test.py @@ -19,9 +19,9 @@ # pytype: skip-file -import sys import typing import unittest +from typing import Tuple import apache_beam as beam from apache_beam import pvalue @@ -88,11 +88,11 @@ def process(self, element): self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): ['a', 'b', 'c'] | beam.ParDo(MyDoFn()) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_dofn_method(self): @@ -104,11 +104,11 @@ def process(self, element: int) -> typehints.Tuple[str]: self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = ['a', 'b', 'c'] | beam.ParDo(MyDoFn()) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_dofn_method_with_class_decorators(self): @@ -124,12 +124,12 @@ def process(self, element: int) -> typehints.Tuple[str]: with self.assertRaisesRegex( typehints.TypeCheckError, - r'requires.*Tuple\[, \].*got.*str'): + r'requires.*Tuple\[, \].*applied.*str'): _ = ['a', 'b', 'c'] | beam.ParDo(MyDoFn()) with self.assertRaisesRegex( typehints.TypeCheckError, - r'requires.*Tuple\[, \].*got.*int'): + r'requires.*Tuple\[, \].*applied.*int'): _ = [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_callable_iterable_output(self): @@ -156,11 +156,11 @@ def process(self, element: typehints.Tuple[int, int]) -> \ self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = ['a', 'b', 'c'] | beam.ParDo(my_do_fn) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = [1, 2, 3] | (beam.ParDo(my_do_fn) | 'again' >> beam.ParDo(my_do_fn)) def test_typed_callable_instance(self): @@ -177,11 +177,11 @@ def do_fn(element: typehints.Tuple[int, int]) -> typehints.Generator[str]: self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = ['a', 'b', 'c'] | pardo with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = [1, 2, 3] | (pardo | 'again' >> pardo) def test_filter_type_hint(self): @@ -430,7 +430,7 @@ def fn(element: float): return pcoll | beam.ParDo(fn) with self.assertRaisesRegex(typehints.TypeCheckError, - r'ParDo.*requires.*float.*got.*str'): + r'ParDo.*requires.*float.*applied.*str'): _ = ['1', '2', '3'] | MyMap() with self.assertRaisesRegex(typehints.TypeCheckError, r'MyMap.*expected.*str.*got.*bytes'): @@ -632,14 +632,14 @@ def produces_unkown(e): return e @typehints.with_input_types(int) - def requires_int(e): + def accepts_int(e): return e class MyPTransform(beam.PTransform): def expand(self, pcoll): unknowns = pcoll | beam.Map(produces_unkown) ints = pcoll | beam.Map(int) - return (unknowns, ints) | beam.Flatten() | beam.Map(requires_int) + return (unknowns, ints) | beam.Flatten() | beam.Map(accepts_int) _ = [1, 2, 3] | MyPTransform() @@ -761,8 +761,8 @@ def test_var_positional_only_side_input_hint(self): with self.assertRaisesRegex( typehints.TypeCheckError, - r'requires Tuple\[Union\[, \], ...\] but ' - r'got Tuple\[Union\[, \], ...\]'): + r'requires.*Tuple\[Union\[, \], ...\].*' + r'applied.*Tuple\[Union\[, \], ...\]'): _ = [1.2] | beam.Map(lambda *_: 'a', 5).with_input_types(int, str) def test_var_keyword_side_input_hint(self): @@ -874,12 +874,7 @@ def test_flat_type_hint(self): class AnnotationsTest(unittest.TestCase): def test_pardo_wrapper_builtin_method(self): th = beam.ParDo(str.strip).get_type_hints() - if sys.version_info < (3, 7): - self.assertEqual(th.input_types, ((str, ), {})) - else: - # Python 3.7+ has annotations for CPython builtins - # (_MethodDescriptorType). - self.assertEqual(th.input_types, ((str, typehints.Any), {})) + self.assertEqual(th.input_types, ((str, typehints.Any), {})) self.assertEqual(th.output_types, ((typehints.Any, ), {})) def test_pardo_wrapper_builtin_type(self): @@ -1005,5 +1000,22 @@ def filter_fn(element: int) -> bool: self.assertEqual(th.output_types, ((int, ), {})) +class TestFlatMapTuple(unittest.TestCase): + def test_flatmaptuple(self): + # Regression test. See + # https://github.com/apache/beam/issues/33014 + + def identity(x: Tuple[str, int]) -> Tuple[str, int]: + return x + + with beam.Pipeline() as p: + # Just checking that this doesn't raise an exception. + ( + p + | "Generate input" >> beam.Create([('P1', [2])]) + | "Flat" >> beam.FlatMapTuple(lambda k, vs: [(k, v) for v in vs]) + | "Identity" >> beam.Map(identity)) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 912cb78dc095..0e18e887c2a0 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -391,12 +391,6 @@ def validate_composite_type_param(type_param, error_msg_prefix): if sys.version_info.major == 3 and sys.version_info.minor >= 10: if isinstance(type_param, types.UnionType): is_not_type_constraint = False - # Pre-Python 3.9 compositve type-hinting with built-in types was not - # supported, the typing module equivalents should be used instead. - if sys.version_info.major == 3 and sys.version_info.minor < 9: - is_not_type_constraint = is_not_type_constraint or ( - isinstance(type_param, type) and - type_param in DISALLOWED_PRIMITIVE_TYPES) if is_not_type_constraint: raise TypeError( @@ -1266,7 +1260,7 @@ def normalize(x, none_as_type=False): # Avoid circular imports from apache_beam.typehints import native_type_compatibility - if sys.version_info >= (3, 9) and isinstance(x, types.GenericAlias): + if isinstance(x, types.GenericAlias): x = native_type_compatibility.convert_builtin_to_typing(x) if none_as_type and x is None: diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 843c1498cac5..6611dcecab01 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -388,15 +388,10 @@ def test_getitem_params_must_be_type_or_constraint(self): typehints.Tuple[5, [1, 3]] self.assertTrue(e.exception.args[0].startswith(expected_error_prefix)) - if sys.version_info < (3, 9): - with self.assertRaises(TypeError) as e: - typehints.Tuple[list, dict] - self.assertTrue(e.exception.args[0].startswith(expected_error_prefix)) - else: - try: - typehints.Tuple[list, dict] - except TypeError: - self.fail("built-in composite raised TypeError unexpectedly") + try: + typehints.Tuple[list, dict] + except TypeError: + self.fail("built-in composite raised TypeError unexpectedly") def test_compatibility_arbitrary_length(self): self.assertNotCompatible( @@ -548,15 +543,13 @@ def test_type_check_invalid_composite_type_arbitrary_length(self): e.exception.args[0]) def test_normalize_with_builtin_tuple(self): - if sys.version_info >= (3, 9): - expected_beam_type = typehints.Tuple[int, int] - converted_beam_type = typehints.normalize(tuple[int, int], False) - self.assertEqual(converted_beam_type, expected_beam_type) + expected_beam_type = typehints.Tuple[int, int] + converted_beam_type = typehints.normalize(tuple[int, int], False) + self.assertEqual(converted_beam_type, expected_beam_type) def test_builtin_and_type_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible(tuple, typing.Tuple) - self.assertCompatible(tuple[int, int], typing.Tuple[int, int]) + self.assertCompatible(tuple, typing.Tuple) + self.assertCompatible(tuple[int, int], typing.Tuple[int, int]) class ListHintTestCase(TypeHintTestCase): @@ -618,22 +611,19 @@ def test_enforce_list_type_constraint_invalid_composite_type(self): e.exception.args[0]) def test_normalize_with_builtin_list(self): - if sys.version_info >= (3, 9): - expected_beam_type = typehints.List[int] - converted_beam_type = typehints.normalize(list[int], False) - self.assertEqual(converted_beam_type, expected_beam_type) + expected_beam_type = typehints.List[int] + converted_beam_type = typehints.normalize(list[int], False) + self.assertEqual(converted_beam_type, expected_beam_type) def test_builtin_and_type_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible(list, typing.List) - self.assertCompatible(list[int], typing.List[int]) + self.assertCompatible(list, typing.List) + self.assertCompatible(list[int], typing.List[int]) def test_is_typing_generic(self): self.assertTrue(typehints.is_typing_generic(typing.List[str])) def test_builtin_is_typing_generic(self): - if sys.version_info >= (3, 9): - self.assertTrue(typehints.is_typing_generic(list[str])) + self.assertTrue(typehints.is_typing_generic(list[str])) class KVHintTestCase(TypeHintTestCase): @@ -687,14 +677,10 @@ def test_getitem_param_must_have_length_2(self): e.exception.args[0]) def test_key_type_must_be_valid_composite_param(self): - if sys.version_info < (3, 9): - with self.assertRaises(TypeError): - typehints.Dict[list, int] - else: - try: - typehints.Tuple[list, int] - except TypeError: - self.fail("built-in composite raised TypeError unexpectedly") + try: + typehints.Tuple[list, int] + except TypeError: + self.fail("built-in composite raised TypeError unexpectedly") def test_value_type_must_be_valid_composite_param(self): with self.assertRaises(TypeError): @@ -777,35 +763,24 @@ def test_match_type_variables(self): hint.match_type_variables(typehints.Dict[int, str])) def test_normalize_with_builtin_dict(self): - if sys.version_info >= (3, 9): - expected_beam_type = typehints.Dict[str, int] - converted_beam_type = typehints.normalize(dict[str, int], False) - self.assertEqual(converted_beam_type, expected_beam_type) + expected_beam_type = typehints.Dict[str, int] + converted_beam_type = typehints.normalize(dict[str, int], False) + self.assertEqual(converted_beam_type, expected_beam_type) def test_builtin_and_type_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible(dict, typing.Dict) - self.assertCompatible(dict[str, int], typing.Dict[str, int]) - self.assertCompatible( - dict[str, list[int]], typing.Dict[str, typing.List[int]]) + self.assertCompatible(dict, typing.Dict) + self.assertCompatible(dict[str, int], typing.Dict[str, int]) + self.assertCompatible( + dict[str, list[int]], typing.Dict[str, typing.List[int]]) class BaseSetHintTest: class CommonTests(TypeHintTestCase): def test_getitem_invalid_composite_type_param(self): - if sys.version_info < (3, 9): - with self.assertRaises(TypeError) as e: - self.beam_type[list] - self.assertEqual( - "Parameter to a {} hint must be a non-sequence, a " - "type, or a TypeConstraint. {} is an instance of " - "type.".format(self.string_type, list), - e.exception.args[0]) - else: - try: - self.beam_type[list] - except TypeError: - self.fail("built-in composite raised TypeError unexpectedly") + try: + self.beam_type[list] + except TypeError: + self.fail("built-in composite raised TypeError unexpectedly") def test_non_typing_generic(self): testCase = DummyTestClass1() @@ -855,16 +830,14 @@ class SetHintTestCase(BaseSetHintTest.CommonTests): string_type = 'Set' def test_builtin_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible(set[int], collections.abc.Set[int]) - self.assertCompatible(set[int], collections.abc.MutableSet[int]) + self.assertCompatible(set[int], collections.abc.Set[int]) + self.assertCompatible(set[int], collections.abc.MutableSet[int]) def test_collections_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible( - collections.abc.Set[int], collections.abc.MutableSet[int]) - self.assertCompatible( - collections.abc.MutableSet[int], collections.abc.Set[int]) + self.assertCompatible( + collections.abc.Set[int], collections.abc.MutableSet[int]) + self.assertCompatible( + collections.abc.MutableSet[int], collections.abc.Set[int]) class FrozenSetHintTestCase(BaseSetHintTest.CommonTests): @@ -1416,37 +1389,16 @@ def func(a, b_c, *d): func, *[Any, Any, Tuple[str, ...], int])) def test_getcallargs_forhints_builtins(self): - if sys.version_info < (3, 7): - # Signatures for builtins are not supported in 3.5 and 3.6. - self.assertEqual({ - '_': str, - '__unknown__varargs': Tuple[Any, ...], - '__unknown__keywords': typehints.Dict[Any, Any] - }, - getcallargs_forhints(str.upper, str)) - self.assertEqual({ - '_': str, - '__unknown__varargs': Tuple[str, ...], - '__unknown__keywords': typehints.Dict[Any, Any] - }, - getcallargs_forhints(str.strip, str, str)) - self.assertEqual({ - '_': str, - '__unknown__varargs': Tuple[typehints.List[int], ...], - '__unknown__keywords': typehints.Dict[Any, Any] - }, - getcallargs_forhints(str.join, str, typehints.List[int])) - else: - self.assertEqual({'self': str}, getcallargs_forhints(str.upper, str)) - # str.strip has an optional second argument. - self.assertEqual({ - 'self': str, 'chars': Any - }, - getcallargs_forhints(str.strip, str)) - self.assertEqual({ - 'self': str, 'iterable': typehints.List[int] - }, - getcallargs_forhints(str.join, str, typehints.List[int])) + self.assertEqual({'self': str}, getcallargs_forhints(str.upper, str)) + # str.strip has an optional second argument. + self.assertEqual({ + 'self': str, 'chars': Any + }, + getcallargs_forhints(str.strip, str)) + self.assertEqual({ + 'self': str, 'iterable': typehints.List[int] + }, + getcallargs_forhints(str.join, str, typehints.List[int])) class TestGetYieldedType(unittest.TestCase): diff --git a/sdks/python/apache_beam/utils/profiler.py b/sdks/python/apache_beam/utils/profiler.py index c75fdcc5878d..61c2371bd07d 100644 --- a/sdks/python/apache_beam/utils/profiler.py +++ b/sdks/python/apache_beam/utils/profiler.py @@ -104,7 +104,9 @@ def __exit__(self, *args): self.profile.create_stats() self.profile_output = self._upload_profile_data( # typing: seems stats attr is missing from typeshed - self.profile_location, 'cpu_profile', self.profile.stats) # type: ignore[attr-defined] + self.profile_location, + 'cpu_profile', + self.profile.stats) if self.enable_memory_profiling: if not self.hpy: diff --git a/sdks/python/apache_beam/utils/proto_utils.py b/sdks/python/apache_beam/utils/proto_utils.py index cc637dead477..60c0af2ebac0 100644 --- a/sdks/python/apache_beam/utils/proto_utils.py +++ b/sdks/python/apache_beam/utils/proto_utils.py @@ -36,6 +36,9 @@ message_types = (message.Message, ) +_SECONDS_TO_MICROS = 10**6 +_MICROS_TO_NANOS = 10**3 + @overload def pack_Any(msg: message.Message) -> any_pb2.Any: @@ -43,7 +46,7 @@ def pack_Any(msg: message.Message) -> any_pb2.Any: @overload -def pack_Any(msg: None) -> None: +def pack_Any(msg: None) -> None: # type: ignore[overload-cannot-match] pass @@ -115,8 +118,29 @@ def pack_Struct(**kwargs) -> struct_pb2.Struct: def from_micros(cls: Type[TimeMessageT], micros: int) -> TimeMessageT: result = cls() - result.FromMicroseconds(micros) - return result + if isinstance(result, duration_pb2.Duration): + result.FromMicroseconds(micros) + return result + # Protobuf 5.x enforces a maximum timestamp value less than the Beam + # maximum allowable timestamp, so we cannot use the built-in conversion. + elif isinstance(result, timestamp_pb2.Timestamp): + result.seconds = micros // _SECONDS_TO_MICROS + result.nanos = (micros % _SECONDS_TO_MICROS) * _MICROS_TO_NANOS + return result + else: + raise RuntimeError('cannot convert the micro seconds to %s' % cls) + + +def to_micros(value: Union[duration_pb2.Duration, timestamp_pb2.Timestamp]): + if isinstance(value, duration_pb2.Duration): + return value.ToMicroseconds() + # Protobuf 5.x enforces a maximum timestamp value less than the Beam + # maximum allowable timestamp, so we cannot use the built-in conversion. + elif isinstance(value, timestamp_pb2.Timestamp): + micros = value.seconds * _SECONDS_TO_MICROS + return micros + (value.nanos // _MICROS_TO_NANOS) + else: + raise RuntimeError('cannot convert %s to micro seconds' % value) def to_Timestamp(time: Union[int, float]) -> timestamp_pb2.Timestamp: diff --git a/sdks/python/apache_beam/utils/proto_utils_test.py b/sdks/python/apache_beam/utils/proto_utils_test.py new file mode 100644 index 000000000000..c40967cd2c0f --- /dev/null +++ b/sdks/python/apache_beam/utils/proto_utils_test.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 + +from apache_beam.utils import proto_utils +from apache_beam.utils.timestamp import MAX_TIMESTAMP + + +class TestProtoUtils(unittest.TestCase): + def test_from_micros_duration(self): + ts = proto_utils.from_micros(duration_pb2.Duration, MAX_TIMESTAMP.micros) + expected = duration_pb2.Duration( + seconds=MAX_TIMESTAMP.seconds(), nanos=775000000) + self.assertEqual(ts, expected) + + def test_from_micros_timestamp(self): + ts = proto_utils.from_micros(timestamp_pb2.Timestamp, MAX_TIMESTAMP.micros) + expected = timestamp_pb2.Timestamp( + seconds=MAX_TIMESTAMP.seconds(), nanos=775000000) + self.assertEqual(ts, expected) + + def test_to_micros_duration(self): + dur = duration_pb2.Duration( + seconds=MAX_TIMESTAMP.seconds(), nanos=775000000) + ts = proto_utils.to_micros(dur) + expected = MAX_TIMESTAMP.micros + self.assertEqual(ts, expected) + + def test_to_micros_timestamp(self): + dur = timestamp_pb2.Timestamp( + seconds=MAX_TIMESTAMP.seconds(), nanos=775000000) + ts = proto_utils.to_micros(dur) + expected = MAX_TIMESTAMP.micros + self.assertEqual(ts, expected) + + def test_round_trip_duration(self): + expected = 919336704 + dur = proto_utils.from_micros(duration_pb2.Duration, expected) + ms = proto_utils.to_micros(dur) + self.assertEqual(ms, expected) + + def test_round_trip_timestamp(self): + expected = 919336704 + ts = proto_utils.from_micros(timestamp_pb2.Timestamp, expected) + ms = proto_utils.to_micros(ts) + self.assertEqual(ms, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py index 944c12625d7c..b1080cb643af 100644 --- a/sdks/python/apache_beam/utils/subprocess_server.py +++ b/sdks/python/apache_beam/utils/subprocess_server.py @@ -266,11 +266,17 @@ class JavaJarServer(SubprocessServer): 'local', (threading.local, ), dict(__init__=lambda self: setattr(self, 'replacements', {})))() - def __init__(self, stub_class, path_to_jar, java_arguments, classpath=None): + def __init__( + self, + stub_class, + path_to_jar, + java_arguments, + classpath=None, + cache_dir=None): if classpath: # java -jar ignores the classpath, so we make a new jar that embeds # the requested classpath. - path_to_jar = self.make_classpath_jar(path_to_jar, classpath) + path_to_jar = self.make_classpath_jar(path_to_jar, classpath, cache_dir) super().__init__( stub_class, ['java', '-jar', path_to_jar] + list(java_arguments)) self._existing_service = path_to_jar if is_service_endpoint( diff --git a/sdks/python/apache_beam/version.py b/sdks/python/apache_beam/version.py index dfe451175fde..9974bb68bccf 100644 --- a/sdks/python/apache_beam/version.py +++ b/sdks/python/apache_beam/version.py @@ -17,4 +17,4 @@ """Apache Beam SDK version information and utilities.""" -__version__ = '2.61.0.dev' +__version__ = '2.62.0.dev' diff --git a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py index 6c8efac980aa..3b497ed1efab 100644 --- a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py +++ b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py @@ -40,8 +40,8 @@ def check_output(expected: List[str]): - def _check_inner(actual: PCollection[str]): - formatted_actual = actual | beam.Map( + def _check_inner(actual: List[PCollection[str]]): + formatted_actual = actual | beam.Flatten() | beam.Map( lambda row: str(beam.Row(**row._asdict()))) assert_matches_stdout(formatted_actual, expected) @@ -59,6 +59,57 @@ def products_csv(): ]) +def spanner_data(): + return [{ + 'shipment_id': 'S1', + 'customer_id': 'C1', + 'shipment_date': '2023-05-01', + 'shipment_cost': 150.0, + 'customer_name': 'Alice', + 'customer_email': 'alice@example.com' + }, + { + 'shipment_id': 'S2', + 'customer_id': 'C2', + 'shipment_date': '2023-06-12', + 'shipment_cost': 300.0, + 'customer_name': 'Bob', + 'customer_email': 'bob@example.com' + }, + { + 'shipment_id': 'S3', + 'customer_id': 'C1', + 'shipment_date': '2023-05-10', + 'shipment_cost': 20.0, + 'customer_name': 'Alice', + 'customer_email': 'alice@example.com' + }, + { + 'shipment_id': 'S4', + 'customer_id': 'C4', + 'shipment_date': '2024-07-01', + 'shipment_cost': 150.0, + 'customer_name': 'Derek', + 'customer_email': 'derek@example.com' + }, + { + 'shipment_id': 'S5', + 'customer_id': 'C5', + 'shipment_date': '2023-05-09', + 'shipment_cost': 300.0, + 'customer_name': 'Erin', + 'customer_email': 'erin@example.com' + }, + { + 'shipment_id': 'S6', + 'customer_id': 'C4', + 'shipment_date': '2024-07-02', + 'shipment_cost': 150.0, + 'customer_name': 'Derek', + 'customer_email': 'derek@example.com' + }] + + def create_test_method( pipeline_spec_file: str, custom_preprocessors: List[Callable[..., Union[Dict, List]]]): @@ -84,9 +135,12 @@ def test_yaml_example(self): pickle_library='cloudpickle', **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( 'options', {})))) as p: - actual = yaml_transform.expand_pipeline(p, pipeline_spec) - if not actual: - actual = p.transforms_stack[0].parts[-1].outputs[None] + actual = [yaml_transform.expand_pipeline(p, pipeline_spec)] + if not actual[0]: + actual = list(p.transforms_stack[0].parts[-1].outputs.values()) + for transform in p.transforms_stack[0].parts[:-1]: + if transform.transform.label == 'log_for_testing': + actual += list(transform.outputs.values()) check_output(expected)(actual) return test_yaml_example @@ -155,9 +209,13 @@ def _wordcount_test_preprocessor( env.input_file('kinglear.txt', '\n'.join(lines))) -@YamlExamplesTestSuite.register_test_preprocessor( - ['test_simple_filter_yaml', 'test_simple_filter_and_combine_yaml']) -def _file_io_write_test_preprocessor( +@YamlExamplesTestSuite.register_test_preprocessor([ + 'test_simple_filter_yaml', + 'test_simple_filter_and_combine_yaml', + 'test_spanner_read_yaml', + 'test_spanner_write_yaml' +]) +def _io_write_test_preprocessor( test_spec: dict, expected: List[str], env: TestEnvironment): if pipeline := test_spec.get('pipeline', None): @@ -166,8 +224,8 @@ def _file_io_write_test_preprocessor( transform['type'] = 'LogForTesting' transform['config'] = { k: v - for k, - v in transform.get('config', {}).items() if k.startswith('__') + for (k, v) in transform.get('config', {}).items() + if (k.startswith('__') or k == 'error_handling') } return test_spec @@ -191,7 +249,30 @@ def _file_io_read_test_preprocessor( return test_spec +@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml']) +def _spanner_io_read_test_preprocessor( + test_spec: dict, expected: List[str], env: TestEnvironment): + + if pipeline := test_spec.get('pipeline', None): + for transform in pipeline.get('transforms', []): + if transform.get('type', '').startswith('ReadFromSpanner'): + config = transform['config'] + instance, database = config['instance_id'], config['database_id'] + if table := config.get('table', None) is None: + table = config.get('query', '').split('FROM')[-1].strip() + transform['type'] = 'Create' + transform['config'] = { + k: v + for k, v in config.items() if k.startswith('__') + } + transform['config']['elements'] = INPUT_TABLES[( + str(instance), str(database), str(table))] + + return test_spec + + INPUT_FILES = {'products.csv': products_csv()} +INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()} YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__)) ExamplesTest = YamlExamplesTestSuite( @@ -205,6 +286,10 @@ def _file_io_read_test_preprocessor( 'AggregationExamplesTest', os.path.join(YAML_DOCS_DIR, '../transforms/aggregation/*.yaml')).run() +IOTest = YamlExamplesTestSuite( + 'IOExamplesTest', os.path.join(YAML_DOCS_DIR, + '../transforms/io/*.yaml')).run() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/yaml/examples/io/spanner_read.yaml b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_read.yaml similarity index 73% rename from sdks/python/apache_beam/yaml/examples/io/spanner_read.yaml rename to sdks/python/apache_beam/yaml/examples/transforms/io/spanner_read.yaml index c86d42c1e0c6..26f68b68d931 100644 --- a/sdks/python/apache_beam/yaml/examples/io/spanner_read.yaml +++ b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_read.yaml @@ -18,10 +18,10 @@ pipeline: transforms: - # Reading data from a Spanner database. The table used here has the following columns: - # shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String) - # ReadFromSpanner transform is called using project_id, instance_id, database_id and a query - # A table with a list of columns can also be specified instead of a query + # Reading data from a Spanner database. The table used here has the following columns: + # shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String) + # ReadFromSpanner transform is called using project_id, instance_id, database_id and a query + # A table with a list of columns can also be specified instead of a query - type: ReadFromSpanner name: ReadShipments config: @@ -30,8 +30,8 @@ pipeline: database_id: 'shipment' query: 'SELECT * FROM shipments' - # Filtering the data based on a specific condition - # Here, the condition is used to keep only the rows where the customer_id is 'C1' + # Filtering the data based on a specific condition + # Here, the condition is used to keep only the rows where the customer_id is 'C1' - type: Filter name: FilterShipments input: ReadShipments @@ -39,9 +39,9 @@ pipeline: language: python keep: "customer_id == 'C1'" - # Mapping the data fields and applying transformations - # A new field 'shipment_cost_category' is added with a custom transformation - # A callable is defined to categorize shipment cost + # Mapping the data fields and applying transformations + # A new field 'shipment_cost_category' is added with a custom transformation + # A callable is defined to categorize shipment cost - type: MapToFields name: MapFieldsForSpanner input: FilterShipments @@ -65,7 +65,7 @@ pipeline: else: return 'High Cost' - # Writing the transformed data to a CSV file + # Writing the transformed data to a CSV file - type: WriteToCsv name: WriteBig input: MapFieldsForSpanner @@ -73,8 +73,7 @@ pipeline: path: shipments.csv - # On executing the above pipeline, a new CSV file is created with the following records - +# On executing the above pipeline, a new CSV file is created with the following records # Expected: # Row(shipment_id='S1', customer_id='C1', shipment_date='2023-05-01', shipment_cost=150.0, customer_name='Alice', customer_email='alice@example.com', shipment_cost_category='Medium Cost') # Row(shipment_id='S3', customer_id='C1', shipment_date='2023-05-10', shipment_cost=20.0, customer_name='Alice', customer_email='alice@example.com', shipment_cost_category='Low Cost') diff --git a/sdks/python/apache_beam/yaml/examples/io/spanner_write.yaml b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_write.yaml similarity index 69% rename from sdks/python/apache_beam/yaml/examples/io/spanner_write.yaml rename to sdks/python/apache_beam/yaml/examples/transforms/io/spanner_write.yaml index 74ac35de260f..1667fcfcc163 100644 --- a/sdks/python/apache_beam/yaml/examples/io/spanner_write.yaml +++ b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_write.yaml @@ -18,8 +18,8 @@ pipeline: transforms: - # Step 1: Creating rows to be written to Spanner - # The element names correspond to the column names in the Spanner table + # Step 1: Creating rows to be written to Spanner + # The element names correspond to the column names in the Spanner table - type: Create name: CreateRows config: @@ -31,10 +31,10 @@ pipeline: customer_name: "Erin" customer_email: "erin@example.com" - # Step 2: Writing the created rows to a Spanner database - # We require the project ID, instance ID, database ID and table ID to connect to Spanner - # Error handling can be specified optionally to ensure any failed operations aren't lost - # The failed data is passed on in the pipeline and can be handled + # Step 2: Writing the created rows to a Spanner database + # We require the project ID, instance ID, database ID and table ID to connect to Spanner + # Error handling can be specified optionally to ensure any failed operations aren't lost + # The failed data is passed on in the pipeline and can be handled - type: WriteToSpanner name: WriteSpanner input: CreateRows @@ -46,8 +46,11 @@ pipeline: error_handling: output: my_error_output - # Step 3: Writing the failed records to a JSON file + # Step 3: Writing the failed records to a JSON file - type: WriteToJson input: WriteSpanner.my_error_output config: path: errors.json + +# Expected: +# Row(shipment_id='S5', customer_id='C5', shipment_date='2023-05-09', shipment_cost=300.0, customer_name='Erin', customer_email='erin@example.com') diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py index 4719bc3e66aa..27e17029f387 100644 --- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py +++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py @@ -20,12 +20,17 @@ import itertools import re +import docstring_parser import yaml from apache_beam.portability.api import schema_pb2 +from apache_beam.typehints import schemas from apache_beam.utils import subprocess_server +from apache_beam.utils.python_callable import PythonCallableWithSource +from apache_beam.version import __version__ as beam_version from apache_beam.yaml import json_utils from apache_beam.yaml import yaml_provider +from apache_beam.yaml.yaml_errors import ErrorHandlingConfig def _singular(name): @@ -134,8 +139,29 @@ def maybe_row_parameters(t): def maybe_optional(t): return " (Optional)" if t.nullable else "" + def normalize_error_handling(f): + doc = docstring_parser.parse( + ErrorHandlingConfig.__doc__, docstring_parser.DocstringStyle.GOOGLE) + if f.name == "error_handling": + f = schema_pb2.Field( + name="error_handling", + type=schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[ + schemas.schema_field( + param.arg_name, + PythonCallableWithSource.load_from_expression( + param.type_name), + param.description) for param in doc.params + ])), + nullable=True), + description=f.description or doc.short_description) + return f + def lines(): for f in schema.fields: + f = normalize_error_handling(f) yield ''.join([ f'**{f.name}** `{pretty_type(f.type)}`', maybe_optional(f.type), @@ -164,11 +190,8 @@ def io_grouping_key(transform_name): return 0, transform_name -SKIP = [ - 'Combine', - 'Filter', - 'MapToFields', -] +# Exclude providers +SKIP = {} def transform_docs(transform_base, transforms, providers, extra_docs=''): @@ -211,7 +234,8 @@ def main(): options = parser.parse_args() include = re.compile(options.include).match exclude = ( - re.compile(options.exclude).match if options.exclude else lambda _: False) + re.compile(options.exclude).match + if options.exclude else lambda x: x in SKIP) with subprocess_server.SubprocessServer.cache_subprocesses(): json_config_schemas = [] @@ -284,42 +308,143 @@ def main(): markdown.extensions.toc.TocExtension(toc_depth=2), 'codehilite', ]) - html = md.convert(markdown_out.getvalue()) pygments_style = pygments.formatters.HtmlFormatter().get_style_defs( '.codehilite') extra_style = ''' - .nav { - height: 100%; - width: 12em; + * { + box-sizing: border-box; + } + body { + font-family: 'Roboto', sans-serif; + font-weight: normal; + color: #404040; + background: #edf0f2; + } + .body-for-nav { + background: #fcfcfc; + } + .grid-for-nav { + width: 100%; + } + .nav-side { position: fixed; top: 0; left: 0; - overflow-x: hidden; + width: 300px; + height: 100%; + padding-bottom: 2em; + color: #9b9b9b; + background: #343131; } - .nav a { - color: #333; - padding: .2em; + .nav-header { display: block; - text-decoration: none; + width: 300px; + padding: 1em; + background-color: #2980B9; + text-align: center; + color: #fcfcfc; + } + .nav-header a { + color: #fcfcfc; + font-weight: bold; + display: inline-block; + padding: 4px 6px; + margin-bottom: 1em; + text-decoration:none; } - .nav a:hover { - color: #888; + .nav-header>div.version { + margin-top: -.5em; + margin-bottom: 1em; + font-weight: normal; + color: rgba(255, 255, 255, 0.3); } - .nav li { - list-style-type: none; + .toc { + width: 300px; + text-align: left; + overflow-y: auto; + max-height: calc(100% - 4.3em); + scrollbar-width: thin; + scrollbar-color: #9b9b9b #343131; + } + .toc ul { margin: 0; padding: 0; + list-style: none; + } + .toc li { + border-bottom: 1px solid #4e4a4a; + margin-left: 1em; + } + .toc a { + display: block; + line-height: 36px; + font-size: 90%; + color: #d9d9d9; + padding: .1em 0.6em; + text-decoration: none; + transition: background-color 0.3s ease, color 0.3s ease; + } + .toc a:hover { + background-color: #4e4a4a; + color: #ffffff; + } + .transform-content-wrap { + margin-left: 300px; + background: #fcfcfc; + } + .transform-content { + padding: 1.5em 3em; + margin: 20px; + padding-bottom: 2em; + } + .transform-content li::marker { + display: inline-block; + width: 0.5em; + } + .transform-content h1 { + font-size: 40px; + } + .transform-content ul { + margin-left: 0.75em; + text-align: left; + list-style-type: disc; + } + hr { + color: gray; + display: block; + height: 1px; + border: 0; + border-top: 1px solid #e1e4e5; + margin-bottom: 3em; + margin-top: 3em; + padding: 0; } - .content { - margin-left: 12em; + .codehilite { + background: #f5f5f5; + border: 1px solid #ccc; + border-radius: 4px; + padding: 0.2em 1em; + overflow: auto; + font-family: monospace; + font-size: 14px; + line-height: 1.5; } - h2 { - margin-top: 2em; + p code, li code { + white-space: nowrap; + max-width: 100%; + background: #fff; + border: solid 1px #e1e4e5; + padding: 0 5px; + font-family: monospace; + color: #404040; + font-weight: bold; + padding: 2px 5px; } ''' - with open(options.html_file, 'w') as fout: - fout.write( + html = md.convert(markdown_out.getvalue()) + with open(options.html_file, 'w') as html_out: + html_out.write( f''' @@ -329,13 +454,23 @@ def main(): {extra_style} - - -
    -

    {title}

    - {html} + +
    + +
    +
    +

    {title}

    + {html.replace(' +
    diff --git a/sdks/python/apache_beam/yaml/integration_tests.py b/sdks/python/apache_beam/yaml/integration_tests.py index af1be7b1e8e5..72b3918195da 100644 --- a/sdks/python/apache_beam/yaml/integration_tests.py +++ b/sdks/python/apache_beam/yaml/integration_tests.py @@ -69,7 +69,7 @@ def temp_bigquery_table(project, prefix='yaml_bq_it_'): dataset_id = '%s_%s' % (prefix, uuid.uuid4().hex) bigquery_client.get_or_create_dataset(project, dataset_id) logging.info("Created dataset %s in project %s", dataset_id, project) - yield f'{project}:{dataset_id}.tmp_table' + yield f'{project}.{dataset_id}.tmp_table' request = bigquery.BigqueryDatasetsDeleteRequest( projectId=project, datasetId=dataset_id, deleteContents=True) logging.info("Deleting dataset %s in project %s", dataset_id, project) diff --git a/sdks/python/apache_beam/yaml/json_utils.py b/sdks/python/apache_beam/yaml/json_utils.py index 40e515ee6946..76cc80bc2036 100644 --- a/sdks/python/apache_beam/yaml/json_utils.py +++ b/sdks/python/apache_beam/yaml/json_utils.py @@ -106,6 +106,18 @@ def json_type_to_beam_type(json_type: Dict[str, Any]) -> schema_pb2.FieldType: raise ValueError(f'Unable to convert {json_type} to a Beam schema.') +def beam_schema_to_json_schema( + beam_schema: schema_pb2.Schema) -> Dict[str, Any]: + return { + 'type': 'object', + 'properties': { + field.name: beam_type_to_json_type(field.type) + for field in beam_schema.fields + }, + 'additionalProperties': False + } + + def beam_type_to_json_type(beam_type: schema_pb2.FieldType) -> Dict[str, Any]: type_info = beam_type.WhichOneof("type_info") if type_info == "atomic_type": @@ -267,3 +279,52 @@ def json_formater( convert = row_to_json( schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) return lambda row: json.dumps(convert(row), sort_keys=True).encode('utf-8') + + +def _validate_compatible(weak_schema, strong_schema): + if not weak_schema: + return + if weak_schema['type'] != strong_schema['type']: + raise ValueError( + 'Incompatible types: %r vs %r' % + (weak_schema['type'] != strong_schema['type'])) + if weak_schema['type'] == 'array': + _validate_compatible(weak_schema['items'], strong_schema['items']) + elif weak_schema == 'object': + for required in strong_schema.get('required', []): + if required not in weak_schema['properties']: + raise ValueError('Missing or unkown property %r' % required) + for name, spec in weak_schema.get('properties', {}): + if name in strong_schema['properties']: + try: + _validate_compatible(spec, strong_schema['properties'][name]) + except Exception as exn: + raise ValueError('Incompatible schema for %r' % name) from exn + elif not strong_schema.get('additionalProperties'): + raise ValueError( + 'Prohibited property: {property}; ' + 'perhaps additionalProperties: False is missing?') + + +def row_validator(beam_schema: schema_pb2.Schema, + json_schema: Dict[str, Any]) -> Callable[[Any], Any]: + """Returns a callable that will fail on elements not respecting json_schema. + """ + if not json_schema: + return lambda x: None + + # Validate that this compiles, but avoid pickling the validator itself. + _ = jsonschema.validators.validator_for(json_schema)(json_schema) + _validate_compatible(beam_schema_to_json_schema(beam_schema), json_schema) + validator = None + + convert = row_to_json( + schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) + + def validate(row): + nonlocal validator + if validator is None: + validator = jsonschema.validators.validator_for(json_schema)(json_schema) + validator.validate(convert(row)) + + return validate diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml index 4de36b3dc9e0..305e6877ad90 100644 --- a/sdks/python/apache_beam/yaml/standard_io.yaml +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -225,6 +225,7 @@ driver_jars: 'driver_jars' connection_properties: 'connection_properties' connection_init_sql: 'connection_init_sql' + batch_size: 'batch_size' 'ReadFromMySql': 'ReadFromJdbc' 'WriteToMySql': 'WriteToJdbc' 'ReadFromPostgres': 'ReadFromJdbc' @@ -235,21 +236,21 @@ 'WriteToSqlServer': 'WriteToJdbc' defaults: 'ReadFromMySql': - jdbcType: 'mysql' + jdbc_type: 'mysql' 'WriteToMySql': - jdbcType: 'mysql' + jdbc_type: 'mysql' 'ReadFromPostgres': - jdbcType: 'postgres' + jdbc_type: 'postgres' 'WriteToPostgres': - jdbcType: 'postgres' + jdbc_type: 'postgres' 'ReadFromOracle': - jdbcType: 'oracle' + jdbc_type: 'oracle' 'WriteToOracle': - jdbcType: 'oracle' + jdbc_type: 'oracle' 'ReadFromSqlServer': - jdbcType: 'mssql' + jdbc_type: 'mssql' 'WriteToSqlServer': - jdbcType: 'mssql' + jdbc_type: 'mssql' underlying_provider: type: beamJar transforms: @@ -271,6 +272,8 @@ table: 'table_id' query: 'query' columns: 'columns' + index: 'index' + batching: 'batching' 'WriteToSpanner': project: 'project_id' instance: 'instance_id' diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 574179805959..242faaa9a77b 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -101,3 +101,8 @@ Explode: "beam:schematransform:org.apache.beam:yaml:explode:v1" config: gradle_target: 'sdks:java:extensions:sql:expansion-service:shadowJar' + +- type: 'python' + config: {} + transforms: + Enrichment: 'apache_beam.yaml.yaml_enrichment.enrichment_transform' diff --git a/sdks/python/apache_beam/yaml/tests/enrichment.yaml b/sdks/python/apache_beam/yaml/tests/enrichment.yaml new file mode 100644 index 000000000000..6469c094b8b4 --- /dev/null +++ b/sdks/python/apache_beam/yaml/tests/enrichment.yaml @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +fixtures: + - name: BQ_TABLE + type: "apache_beam.yaml.integration_tests.temp_bigquery_table" + config: + project: "apache-beam-testing" + - name: TEMP_DIR + type: "apache_beam.yaml.integration_tests.gcs_temp_dir" + config: + bucket: "gs://temp-storage-for-end-to-end-tests/temp-it" + +pipelines: + - pipeline: + type: chain + transforms: + - type: Create + name: Rows + config: + elements: + - {label: '11a', rank: 0} + - {label: '37a', rank: 1} + - {label: '389a', rank: 2} + + - type: WriteToBigQuery + config: + table: "{BQ_TABLE}" + + - pipeline: + type: chain + transforms: + - type: Create + name: Data + config: + elements: + - {label: '11a', name: 'S1'} + - {label: '37a', name: 'S2'} + - {label: '389a', name: 'S3'} + - type: Enrichment + name: Enriched + config: + enrichment_handler: 'BigQuery' + handler_config: + project: apache-beam-testing + table_name: "{BQ_TABLE}" + fields: ['label'] + row_restriction_template: "label = '37a'" + timeout: 30 + + - type: MapToFields + config: + language: python + fields: + label: + callable: 'lambda x: x.label' + output_type: string + rank: + callable: 'lambda x: x.rank' + output_type: integer + name: + callable: 'lambda x: x.name' + output_type: string + + - type: AssertEqual + config: + elements: + - {label: '37a', rank: 1, name: 'S2'} + options: + yaml_experimental_features: [ 'Enrichment' ] \ No newline at end of file diff --git a/sdks/python/apache_beam/yaml/tests/map.yaml b/sdks/python/apache_beam/yaml/tests/map.yaml index b676966ad6bd..04f057cb2e82 100644 --- a/sdks/python/apache_beam/yaml/tests/map.yaml +++ b/sdks/python/apache_beam/yaml/tests/map.yaml @@ -30,6 +30,7 @@ pipelines: config: append: true fields: + # TODO(https://github.com/apache/beam/issues/32832): Figure out why Java sometimes re-orders these fields. named_field: element literal_int: 10 literal_float: 1.5 diff --git a/sdks/python/apache_beam/yaml/yaml_combine.py b/sdks/python/apache_beam/yaml/yaml_combine.py index a28bef52ea31..b7499f3b0c7a 100644 --- a/sdks/python/apache_beam/yaml/yaml_combine.py +++ b/sdks/python/apache_beam/yaml/yaml_combine.py @@ -94,6 +94,12 @@ class PyJsYamlCombine(beam.PTransform): See also the documentation on [YAML Aggregation](https://beam.apache.org/documentation/sdks/yaml-combine/). + + Args: + group_by: The field(s) to aggregate on. + combine: The aggregation function to use. + language: The language used to define (and execute) the + custom callables in `combine`. Defaults to generic. """ def __init__( self, diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py new file mode 100644 index 000000000000..c17f29bea90f --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any +from typing import Dict +from typing import Optional + +import apache_beam as beam +from apache_beam.yaml import options + +try: + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreEnrichmentHandler +except ImportError: + Enrichment = None # type: ignore + BigQueryEnrichmentHandler = None # type: ignore + BigTableEnrichmentHandler = None # type: ignore + VertexAIFeatureStoreEnrichmentHandler = None # type: ignore + +try: + from apache_beam.transforms.enrichment_handlers.feast_feature_store import FeastFeatureStoreEnrichmentHandler +except ImportError: + FeastFeatureStoreEnrichmentHandler = None # type: ignore + + +@beam.ptransform.ptransform_fn +def enrichment_transform( + pcoll, + enrichment_handler: str, + handler_config: Dict[str, Any], + timeout: Optional[float] = 30): + # pylint: disable=line-too-long + + """ + The Enrichment transform allows one to dynamically enhance elements in a + pipeline by performing key-value lookups against external services like + APIs or databases. + + Example using BigTable: :: + + - type: Enrichment + config: + enrichment_handler: 'BigTable' + handler_config: + project_id: 'apache-beam-testing' + instance_id: 'beam-test' + table_id: 'bigtable-enrichment-test' + row_key: 'product_id' + timeout: 30 + + For more information on Enrichment, see the [Beam docs]( + https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/). + + Args: + enrichment_handler (str): Specifies the source from where data needs + to be extracted into the pipeline for enriching data. One of + "BigQuery", "BigTable", "FeastFeatureStore" or "VertexAIFeatureStore". + handler_config (str): Specifies the parameters for the respective + enrichment_handler in a YAML/JSON format. To see the full set of + handler_config parameters, see their corresponding doc pages: + + - [BigQueryEnrichmentHandler](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigquery.html#apache_beam.transforms.enrichment_handlers.bigquery.BigQueryEnrichmentHandler) + - [BigTableEnrichmentHandler](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler) + - [FeastFeatureStoreEnrichmentHandler](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.feast_feature_store.html#apache_beam.transforms.enrichment_handlers.feast_feature_store.FeastFeatureStoreEnrichmentHandler) + - [VertexAIFeatureStoreEnrichmentHandler](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store.html#apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store.VertexAIFeatureStoreEnrichmentHandler) + timeout (float): Timeout for source requests in seconds. Defaults to 30 + seconds. + """ + options.YamlOptions.check_enabled(pcoll.pipeline, 'Enrichment') + + if not Enrichment: + raise ValueError( + f"gcp dependencies not installed. Cannot use {enrichment_handler} " + f"handler. Please install using 'pip install apache-beam[gcp]'.") + + if (enrichment_handler == 'FeastFeatureStore' and + not FeastFeatureStoreEnrichmentHandler): + raise ValueError( + "FeastFeatureStore handler requires 'feast' package to be installed. " + + "Please install using 'pip install feast[gcp]' and try again.") + + handler_map = { + 'BigQuery': BigQueryEnrichmentHandler, + 'BigTable': BigTableEnrichmentHandler, + 'FeastFeatureStore': FeastFeatureStoreEnrichmentHandler, + 'VertexAIFeatureStore': VertexAIFeatureStoreEnrichmentHandler + } + + if enrichment_handler not in handler_map: + raise ValueError(f"Unknown enrichment source: {enrichment_handler}") + + handler = handler_map[enrichment_handler](**handler_config) + return pcoll | Enrichment(source_handler=handler, timeout=timeout) diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py new file mode 100644 index 000000000000..e26d6140af23 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import unittest + +import mock + +import apache_beam as beam +from apache_beam import Row +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.yaml.yaml_transform import YamlTransform + + +class FakeEnrichmentTransform: + def __init__(self, enrichment_handler, handler_config, timeout=30): + self._enrichment_handler = enrichment_handler + self._handler_config = handler_config + self._timeout = timeout + + def __call__(self, enrichment_handler, *, handler_config, timeout=30): + assert enrichment_handler == self._enrichment_handler + assert handler_config == self._handler_config + assert timeout == self._timeout + return beam.Map(lambda x: beam.Row(**x._asdict())) + + +class EnrichmentTransformTest(unittest.TestCase): + def test_enrichment_with_bigquery(self): + input_data = [ + Row(label="item1", rank=0), + Row(label="item2", rank=1), + ] + + handler = 'BigQuery' + config = { + "project": "apache-beam-testing", + "table_name": "project.database.table", + "row_restriction_template": "label='item1' or label='item2'", + "fields": ["label"] + } + + with beam.Pipeline() as p: + with mock.patch('apache_beam.yaml.yaml_enrichment.enrichment_transform', + FakeEnrichmentTransform(enrichment_handler=handler, + handler_config=config)): + input_pcoll = p | 'CreateInput' >> beam.Create(input_data) + result = input_pcoll | YamlTransform( + f''' + type: Enrichment + config: + enrichment_handler: {handler} + handler_config: {config} + ''') + assert_that(result, equal_to(input_data)) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/yaml/yaml_errors.py b/sdks/python/apache_beam/yaml/yaml_errors.py new file mode 100644 index 000000000000..dace44ca09f6 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_errors.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import inspect +from typing import NamedTuple + +import apache_beam as beam +from apache_beam.typehints.row_type import RowTypeConstraint + + +class ErrorHandlingConfig(NamedTuple): + """This option specifies whether and where to output error rows. + + Args: + output (str): Name to use for the output error collection + """ + output: str + # TODO: Other parameters are valid here too, but not common to Java. + + +def exception_handling_args(error_handling_spec): + if error_handling_spec: + return { + 'dead_letter_tag' if k == 'output' else k: v + for (k, v) in error_handling_spec.items() + } + else: + return None + + +def map_errors_to_standard_format(input_type): + # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple. + + return beam.Map( + lambda x: beam.Row( + element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2])) + ).with_output_types( + RowTypeConstraint.from_fields([("element", input_type), ("msg", str), + ("stack", str)])) + + +def maybe_with_exception_handling(inner_expand): + def expand(self, pcoll): + wrapped_pcoll = beam.core._MaybePValueWithErrors( + pcoll, self._exception_handling_args) + return inner_expand(self, wrapped_pcoll).as_result( + map_errors_to_standard_format(pcoll.element_type)) + + return expand + + +def maybe_with_exception_handling_transform_fn(transform_fn): + @functools.wraps(transform_fn) + def expand(pcoll, error_handling=None, **kwargs): + wrapped_pcoll = beam.core._MaybePValueWithErrors( + pcoll, exception_handling_args(error_handling)) + return transform_fn(wrapped_pcoll, **kwargs).as_result( + map_errors_to_standard_format(pcoll.element_type)) + + original_signature = inspect.signature(transform_fn) + new_parameters = list(original_signature.parameters.values()) + error_handling_param = inspect.Parameter( + 'error_handling', + inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=ErrorHandlingConfig) + if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: + new_parameters.insert(-1, error_handling_param) + else: + new_parameters.append(error_handling_param) + expand.__signature__ = original_signature.replace(parameters=new_parameters) + + return expand diff --git a/sdks/python/apache_beam/yaml/yaml_io.py b/sdks/python/apache_beam/yaml/yaml_io.py index 22663bdb8461..a6525aef9877 100644 --- a/sdks/python/apache_beam/yaml/yaml_io.py +++ b/sdks/python/apache_beam/yaml/yaml_io.py @@ -45,7 +45,7 @@ from apache_beam.portability.api import schema_pb2 from apache_beam.typehints import schemas from apache_beam.yaml import json_utils -from apache_beam.yaml import yaml_mapping +from apache_beam.yaml import yaml_errors from apache_beam.yaml import yaml_provider @@ -289,7 +289,7 @@ def formatter(row): @beam.ptransform_fn -@yaml_mapping.maybe_with_exception_handling_transform_fn +@yaml_errors.maybe_with_exception_handling_transform_fn def read_from_pubsub( root, *, @@ -393,7 +393,7 @@ def mapper(msg): @beam.ptransform_fn -@yaml_mapping.maybe_with_exception_handling_transform_fn +@yaml_errors.maybe_with_exception_handling_transform_fn def write_to_pubsub( pcoll, *, diff --git a/sdks/python/apache_beam/yaml/yaml_join.py b/sdks/python/apache_beam/yaml/yaml_join.py index 5124ef56b49c..b22e452b27f9 100644 --- a/sdks/python/apache_beam/yaml/yaml_join.py +++ b/sdks/python/apache_beam/yaml/yaml_join.py @@ -62,9 +62,11 @@ def _validate_equalities(equalities, pcolls): error_prefix = f'Invalid value "{equalities}" for "equalities".' valid_cols = { - name: set(dict(pcoll.element_type._fields).keys()) - for name, - pcoll in pcolls.items() + name: set( + dict(fields).keys() if fields and all( + isinstance(field, tuple) for field in fields) else fields) + for (name, pcoll) in pcolls.items() + for fields in [getattr(pcoll.element_type, '_fields', [])] } if isinstance(equalities, str): diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 377bcac0e31a..8f4a2118c236 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -16,8 +16,6 @@ # """This module defines the basic MapToFields operation.""" -import functools -import inspect import itertools import re from collections import abc @@ -27,7 +25,6 @@ from typing import Dict from typing import List from typing import Mapping -from typing import NamedTuple from typing import Optional from typing import TypeVar from typing import Union @@ -41,12 +38,17 @@ from apache_beam.typehints import trivial_inference from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import convert_to_beam_type -from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import named_fields_from_element_type +from apache_beam.typehints.schemas import schema_from_element_type +from apache_beam.typehints.schemas import typing_from_runner_api from apache_beam.utils import python_callable from apache_beam.yaml import json_utils from apache_beam.yaml import options from apache_beam.yaml import yaml_provider +from apache_beam.yaml.yaml_errors import exception_handling_args +from apache_beam.yaml.yaml_errors import map_errors_to_standard_format +from apache_beam.yaml.yaml_errors import maybe_with_exception_handling +from apache_beam.yaml.yaml_errors import maybe_with_exception_handling_transform_fn from apache_beam.yaml.yaml_provider import dicts_to_rows # Import js2py package if it exists @@ -416,63 +418,102 @@ def checking_func(row): return func -class ErrorHandlingConfig(NamedTuple): - output: str - # TODO: Other parameters are valid here too, but not common to Java. +class _StripErrorMetadata(beam.PTransform): + """Strips error metadata from outputs returned via error handling. + Generally the error outputs for transformations return information about + the error encountered (e.g. error messages and tracebacks) in addition to the + failing element itself. This transformation attempts to remove that metadata + and returns the bad element alone which can be useful for re-processing. -def exception_handling_args(error_handling_spec): - if error_handling_spec: - return { - 'dead_letter_tag' if k == 'output' else k: v - for (k, v) in error_handling_spec.items() - } - else: - return None + For example, in the following pipeline snippet:: + + - name: MyMappingTransform + type: MapToFields + input: SomeInput + config: + language: python + fields: + ... + error_handling: + output: errors + + - name: RecoverOriginalElements + type: StripErrorMetadata + input: MyMappingTransform.errors + the output of `RecoverOriginalElements` will contain exactly those elements + from SomeInput that failed to processes (whereas `MyMappingTransform.errors` + would contain those elements paired with error information). -def _map_errors_to_standard_format(input_type): - # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple. + Note that this relies on the preceding transform actually returning the + failing input in a schema'd way. Most built-in transformation follow the + correct conventions. + """ - return beam.Map( - lambda x: beam.Row(element=x[0], msg=str(x[1][1]), stack=str(x[1][2])) - ).with_output_types( - RowTypeConstraint.from_fields([("element", input_type), ("msg", str), - ("stack", str)])) + _ERROR_FIELD_NAMES = ('failed_row', 'element', 'record') + def __init__(self): + super().__init__(label=None) -def maybe_with_exception_handling(inner_expand): def expand(self, pcoll): - wrapped_pcoll = beam.core._MaybePValueWithErrors( - pcoll, self._exception_handling_args) - return inner_expand(self, wrapped_pcoll).as_result( - _map_errors_to_standard_format(pcoll.element_type)) - - return expand - - -def maybe_with_exception_handling_transform_fn(transform_fn): - @functools.wraps(transform_fn) - def expand(pcoll, error_handling=None, **kwargs): - wrapped_pcoll = beam.core._MaybePValueWithErrors( - pcoll, exception_handling_args(error_handling)) - return transform_fn(wrapped_pcoll, **kwargs).as_result( - _map_errors_to_standard_format(pcoll.element_type)) - - original_signature = inspect.signature(transform_fn) - new_parameters = list(original_signature.parameters.values()) - error_handling_param = inspect.Parameter( - 'error_handling', - inspect.Parameter.KEYWORD_ONLY, - default=None, - annotation=ErrorHandlingConfig) - if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: - new_parameters.insert(-1, error_handling_param) - else: - new_parameters.append(error_handling_param) - expand.__signature__ = original_signature.replace(parameters=new_parameters) + try: + existing_fields = { + fld.name: fld.type + for fld in schema_from_element_type(pcoll.element_type).fields + } + except TypeError: + fld = None + else: + for fld in self._ERROR_FIELD_NAMES: + if fld in existing_fields: + break + else: + raise ValueError( + 'The input to this transform does not appear to be an error ' + + "output. Expected a schema'd input with a field named " + + ' or '.join(repr(fld) for fld in self._ERROR_FIELD_NAMES)) + + if fld is None: + # This handles with_exception_handling() that returns bare tuples. + return pcoll | beam.Map(lambda x: x[0]) + else: + return pcoll | beam.Map(lambda x: getattr(x, fld)).with_output_types( + typing_from_runner_api(existing_fields[fld])) + + +class _Validate(beam.PTransform): + """Validates each element of a PCollection against a json schema. + + Args: + schema: A json schema against which to validate each element. + error_handling: Whether and how to handle errors during iteration. + If this is not set, invalid elements will fail the pipeline, otherwise + invalid elements will be passed to the specified error output along + with information about how the schema was invalidated. + """ + def __init__( + self, + schema: Dict[str, Any], + error_handling: Optional[Mapping[str, Any]] = None): + self._schema = schema + self._exception_handling_args = exception_handling_args(error_handling) + + @maybe_with_exception_handling + def expand(self, pcoll): + validator = json_utils.row_validator( + schema_from_element_type(pcoll.element_type), self._schema) + + def invoke_validator(x): + validator(x) + return x + + return pcoll | beam.Map(invoke_validator) - return expand + def with_exception_handling(self, **kwargs): + # It's possible there's an error in iteration... + self._exception_handling_args = kwargs + return self class _Explode(beam.PTransform): @@ -744,9 +785,8 @@ def split(element): splits = pcoll | mapping_transform.with_input_types(T).with_output_types(T) result = {out: getattr(splits, out) for out in output_set} if error_output: - result[ - error_output] = result[error_output] | _map_errors_to_standard_format( - pcoll.element_type) + result[error_output] = result[error_output] | map_errors_to_standard_format( + pcoll.element_type) return result @@ -797,6 +837,8 @@ def create_mapping_providers(): 'Partition-python': _Partition, 'Partition-javascript': _Partition, 'Partition-generic': _Partition, + 'StripErrorMetadata': _StripErrorMetadata, + 'ValidateWithSchema': _Validate, }), yaml_provider.SqlBackedProvider({ 'Filter-sql': _SqlFilterTransform, diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index 1b74a765e54b..2c5feec18278 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -134,6 +134,43 @@ def test_explode(self): beam.Row(a=3, b='y', c=.125, range=2), ])) + def test_validate(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(key='good', small=[5], nested=beam.Row(big=100)), + beam.Row(key='bad1', small=[500], nested=beam.Row(big=100)), + beam.Row(key='bad2', small=[5], nested=beam.Row(big=1)), + ]) + result = elements | YamlTransform( + ''' + type: ValidateWithSchema + config: + schema: + type: object + properties: + small: + type: array + items: + type: integer + maximum: 10 + nested: + type: object + properties: + big: + type: integer + minimum: 10 + error_handling: + output: bad + ''') + + assert_that( + result['good'] | beam.Map(lambda x: x.key), equal_to(['good'])) + assert_that( + result['bad'] | beam.Map(lambda x: x.element.key), + equal_to(['bad1', 'bad2']), + label='Errors') + def test_validate_explicit_types(self): with self.assertRaisesRegex(TypeError, r'.*violates schema.*'): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index c2cba936abce..a07638953551 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -63,6 +63,7 @@ from apache_beam.utils import subprocess_server from apache_beam.version import __version__ as beam_version from apache_beam.yaml import json_utils +from apache_beam.yaml.yaml_errors import maybe_with_exception_handling_transform_fn class Provider: @@ -117,7 +118,7 @@ def affinity(self, other: "Provider"): (e.g. to encourage fusion). """ # TODO(yaml): This is a very rough heuristic. Consider doing better. - # E.g. we could look at the the expected environments themselves. + # E.g. we could look at the expected environments themselves. # Possibly, we could provide multiple expansions and have the runner itself # choose the actual implementation based on fusion (and other) criteria. a = self.underlying_provider() @@ -876,8 +877,10 @@ def _parse_window_spec(spec): return beam.WindowInto(window_fn) @staticmethod + @beam.ptransform_fn + @maybe_with_exception_handling_transform_fn def log_for_testing( - level: Optional[str] = 'INFO', prefix: Optional[str] = ''): + pcoll, *, level: Optional[str] = 'INFO', prefix: Optional[str] = ''): """Logs each element of its input PCollection. The output of this transform is a copy of its input for ease of use in @@ -918,7 +921,7 @@ def log_and_return(x): logger(prefix + json.dumps(to_loggable_json_recursive(x))) return x - return "LogForTesting" >> beam.Map(log_and_return) + return pcoll | "LogForTesting" >> beam.Map(log_and_return) @staticmethod def create_builtin_provider(): diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index ab86a2aaff56..b8e49e81c579 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -493,7 +493,7 @@ def expand_leaf_transform(spec, scope): outputs = inputs | scope.unique_name(spec, ptransform) >> ptransform except Exception as exn: raise ValueError( - f"Error apply transform {identify_object(spec)}: {exn}") from exn + f"Error applying transform {identify_object(spec)}: {exn}") from exn if isinstance(outputs, dict): # TODO: Handle (or at least reject) nested case. return outputs @@ -576,8 +576,8 @@ def is_not_output_of_last_transform(new_transforms, value): pass else: raise ValueError( - f'Transform {identify_object(transform)} is part of a chain, ' - 'must have implicit inputs and outputs.') + f'Transform {identify_object(transform)} is part of a chain. ' + 'Cannot define explicit inputs on chain pipeline') if ix == 0: if is_explicitly_empty(transform.get('input', None)): pass diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py index f00403b07e2a..2a5a96aa42df 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py @@ -72,10 +72,12 @@ def test_get_pcollection_output(self): str(scope.get_pcollection("Create"))) self.assertEqual( - "PCollection[Square.None]", str(scope.get_pcollection("Square"))) + "PCollection[Square/LogForTesting.None]", + str(scope.get_pcollection("Square"))) self.assertEqual( - "PCollection[Square.None]", str(scope.get_pcollection("LogForTesting"))) + "PCollection[Square/LogForTesting.None]", + str(scope.get_pcollection("LogForTesting"))) self.assertTrue( scope.get_pcollection("Square") == scope.get_pcollection( diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index fbdae6679e96..7fcea7e2b662 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -401,6 +401,51 @@ def test_error_handling_outputs(self): assert_that(result['good'], equal_to(['a', 'b']), label="CheckGood") assert_that(result['bad'], equal_to(["ValueError('biiiiig')"])) + def test_strip_error_metadata(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + result = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + config: + elements: ['a', 'b', 'biiiiig'] + + - type: SizeLimiter + input: Create + config: + limit: 5 + error_handling: + output: errors + - type: StripErrorMetadata + name: StripErrorMetadata1 + input: SizeLimiter.errors + + - type: MapToFields + input: Create + config: + language: python + fields: + out: "1/(1-len(element))" + error_handling: + output: errors + - type: StripErrorMetadata + name: StripErrorMetadata2 + input: MapToFields.errors + + output: + good: SizeLimiter + bad1: StripErrorMetadata1 + bad2: StripErrorMetadata2 + ''', + providers=TEST_PROVIDERS) + assert_that(result['good'], equal_to(['a', 'b']), label="CheckGood") + assert_that( + result['bad1'] | beam.Map(lambda x: x.element), equal_to(['biiiiig'])) + assert_that( + result['bad2'] | beam.Map(lambda x: x.element), equal_to(['a', 'b'])) + def test_must_handle_error_output(self): with self.assertRaisesRegex(Exception, 'Unconsumed error output .*line 7'): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 8c4b00351b24..084e03cdb197 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -213,7 +213,7 @@ def test_expand_composite_transform_with_name_input(self): inputs={'elements': elements}) self.assertRegex( str(expand_composite_transform(spec, scope)['output']), - r"PCollection.*Composite/LogForTesting.*") + r"PCollection.*Composite/log_for_testing/LogForTesting.*") def test_expand_composite_transform_root(self): with new_pipeline() as p: @@ -325,6 +325,23 @@ def test_chain_as_composite_with_input(self): self.assertEqual( chain_as_composite(spec)['transforms'][0]['input'], {"input": "input"}) + def test_chain_as_composite_with_transform_input(self): + spec = ''' + type: chain + transforms: + - type: Create + config: + elements: [0,1,2] + - type: LogForTesting + input: Create + ''' + spec = yaml.load(spec, Loader=SafeLineLoader) + with self.assertRaisesRegex( + ValueError, + r"Transform .* is part of a chain. " + r"Cannot define explicit inputs on chain pipeline"): + chain_as_composite(spec) + def test_normalize_source_sink(self): spec = ''' source: diff --git a/sdks/python/container/Dockerfile-distroless b/sdks/python/container/Dockerfile-distroless new file mode 100644 index 000000000000..2799bd2be81f --- /dev/null +++ b/sdks/python/container/Dockerfile-distroless @@ -0,0 +1,50 @@ +############################################################################### +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + +ARG BASE +FROM ${BASE} AS base +ENV LANG=C.UTF8 +ARG TARGETARCH +LABEL Author="Apache Beam " + +RUN if [ -z "${TARGETARCH}" ]; then echo "fatal: TARGETARCH not set; run as docker buildx build or use --build-arg=TARGETARCH=amd64|arm64" >&2; exit 1; fi + +FROM gcr.io/distroless/python3-debian12:latest-${TARGETARCH} AS distroless + +# Prevents internal errors found with distroless container images and Flex templates. +COPY --from=base /usr/lib/locale /usr/lib/locale + +# Contains header files needed by the Python interpreter. +COPY --from=base /usr/local/include /usr/local/include + +# Contains the Python interpreter executables. +COPY --from=base /usr/local/bin /usr/local/bin + +# Contains the Python library dependencies. +COPY --from=base /usr/local/lib /usr/local/lib + +# Python standard library modules. +COPY --from=base /usr/lib/python* /usr/lib/. + +# Contains the boot entrypoint and related files such as licenses. +COPY --from=base /opt /opt + +ENV PATH "$PATH:/usr/local/bin" + +# Despite the ENTRYPOINT set in base image, need to reset since deriving the layer derives from a different image. +ENTRYPOINT ["/opt/apache/beam/boot"] diff --git a/sdks/python/container/boot.go b/sdks/python/container/boot.go index 696604c64886..b7cbc07dca68 100644 --- a/sdks/python/container/boot.go +++ b/sdks/python/container/boot.go @@ -189,7 +189,12 @@ func launchSDKProcess() error { fmtErr := fmt.Errorf("failed to retrieve staged files: %v", err) // Send error message to logging service before returning up the call stack logger.Errorf(ctx, fmtErr.Error()) - return fmtErr + // No need to fail the job if submission_environment_dependencies.txt cannot be loaded + if strings.Contains(fmtErr.Error(), "submission_environment_dependencies.txt") { + logger.Printf(ctx, "Ignore the error when loading submission_environment_dependencies.txt.") + } else { + return fmtErr + } } // TODO(herohde): the packages to install should be specified explicitly. It diff --git a/sdks/python/container/build.gradle b/sdks/python/container/build.gradle index f07b6f743fa4..14c08a3a539b 100644 --- a/sdks/python/container/build.gradle +++ b/sdks/python/container/build.gradle @@ -20,7 +20,7 @@ plugins { id 'org.apache.beam.module' } applyGoNature() description = "Apache Beam :: SDKs :: Python :: Container" -int min_python_version=8 +int min_python_version=9 int max_python_version=12 configurations { diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index 435d20d383ec..3442b92f3583 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -22,23 +22,23 @@ # Reach out to a committer if you need help. annotated-types==0.7.0 -async-timeout==4.0.3 +async-timeout==5.0.1 attrs==24.2.0 backports.tarfile==1.2.0 beautifulsoup4==4.12.3 bs4==0.0.2 -build==1.2.2 +build==1.2.2.post1 cachetools==5.5.0 certifi==2024.8.30 cffi==1.17.1 -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 click==8.1.7 cloudpickle==2.2.1 -cramjam==2.8.4 +cramjam==2.9.0 crcmod==1.7 -cryptography==43.0.1 +cryptography==43.0.3 Cython==3.0.11 -Deprecated==1.2.14 +Deprecated==1.2.15 deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 @@ -51,31 +51,31 @@ fastavro==1.9.7 fasteners==0.19 freezegun==1.5.1 future==1.0.0 -google-api-core==2.20.0 -google-api-python-client==2.147.0 +google-api-core==2.23.0 +google-api-python-client==2.153.0 google-apitools==0.5.31 -google-auth==2.35.0 +google-auth==2.36.0 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.69.0 -google-cloud-bigquery==3.26.0 -google-cloud-bigquery-storage==2.26.0 -google-cloud-bigtable==2.26.0 +google-cloud-aiplatform==1.72.0 +google-cloud-bigquery==3.27.0 +google-cloud-bigquery-storage==2.27.0 +google-cloud-bigtable==2.27.0 google-cloud-core==2.4.1 google-cloud-datastore==2.20.1 -google-cloud-dlp==3.23.0 -google-cloud-language==2.14.0 +google-cloud-dlp==3.25.1 +google-cloud-language==2.15.1 google-cloud-profiler==4.1.0 -google-cloud-pubsub==2.25.2 +google-cloud-pubsub==2.27.1 google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.12 -google-cloud-resource-manager==1.12.5 -google-cloud-spanner==3.49.1 +google-cloud-recommendations-ai==0.10.14 +google-cloud-resource-manager==1.13.1 +google-cloud-spanner==3.50.1 google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.13.5 -google-cloud-vision==3.7.4 +google-cloud-videointelligence==2.14.1 +google-cloud-vision==3.8.1 google-crc32c==1.6.0 google-resumable-media==2.7.2 -googleapis-common-protos==1.65.0 +googleapis-common-protos==1.66.0 greenlet==3.1.1 grpc-google-iam-v1==0.13.1 grpc-interceptor==0.15.4 @@ -84,9 +84,9 @@ grpcio-status==1.62.3 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.112.3 +hypothesis==6.119.1 idna==3.10 -importlib_metadata==8.4.0 +importlib_metadata==8.5.0 iniconfig==2.0.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 @@ -94,12 +94,12 @@ jaraco.functools==4.1.0 jeepney==0.8.0 Jinja2==3.1.4 joblib==1.4.2 -jsonpickle==3.3.0 +jsonpickle==3.4.2 jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -keyring==25.4.1 +jsonschema-specifications==2024.10.1 +keyring==25.5.0 keyrings.google-artifactregistry-auth==1.1.2 -MarkupSafe==2.1.5 +MarkupSafe==3.0.2 mmh3==5.0.1 mock==5.1.0 more-itertools==10.5.0 @@ -108,16 +108,16 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -opentelemetry-api==1.27.0 -opentelemetry-sdk==1.27.0 -opentelemetry-semantic-conventions==0.48b0 -orjson==3.10.7 +opentelemetry-api==1.28.1 +opentelemetry-sdk==1.28.1 +opentelemetry-semantic-conventions==0.49b1 +orjson==3.10.11 overrides==7.7.0 -packaging==24.1 +packaging==24.2 pandas==2.1.4 parameterized==0.9.0 pluggy==1.5.0 -proto-plus==1.24.0 +proto-plus==1.25.0 protobuf==4.25.5 psycopg2-binary==2.9.9 pyarrow==16.1.0 @@ -131,7 +131,7 @@ pydot==1.4.2 PyHamcrest==2.1.0 pymongo==4.10.1 PyMySQL==1.1.1 -pyparsing==3.1.4 +pyparsing==3.2.0 pyproject_hooks==1.2.0 pytest==7.4.4 pytest-timeout==2.3.1 @@ -140,12 +140,12 @@ python-dateutil==2.9.0.post0 python-snappy==0.7.3 pytz==2024.2 PyYAML==6.0.2 -redis==5.1.1 +redis==5.2.0 referencing==0.35.1 -regex==2024.9.11 +regex==2024.11.6 requests==2.32.3 requests-mock==1.12.1 -rpds-py==0.20.0 +rpds-py==0.21.0 rsa==4.9 scikit-learn==1.5.2 scipy==1.14.1 @@ -154,17 +154,17 @@ shapely==2.0.6 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.6 -SQLAlchemy==2.0.35 -sqlparse==0.5.1 +SQLAlchemy==2.0.36 +sqlparse==0.5.2 tenacity==8.5.0 testcontainers==3.7.1 threadpoolctl==3.5.0 -tomli==2.0.2 -tqdm==4.66.5 +tomli==2.1.0 +tqdm==4.67.0 typing_extensions==4.12.2 tzdata==2024.2 uritemplate==4.1.1 urllib3==2.2.3 wrapt==1.16.0 -zipp==3.20.2 +zipp==3.21.0 zstandard==0.23.0 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 07c869c31d75..93f579b14dd8 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -26,18 +26,18 @@ attrs==24.2.0 backports.tarfile==1.2.0 beautifulsoup4==4.12.3 bs4==0.0.2 -build==1.2.2 +build==1.2.2.post1 cachetools==5.5.0 certifi==2024.8.30 cffi==1.17.1 -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 click==8.1.7 cloudpickle==2.2.1 -cramjam==2.8.4 +cramjam==2.9.0 crcmod==1.7 -cryptography==43.0.1 +cryptography==43.0.3 Cython==3.0.11 -Deprecated==1.2.14 +Deprecated==1.2.15 deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 @@ -49,31 +49,31 @@ fastavro==1.9.7 fasteners==0.19 freezegun==1.5.1 future==1.0.0 -google-api-core==2.20.0 -google-api-python-client==2.147.0 +google-api-core==2.23.0 +google-api-python-client==2.153.0 google-apitools==0.5.31 -google-auth==2.35.0 +google-auth==2.36.0 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.69.0 -google-cloud-bigquery==3.26.0 -google-cloud-bigquery-storage==2.26.0 -google-cloud-bigtable==2.26.0 +google-cloud-aiplatform==1.72.0 +google-cloud-bigquery==3.27.0 +google-cloud-bigquery-storage==2.27.0 +google-cloud-bigtable==2.27.0 google-cloud-core==2.4.1 google-cloud-datastore==2.20.1 -google-cloud-dlp==3.23.0 -google-cloud-language==2.14.0 +google-cloud-dlp==3.25.1 +google-cloud-language==2.15.1 google-cloud-profiler==4.1.0 -google-cloud-pubsub==2.25.2 +google-cloud-pubsub==2.27.1 google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.12 -google-cloud-resource-manager==1.12.5 -google-cloud-spanner==3.49.1 +google-cloud-recommendations-ai==0.10.14 +google-cloud-resource-manager==1.13.1 +google-cloud-spanner==3.50.1 google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.13.5 -google-cloud-vision==3.7.4 +google-cloud-videointelligence==2.14.1 +google-cloud-vision==3.8.1 google-crc32c==1.6.0 google-resumable-media==2.7.2 -googleapis-common-protos==1.65.0 +googleapis-common-protos==1.66.0 greenlet==3.1.1 grpc-google-iam-v1==0.13.1 grpc-interceptor==0.15.4 @@ -82,9 +82,9 @@ grpcio-status==1.62.3 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.112.3 +hypothesis==6.119.1 idna==3.10 -importlib_metadata==8.4.0 +importlib_metadata==8.5.0 iniconfig==2.0.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 @@ -92,12 +92,12 @@ jaraco.functools==4.1.0 jeepney==0.8.0 Jinja2==3.1.4 joblib==1.4.2 -jsonpickle==3.3.0 +jsonpickle==3.4.2 jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -keyring==25.4.1 +jsonschema-specifications==2024.10.1 +keyring==25.5.0 keyrings.google-artifactregistry-auth==1.1.2 -MarkupSafe==2.1.5 +MarkupSafe==3.0.2 mmh3==5.0.1 mock==5.1.0 more-itertools==10.5.0 @@ -106,16 +106,16 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -opentelemetry-api==1.27.0 -opentelemetry-sdk==1.27.0 -opentelemetry-semantic-conventions==0.48b0 -orjson==3.10.7 +opentelemetry-api==1.28.1 +opentelemetry-sdk==1.28.1 +opentelemetry-semantic-conventions==0.49b1 +orjson==3.10.11 overrides==7.7.0 -packaging==24.1 +packaging==24.2 pandas==2.1.4 parameterized==0.9.0 pluggy==1.5.0 -proto-plus==1.24.0 +proto-plus==1.25.0 protobuf==4.25.5 psycopg2-binary==2.9.9 pyarrow==16.1.0 @@ -129,7 +129,7 @@ pydot==1.4.2 PyHamcrest==2.1.0 pymongo==4.10.1 PyMySQL==1.1.1 -pyparsing==3.1.4 +pyparsing==3.2.0 pyproject_hooks==1.2.0 pytest==7.4.4 pytest-timeout==2.3.1 @@ -138,12 +138,12 @@ python-dateutil==2.9.0.post0 python-snappy==0.7.3 pytz==2024.2 PyYAML==6.0.2 -redis==5.1.1 +redis==5.2.0 referencing==0.35.1 -regex==2024.9.11 +regex==2024.11.6 requests==2.32.3 requests-mock==1.12.1 -rpds-py==0.20.0 +rpds-py==0.21.0 rsa==4.9 scikit-learn==1.5.2 scipy==1.14.1 @@ -152,16 +152,16 @@ shapely==2.0.6 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.6 -SQLAlchemy==2.0.35 -sqlparse==0.5.1 +SQLAlchemy==2.0.36 +sqlparse==0.5.2 tenacity==8.5.0 testcontainers==3.7.1 threadpoolctl==3.5.0 -tqdm==4.66.5 +tqdm==4.67.0 typing_extensions==4.12.2 tzdata==2024.2 uritemplate==4.1.1 urllib3==2.2.3 wrapt==1.16.0 -zipp==3.20.2 +zipp==3.21.0 zstandard==0.23.0 diff --git a/sdks/python/container/py312/base_image_requirements.txt b/sdks/python/container/py312/base_image_requirements.txt index 79e23ef25852..069005318cdb 100644 --- a/sdks/python/container/py312/base_image_requirements.txt +++ b/sdks/python/container/py312/base_image_requirements.txt @@ -25,18 +25,18 @@ annotated-types==0.7.0 attrs==24.2.0 beautifulsoup4==4.12.3 bs4==0.0.2 -build==1.2.2 +build==1.2.2.post1 cachetools==5.5.0 certifi==2024.8.30 cffi==1.17.1 -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 click==8.1.7 cloudpickle==2.2.1 -cramjam==2.8.4 +cramjam==2.9.0 crcmod==1.7 -cryptography==43.0.1 +cryptography==43.0.3 Cython==3.0.11 -Deprecated==1.2.14 +Deprecated==1.2.15 deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 @@ -48,31 +48,31 @@ fastavro==1.9.7 fasteners==0.19 freezegun==1.5.1 future==1.0.0 -google-api-core==2.20.0 -google-api-python-client==2.147.0 +google-api-core==2.23.0 +google-api-python-client==2.153.0 google-apitools==0.5.31 -google-auth==2.35.0 +google-auth==2.36.0 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.69.0 -google-cloud-bigquery==3.26.0 -google-cloud-bigquery-storage==2.26.0 -google-cloud-bigtable==2.26.0 +google-cloud-aiplatform==1.72.0 +google-cloud-bigquery==3.27.0 +google-cloud-bigquery-storage==2.27.0 +google-cloud-bigtable==2.27.0 google-cloud-core==2.4.1 google-cloud-datastore==2.20.1 -google-cloud-dlp==3.23.0 -google-cloud-language==2.14.0 +google-cloud-dlp==3.25.1 +google-cloud-language==2.15.1 google-cloud-profiler==4.1.0 -google-cloud-pubsub==2.25.2 +google-cloud-pubsub==2.27.1 google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.12 -google-cloud-resource-manager==1.12.5 -google-cloud-spanner==3.49.1 +google-cloud-recommendations-ai==0.10.14 +google-cloud-resource-manager==1.13.1 +google-cloud-spanner==3.50.1 google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.13.5 -google-cloud-vision==3.7.4 +google-cloud-videointelligence==2.14.1 +google-cloud-vision==3.8.1 google-crc32c==1.6.0 google-resumable-media==2.7.2 -googleapis-common-protos==1.65.0 +googleapis-common-protos==1.66.0 greenlet==3.1.1 grpc-google-iam-v1==0.13.1 grpc-interceptor==0.15.4 @@ -81,9 +81,9 @@ grpcio-status==1.62.3 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.112.3 +hypothesis==6.119.1 idna==3.10 -importlib_metadata==8.4.0 +importlib_metadata==8.5.0 iniconfig==2.0.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 @@ -91,12 +91,12 @@ jaraco.functools==4.1.0 jeepney==0.8.0 Jinja2==3.1.4 joblib==1.4.2 -jsonpickle==3.3.0 +jsonpickle==3.4.2 jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -keyring==25.4.1 +jsonschema-specifications==2024.10.1 +keyring==25.5.0 keyrings.google-artifactregistry-auth==1.1.2 -MarkupSafe==2.1.5 +MarkupSafe==3.0.2 mmh3==5.0.1 mock==5.1.0 more-itertools==10.5.0 @@ -105,16 +105,16 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -opentelemetry-api==1.27.0 -opentelemetry-sdk==1.27.0 -opentelemetry-semantic-conventions==0.48b0 -orjson==3.10.7 +opentelemetry-api==1.28.1 +opentelemetry-sdk==1.28.1 +opentelemetry-semantic-conventions==0.49b1 +orjson==3.10.11 overrides==7.7.0 -packaging==24.1 +packaging==24.2 pandas==2.1.4 parameterized==0.9.0 pluggy==1.5.0 -proto-plus==1.24.0 +proto-plus==1.25.0 protobuf==4.25.5 psycopg2-binary==2.9.9 pyarrow==16.1.0 @@ -128,7 +128,7 @@ pydot==1.4.2 PyHamcrest==2.1.0 pymongo==4.10.1 PyMySQL==1.1.1 -pyparsing==3.1.4 +pyparsing==3.2.0 pyproject_hooks==1.2.0 pytest==7.4.4 pytest-timeout==2.3.1 @@ -137,32 +137,32 @@ python-dateutil==2.9.0.post0 python-snappy==0.7.3 pytz==2024.2 PyYAML==6.0.2 -redis==5.1.1 +redis==5.2.0 referencing==0.35.1 -regex==2024.9.11 +regex==2024.11.6 requests==2.32.3 requests-mock==1.12.1 -rpds-py==0.20.0 +rpds-py==0.21.0 rsa==4.9 scikit-learn==1.5.2 scipy==1.14.1 SecretStorage==3.3.3 -setuptools==75.1.0 +setuptools==75.5.0 shapely==2.0.6 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.6 -SQLAlchemy==2.0.35 -sqlparse==0.5.1 +SQLAlchemy==2.0.36 +sqlparse==0.5.2 tenacity==8.5.0 testcontainers==3.7.1 threadpoolctl==3.5.0 -tqdm==4.66.5 +tqdm==4.67.0 typing_extensions==4.12.2 tzdata==2024.2 uritemplate==4.1.1 urllib3==2.2.3 -wheel==0.44.0 +wheel==0.45.0 wrapt==1.16.0 -zipp==3.20.2 +zipp==3.21.0 zstandard==0.23.0 diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt deleted file mode 100644 index 0a67a3666d25..000000000000 --- a/sdks/python/container/py38/base_image_requirements.txt +++ /dev/null @@ -1,172 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Autogenerated requirements file for Apache Beam py38 container image. -# Run ./gradlew :sdks:python:container:generatePythonRequirementsAll to update. -# Do not edit manually, adjust ../base_image_requirements_manual.txt or -# Apache Beam's setup.py instead, and regenerate the list. -# You will need Python interpreters for all versions supported by Beam, see: -# https://s.apache.org/beam-python-dev-wiki -# Reach out to a committer if you need help. - -annotated-types==0.7.0 -async-timeout==4.0.3 -attrs==24.2.0 -backports.tarfile==1.2.0 -beautifulsoup4==4.12.3 -bs4==0.0.2 -build==1.2.2 -cachetools==5.5.0 -certifi==2024.8.30 -cffi==1.17.1 -charset-normalizer==3.3.2 -click==8.1.7 -cloudpickle==2.2.1 -cramjam==2.8.4 -crcmod==1.7 -cryptography==43.0.1 -Cython==3.0.11 -Deprecated==1.2.14 -deprecation==2.1.0 -dill==0.3.1.1 -dnspython==2.6.1 -docker==7.1.0 -docopt==0.6.2 -docstring_parser==0.16 -exceptiongroup==1.2.2 -execnet==2.1.1 -fastavro==1.9.7 -fasteners==0.19 -freezegun==1.5.1 -future==1.0.0 -google-api-core==2.20.0 -google-api-python-client==2.147.0 -google-apitools==0.5.31 -google-auth==2.35.0 -google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.69.0 -google-cloud-bigquery==3.26.0 -google-cloud-bigquery-storage==2.26.0 -google-cloud-bigtable==2.26.0 -google-cloud-core==2.4.1 -google-cloud-datastore==2.20.1 -google-cloud-dlp==3.23.0 -google-cloud-language==2.14.0 -google-cloud-profiler==4.1.0 -google-cloud-pubsub==2.25.2 -google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.12 -google-cloud-resource-manager==1.12.5 -google-cloud-spanner==3.49.1 -google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.13.5 -google-cloud-vision==3.7.4 -google-crc32c==1.5.0 -google-resumable-media==2.7.2 -googleapis-common-protos==1.65.0 -greenlet==3.1.1 -grpc-google-iam-v1==0.13.1 -grpc-interceptor==0.15.4 -grpcio==1.65.5 -grpcio-status==1.62.3 -guppy3==3.1.4.post1 -hdfs==2.7.3 -httplib2==0.22.0 -hypothesis==6.112.3 -idna==3.10 -importlib_metadata==8.4.0 -importlib_resources==6.4.5 -iniconfig==2.0.0 -jaraco.classes==3.4.0 -jaraco.context==6.0.1 -jaraco.functools==4.1.0 -jeepney==0.8.0 -Jinja2==3.1.4 -joblib==1.4.2 -jsonpickle==3.3.0 -jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -keyring==25.4.1 -keyrings.google-artifactregistry-auth==1.1.2 -MarkupSafe==2.1.5 -mmh3==5.0.1 -mock==5.1.0 -more-itertools==10.5.0 -nltk==3.9.1 -nose==1.3.7 -numpy==1.24.4 -oauth2client==4.1.3 -objsize==0.7.0 -opentelemetry-api==1.27.0 -opentelemetry-sdk==1.27.0 -opentelemetry-semantic-conventions==0.48b0 -orjson==3.10.7 -overrides==7.7.0 -packaging==24.1 -pandas==2.0.3 -parameterized==0.9.0 -pkgutil_resolve_name==1.3.10 -pluggy==1.5.0 -proto-plus==1.24.0 -protobuf==4.25.5 -psycopg2-binary==2.9.9 -pyarrow==16.1.0 -pyarrow-hotfix==0.6 -pyasn1==0.6.1 -pyasn1_modules==0.4.1 -pycparser==2.22 -pydantic==2.9.2 -pydantic_core==2.23.4 -pydot==1.4.2 -PyHamcrest==2.1.0 -pymongo==4.10.1 -PyMySQL==1.1.1 -pyparsing==3.1.4 -pyproject_hooks==1.2.0 -pytest==7.4.4 -pytest-timeout==2.3.1 -pytest-xdist==3.6.1 -python-dateutil==2.9.0.post0 -python-snappy==0.7.3 -pytz==2024.2 -PyYAML==6.0.2 -redis==5.1.1 -referencing==0.35.1 -regex==2024.9.11 -requests==2.32.3 -requests-mock==1.12.1 -rpds-py==0.20.0 -rsa==4.9 -scikit-learn==1.3.2 -scipy==1.10.1 -SecretStorage==3.3.3 -shapely==2.0.6 -six==1.16.0 -sortedcontainers==2.4.0 -soupsieve==2.6 -SQLAlchemy==2.0.35 -sqlparse==0.5.1 -tenacity==8.5.0 -testcontainers==3.7.1 -threadpoolctl==3.5.0 -tomli==2.0.2 -tqdm==4.66.5 -typing_extensions==4.12.2 -tzdata==2024.2 -uritemplate==4.1.1 -urllib3==2.2.3 -wrapt==1.16.0 -zipp==3.20.2 -zstandard==0.23.0 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index f726406cfd13..22ab0a2fbcf8 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -22,23 +22,23 @@ # Reach out to a committer if you need help. annotated-types==0.7.0 -async-timeout==4.0.3 +async-timeout==5.0.1 attrs==24.2.0 backports.tarfile==1.2.0 beautifulsoup4==4.12.3 bs4==0.0.2 -build==1.2.2 +build==1.2.2.post1 cachetools==5.5.0 certifi==2024.8.30 cffi==1.17.1 -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 click==8.1.7 cloudpickle==2.2.1 -cramjam==2.8.4 +cramjam==2.9.0 crcmod==1.7 -cryptography==43.0.1 +cryptography==43.0.3 Cython==3.0.11 -Deprecated==1.2.14 +Deprecated==1.2.15 deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 @@ -51,31 +51,31 @@ fastavro==1.9.7 fasteners==0.19 freezegun==1.5.1 future==1.0.0 -google-api-core==2.20.0 -google-api-python-client==2.147.0 +google-api-core==2.23.0 +google-api-python-client==2.153.0 google-apitools==0.5.31 -google-auth==2.35.0 +google-auth==2.36.0 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.69.0 -google-cloud-bigquery==3.26.0 -google-cloud-bigquery-storage==2.26.0 -google-cloud-bigtable==2.26.0 +google-cloud-aiplatform==1.72.0 +google-cloud-bigquery==3.27.0 +google-cloud-bigquery-storage==2.27.0 +google-cloud-bigtable==2.27.0 google-cloud-core==2.4.1 google-cloud-datastore==2.20.1 -google-cloud-dlp==3.23.0 -google-cloud-language==2.14.0 +google-cloud-dlp==3.25.1 +google-cloud-language==2.15.1 google-cloud-profiler==4.1.0 -google-cloud-pubsub==2.25.2 +google-cloud-pubsub==2.27.1 google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.12 -google-cloud-resource-manager==1.12.5 -google-cloud-spanner==3.49.1 +google-cloud-recommendations-ai==0.10.14 +google-cloud-resource-manager==1.13.1 +google-cloud-spanner==3.50.1 google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.13.5 -google-cloud-vision==3.7.4 +google-cloud-videointelligence==2.14.1 +google-cloud-vision==3.8.1 google-crc32c==1.6.0 google-resumable-media==2.7.2 -googleapis-common-protos==1.65.0 +googleapis-common-protos==1.66.0 greenlet==3.1.1 grpc-google-iam-v1==0.13.1 grpc-interceptor==0.15.4 @@ -84,9 +84,9 @@ grpcio-status==1.62.3 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.112.3 +hypothesis==6.119.1 idna==3.10 -importlib_metadata==8.4.0 +importlib_metadata==8.5.0 iniconfig==2.0.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 @@ -94,12 +94,12 @@ jaraco.functools==4.1.0 jeepney==0.8.0 Jinja2==3.1.4 joblib==1.4.2 -jsonpickle==3.3.0 +jsonpickle==3.4.2 jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -keyring==25.4.1 +jsonschema-specifications==2024.10.1 +keyring==25.5.0 keyrings.google-artifactregistry-auth==1.1.2 -MarkupSafe==2.1.5 +MarkupSafe==3.0.2 mmh3==5.0.1 mock==5.1.0 more-itertools==10.5.0 @@ -108,16 +108,16 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -opentelemetry-api==1.27.0 -opentelemetry-sdk==1.27.0 -opentelemetry-semantic-conventions==0.48b0 -orjson==3.10.7 +opentelemetry-api==1.28.1 +opentelemetry-sdk==1.28.1 +opentelemetry-semantic-conventions==0.49b1 +orjson==3.10.11 overrides==7.7.0 -packaging==24.1 +packaging==24.2 pandas==2.1.4 parameterized==0.9.0 pluggy==1.5.0 -proto-plus==1.24.0 +proto-plus==1.25.0 protobuf==4.25.5 psycopg2-binary==2.9.9 pyarrow==16.1.0 @@ -131,7 +131,7 @@ pydot==1.4.2 PyHamcrest==2.1.0 pymongo==4.10.1 PyMySQL==1.1.1 -pyparsing==3.1.4 +pyparsing==3.2.0 pyproject_hooks==1.2.0 pytest==7.4.4 pytest-timeout==2.3.1 @@ -140,12 +140,12 @@ python-dateutil==2.9.0.post0 python-snappy==0.7.3 pytz==2024.2 PyYAML==6.0.2 -redis==5.1.1 +redis==5.2.0 referencing==0.35.1 -regex==2024.9.11 +regex==2024.11.6 requests==2.32.3 requests-mock==1.12.1 -rpds-py==0.20.0 +rpds-py==0.21.0 rsa==4.9 scikit-learn==1.5.2 scipy==1.13.1 @@ -154,17 +154,17 @@ shapely==2.0.6 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.6 -SQLAlchemy==2.0.35 -sqlparse==0.5.1 +SQLAlchemy==2.0.36 +sqlparse==0.5.2 tenacity==8.5.0 testcontainers==3.7.1 threadpoolctl==3.5.0 -tomli==2.0.2 -tqdm==4.66.5 +tomli==2.1.0 +tqdm==4.67.0 typing_extensions==4.12.2 tzdata==2024.2 uritemplate==4.1.1 urllib3==2.2.3 wrapt==1.16.0 -zipp==3.20.2 +zipp==3.21.0 zstandard==0.23.0 diff --git a/sdks/python/expansion-service-container/Dockerfile b/sdks/python/expansion-service-container/Dockerfile index 5a5ef0f410bc..4e82165f594c 100644 --- a/sdks/python/expansion-service-container/Dockerfile +++ b/sdks/python/expansion-service-container/Dockerfile @@ -17,7 +17,7 @@ ############################################################################### # We just need to support one Python version supported by Beam. -# Picking the current default Beam Python version which is Python 3.8. +# Picking the current default Beam Python version which is Python 3.9. FROM python:3.9-bookworm as expansion-service LABEL Author "Apache Beam " ARG TARGETOS diff --git a/sdks/python/expansion-service-container/build.gradle b/sdks/python/expansion-service-container/build.gradle index 3edcaee35b4a..4e46f060e59f 100644 --- a/sdks/python/expansion-service-container/build.gradle +++ b/sdks/python/expansion-service-container/build.gradle @@ -40,7 +40,7 @@ task copyDockerfileDependencies(type: Copy) { } task copyRequirementsFile(type: Copy) { - from project(':sdks:python:container:py38').fileTree("./") + from project(':sdks:python:container:py39').fileTree("./") include 'base_image_requirements.txt' rename 'base_image_requirements.txt', 'requirements.txt' setDuplicatesStrategy(DuplicatesStrategy.INCLUDE) diff --git a/sdks/python/mypy.ini b/sdks/python/mypy.ini index 562cb8d56dcc..ee76089fec0b 100644 --- a/sdks/python/mypy.ini +++ b/sdks/python/mypy.ini @@ -28,11 +28,17 @@ files = apache_beam color_output = true # uncomment this to see how close we are to being complete # check_untyped_defs = true -disable_error_code = var-annotated +disable_error_code = var-annotated, import-untyped, valid-type, truthy-function, attr-defined, annotation-unchecked + +[tool.mypy] +ignore_missing_imports = true [mypy-apache_beam.coders.proto2_coder_test_messages_pb2] ignore_errors = true +[mypy-apache_beam.dataframe.*] +ignore_errors = true + [mypy-apache_beam.examples.*] ignore_errors = true diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 037e5a8aed6b..4eb827297019 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -26,7 +26,7 @@ requires = [ # Avoid https://github.com/pypa/virtualenv/issues/2006 "distlib==0.3.7", # Numpy headers - "numpy>=1.14.3,<1.27", # Update setup.py as well. + "numpy>=1.14.3,<2.2.0", # Update setup.py as well. # having cython here will create wheels that are platform dependent. "cython>=3.0,<4", ## deps for generating external transform wrappers: diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini index b10acaac71cd..b62a44aa25e7 100644 --- a/sdks/python/pytest.ini +++ b/sdks/python/pytest.ini @@ -49,6 +49,7 @@ markers = sickbay_direct: run without sickbay-direct sickbay_spark: run without sickbay-spark sickbay_flink: run without sickbay-flink + sickbay_prism: run without sickbay-prism sickbay_dataflow: run without sickbay-dataflow # Tests using this marker conflict with the xdist plugin in some way, such # as enabling save_main_session. diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 490406d579e5..4922b61169d1 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -64,6 +64,7 @@ excluded_patterns=( 'apache_beam/runners/portability/' 'apache_beam/runners/test/' 'apache_beam/runners/worker/' + 'apache_beam/runners/dask/transform_evaluator.*' 'apache_beam/testing/benchmarks/chicago_taxi/' 'apache_beam/testing/benchmarks/cloudml/' 'apache_beam/testing/benchmarks/inference/' @@ -124,7 +125,6 @@ extensions = [ ] master_doc = 'index' html_theme = 'sphinx_rtd_theme' -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] project = 'Apache Beam' version = beam_version.__version__ release = version @@ -135,7 +135,7 @@ autodoc_member_order = 'bysource' autodoc_mock_imports = ["tensorrt", "cuda", "torch", "onnxruntime", "onnx", "tensorflow", "tensorflow_hub", "tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost", "datatable", "transformers", - "sentence_transformers", "redis", "tensorflow_text", "feast", + "sentence_transformers", "redis", "tensorflow_text", "feast", "dask", ] # Allow a special section for documenting DataFrame API @@ -143,6 +143,8 @@ napoleon_custom_sections = ['Differences from pandas'] doctest_global_setup = ''' import apache_beam as beam +import pandas as pd +import numpy as np ''' intersphinx_mapping = { @@ -242,6 +244,9 @@ ignore_identifiers = [ # IPython Magics py:class reference target not found 'IPython.core.magic.Magics', + + # Type variables. + 'apache_beam.transforms.util.T', ] ignore_references = [ 'BeamIOError', @@ -275,7 +280,7 @@ python $(type -p sphinx-build) -v -a -E -q target/docs/source \ # Fail if there are errors or warnings in docs ! grep -q "ERROR:" target/docs/sphinx-build.log || exit 1 -! grep -q "WARNING:" target/docs/sphinx-build.log || exit 1 +# ! grep -q "WARNING:" target/docs/sphinx-build.log || exit 1 # Run tests for code samples, these can be: # - Code blocks using '.. testsetup::', '.. testcode::' and '.. testoutput::' @@ -283,12 +288,13 @@ python $(type -p sphinx-build) -v -a -E -q target/docs/source \ python -msphinx -M doctest target/docs/source \ target/docs/_build -c target/docs/source \ 2>&1 | grep -E -v 'apache_beam\.dataframe.*WARNING:' \ + 2>&1 | grep -E -v 'apache_beam\.dataframe.*ERROR:' \ 2>&1 | grep -E -v 'apache_beam\.io\.textio\.(ReadFrom|WriteTo)(Csv|Json).*WARNING:' \ 2>&1 | tee "target/docs/sphinx-doctest.log" # Fail if there are errors or warnings in docs ! grep -q "ERROR:" target/docs/sphinx-doctest.log || exit 1 -! grep -q "WARNING:" target/docs/sphinx-doctest.log || exit 1 +# ! grep -q "WARNING:" target/docs/sphinx-doctest.log || exit 1 # Message is useful only when this script is run locally. In a remote # test environment, this path will be removed when the test completes. diff --git a/sdks/python/scripts/run_snapshot_publish.sh b/sdks/python/scripts/run_snapshot_publish.sh index bc379077349d..0d7c7764748d 100755 --- a/sdks/python/scripts/run_snapshot_publish.sh +++ b/sdks/python/scripts/run_snapshot_publish.sh @@ -31,7 +31,9 @@ DEP_SNAPSHOT_FILE_NAME="beam-py-requirements-$time.txt" cd $WORKSPACE/sdks/python/build # Rename the file to be apache-beam-{VERSION}-{datetime}.tar.gz -for file in "apache-beam-$VERSION*.tar.gz"; do +# Notice that the distribution name of beam can be "apache-beam" with +# setuptools<69.3.0 or "apache_beam" with setuptools>=69.3.0. +for file in "apache[-_]beam-$VERSION*.tar.gz"; do mv $file $SNAPSHOT done diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 721cb4c1a8dd..53c7a532e706 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -155,7 +155,7 @@ def cythonize(*args, **kwargs): # Exclude 1.5.0 and 1.5.1 because of # https://github.com/pandas-dev/pandas/issues/45725 dataframe_dependency = [ - 'pandas>=1.4.3,!=1.5.0,!=1.5.1,<2.3;python_version>="3.8"', + 'pandas>=1.4.3,!=1.5.0,!=1.5.1,<2.3', ] @@ -271,18 +271,13 @@ def get_portability_package_data(): return files -python_requires = '>=3.8' +python_requires = '>=3.9' -if sys.version_info.major == 3 and sys.version_info.minor >= 12: +if sys.version_info.major == 3 and sys.version_info.minor >= 13: warnings.warn( 'This version of Apache Beam has not been sufficiently tested on ' 'Python %s.%s. You may encounter bugs or missing features.' % (sys.version_info.major, sys.version_info.minor)) -elif sys.version_info.major == 3 and sys.version_info.minor == 8: - warnings.warn('Python 3.8 reaches EOL in October 2024 and support will ' - 'be removed from Apache Beam in version 2.61.0. See ' - 'https://github.com/apache/beam/issues/31192 for more ' - 'information.') if __name__ == '__main__': # In order to find the tree of proto packages, the directory @@ -366,7 +361,7 @@ def get_portability_package_data(): 'jsonpickle>=3.0.0,<4.0.0', # numpy can have breaking changes in minor versions. # Use a strict upper bound. - 'numpy>=1.14.3,<1.27.0', # Update pyproject.toml as well. + 'numpy>=1.14.3,<2.2.0', # Update pyproject.toml as well. 'objsize>=0.6.1,<0.8.0', 'packaging>=22.0', 'pymongo>=3.8.0,<5.0.0', @@ -381,15 +376,17 @@ def get_portability_package_data(): # # 3. Exclude protobuf 4 versions that leak memory, see: # https://github.com/apache/beam/issues/28246 - 'protobuf>=3.20.3,<4.26.0,!=4.0.*,!=4.21.*,!=4.22.0,!=4.23.*,!=4.24.*', # pylint: disable=line-too-long + 'protobuf>=3.20.3,<6.0.0.dev0,!=4.0.*,!=4.21.*,!=4.22.0,!=4.23.*,!=4.24.*', # pylint: disable=line-too-long 'pydot>=1.2.0,<2', 'python-dateutil>=2.8.0,<3', 'pytz>=2018.3', 'redis>=5.0.0,<6', 'regex>=2020.6.8', 'requests>=2.24.0,<3.0.0', + 'sortedcontainers>=2.4.0', 'typing-extensions>=3.7.0', 'zstandard>=0.18.0,<1', + 'pyyaml>=3.12,<7.0.0', # Dynamic dependencies must be specified in a separate list, otherwise # Dependabot won't be able to parse the main list. Any dynamic # dependencies will not receive updates from Dependabot. @@ -399,11 +396,9 @@ def get_portability_package_data(): extras_require={ 'docs': [ 'jinja2>=3.0,<3.2', - 'Sphinx>=1.5.2,<2.0', + 'Sphinx>=7.0.0,<8.0', 'docstring-parser>=0.15,<1.0', - # Pinning docutils as a workaround for Sphinx issue: - # https://github.com/sphinx-doc/sphinx/issues/9727 - 'docutils==0.17.1', + 'docutils>=0.18.1', 'pandas<2.2.0', 'openai' ], @@ -416,7 +411,6 @@ def get_portability_package_data(): 'pandas<2.2.0', 'parameterized>=0.7.1,<0.10.0', 'pyhamcrest>=1.9,!=1.10.0,<3.0.0', - 'pyyaml>=3.12,<7.0.0', 'requests_mock>=1.7,<2.0', 'tenacity>=8.0.0,<9', 'pytest>=7.1.2,<8.0', @@ -425,7 +419,7 @@ def get_portability_package_data(): 'scikit-learn>=0.20.0', 'setuptools', 'sqlalchemy>=1.3,<3.0', - 'psycopg2-binary>=2.8.5,<3.0.0', + 'psycopg2-binary>=2.8.5,<3.0.0,!=2.9.10', 'testcontainers[mysql]>=3.0.3,<4.0.0', 'cryptography>=41.0.2', 'hypothesis>5.0.0,<7.0.0', @@ -496,12 +490,9 @@ def get_portability_package_data(): 'sentence-transformers', 'skl2onnx', 'pillow', - # Support TF 2.16.0: https://github.com/apache/beam/issues/31294 - # Once TF version is unpinned, also don't restrict Python version. - 'tensorflow<2.16.0;python_version<"3.12"', + 'tensorflow', 'tensorflow-hub', - # https://github.com/tensorflow/transform/issues/313 - 'tensorflow-transform;python_version<"3.11"', + 'tensorflow-transform', 'tf2onnx', 'torch', 'transformers', @@ -510,6 +501,19 @@ def get_portability_package_data(): # https://github.com/apache/beam/issues/31285 # 'xgboost<2.0', # https://github.com/apache/beam/issues/31252 ], + 'p312_ml_test': [ + 'datatable', + 'embeddings', + 'onnxruntime', + 'sentence-transformers', + 'skl2onnx', + 'pillow', + 'tensorflow', + 'tensorflow-hub', + 'tf2onnx', + 'torch', + 'transformers', + ], 'aws': ['boto3>=1.9,<2'], 'azure': [ 'azure-storage-blob>=12.3.2,<13', @@ -518,13 +522,19 @@ def get_portability_package_data(): ], 'dataframe': dataframe_dependency, 'dask': [ - 'dask >= 2022.6', - 'distributed >= 2022.6', + 'distributed >= 2024.4.2', + 'dask >= 2024.4.2', + # For development, 'distributed >= 2023.12.1' should work with + # the above dask PR, however it can't be installed as part of + # a single `pip` call, since distributed releases are pinned to + # specific dask releases. As a workaround, distributed can be + # installed first, and then `.[dask]` installed second, with the + # `--update` / `-U` flag to replace the dask release brought in + # by distributed. ], 'yaml': [ 'docstring-parser>=0.15,<1.0', 'jinja2>=3.0,<3.2', - 'pyyaml>=3.12,<7.0.0', 'virtualenv-clone>=0.5,<1.0', # https://github.com/PiotrDabkowski/Js2Py/issues/317 'js2py>=0.74,<1; python_version<"3.12"', @@ -536,7 +546,6 @@ def get_portability_package_data(): 'Intended Audience :: End Users/Desktop', 'License :: OSI Approved :: Apache Software License', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 6bca904c1a64..e851a5420673 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -380,6 +380,50 @@ task validatesContainer() { } } +tasks.register('validatesDistrolessContainer', Task.class) { + dependsOn 'initializeForDataflowJob' + def repository = "us.gcr.io/apache-beam-testing/${System.getenv('USER')}" + def tag = java.time.Instant.now().getEpochSecond() + def imageURL = "${repository}/beam_python${pythonVersion}_sdk_distroless:${tag}" + doLast { + exec { + executable 'docker' + workingDir rootDir + args = [ + 'buildx', + 'build', + '-t', + imageURL, + '-f', + 'sdks/python/container/Dockerfile-distroless', + "--build-arg=BASE=gcr.io/apache-beam-testing/beam-sdk/beam_python${pythonVersion}_sdk", + "." + ] + } + exec { + executable 'docker' + args = ['push', imageURL] + } + exec { + def testTarget = "apache_beam/examples/wordcount_it_test.py::WordCountIT::test_wordcount_it" + def argMap = [ + "output" : "gs://temp-storage-for-end-to-end-tests/py-it-cloud/output", + "project" : "apache-beam-testing", + "region" : "us-central1", + "runner" : "TestDataflowRunner", + "sdk_container_image": "${imageURL}", + "sdk_location" : "container", + "staging_location" : "gs://temp-storage-for-end-to-end-tests/staging-it", + "temp_location" : "gs://temp-storage-for-end-to-end-tests/temp-it", + ] + def cmdArgs = mapToArgString(argMap) + workingDir = "${rootDir}/sdks/python" + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pytest ${testTarget} --test-pipeline-options=\"${cmdArgs}\"" + } + } +} + task validatesContainerARM() { def pyversion = "${project.ext.pythonVersion.replace('.', '')}" dependsOn 'initializeForDataflowJob' @@ -543,12 +587,13 @@ task mockAPITests { } // add all RunInference E2E tests that run on DataflowRunner -// As of now, this test suite is enable in py38 suite as the base NVIDIA image used for Tensor RT -// contains Python 3.8. +// As of now, this test suite is enable in py310 suite as the base NVIDIA image used for Tensor RT +// contains Python 3.10. // TODO: https://github.com/apache/beam/issues/22651 project.tasks.register("inferencePostCommitIT") { dependsOn = [ - 'tensorRTtests', + // Temporarily disabled because of a container issue + // 'tensorRTtests', 'vertexAIInferenceTest', 'mockAPITests', ] diff --git a/sdks/python/test-suites/dataflow/py38/build.gradle b/sdks/python/test-suites/dataflow/py38/build.gradle deleted file mode 100644 index b3c3a5bfb8a6..000000000000 --- a/sdks/python/test-suites/dataflow/py38/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -apply plugin: org.apache.beam.gradle.BeamModulePlugin -applyPythonNature() - -// Required to setup a Python 3 virtualenv and task names. -pythonVersion = '3.8' -apply from: "../common.gradle" diff --git a/sdks/python/test-suites/direct/py38/build.gradle b/sdks/python/test-suites/direct/py38/build.gradle deleted file mode 100644 index edf86a7bf5a8..000000000000 --- a/sdks/python/test-suites/direct/py38/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -plugins { id 'org.apache.beam.module' } -applyPythonNature() - -// Required to setup a Python 3 virtualenv and task names. -pythonVersion = '3.8' -apply from: '../common.gradle' diff --git a/sdks/python/test-suites/gradle.properties b/sdks/python/test-suites/gradle.properties index f8c04e0f5609..d027cd3144d3 100644 --- a/sdks/python/test-suites/gradle.properties +++ b/sdks/python/test-suites/gradle.properties @@ -47,5 +47,10 @@ samza_validates_runner_postcommit_py_versions=3.9,3.12 # spark runner test-suites spark_examples_postcommit_py_versions=3.9,3.12 +# prism runner test-suites +prism_validates_runner_precommit_py_versions=3.12 +prism_validates_runner_postcommit_py_versions=3.9,3.12 +prism_examples_postcommit_py_versions=3.9,3.12 + # cross language postcommit python test suites cross_language_validates_py_versions=3.9,3.12 diff --git a/sdks/python/test-suites/portable/build.gradle b/sdks/python/test-suites/portable/build.gradle index 390c39a10899..41cd88acfb6a 100644 --- a/sdks/python/test-suites/portable/build.gradle +++ b/sdks/python/test-suites/portable/build.gradle @@ -31,12 +31,31 @@ tasks.register("samzaValidatesRunner") { } } +tasks.register("prismValidatesRunner") { + getVersionsAsList('prism_validates_runner_postcommit_py_versions').each { + dependsOn.add(":sdks:python:test-suites:portable:py${getVersionSuffix(it)}:prismValidatesRunner") + } +} + tasks.register("flinkValidatesRunnerPrecommit") { getVersionsAsList('flink_validates_runner_precommit_py_versions').each { dependsOn.add(":sdks:python:test-suites:portable:py${getVersionSuffix(it)}:flinkValidatesRunner") } } +tasks.register("prismValidatesRunnerPrecommit") { + getVersionsAsList('prism_validates_runner_precommit_py_versions').each { + dependsOn.add(":sdks:python:test-suites:portable:py${getVersionSuffix(it)}:prismValidatesRunner") + } +} + +// TODO merge with above once passing. Currently for convenience. +tasks.register("prismTriggerTranscript") { + getVersionsAsList('prism_validates_runner_precommit_py_versions').each { + dependsOn.add(":sdks:python:test-suites:portable:py${getVersionSuffix(it)}:prismTriggerTranscript") + } +} + tasks.register("flinkExamplesPostCommit") { getVersionsAsList('flink_examples_postcommit_py_versions').each { dependsOn.add(":sdks:python:test-suites:portable:py${getVersionSuffix(it)}:flinkExamples") @@ -48,3 +67,9 @@ tasks.register("sparkExamplesPostCommit") { dependsOn.add(":sdks:python:test-suites:portable:py${getVersionSuffix(it)}:sparkExamples") } } + +tasks.register("prismExamplesPostCommit") { + getVersionsAsList('prism_examples_postcommit_py_versions').each { + dependsOn.add(":sdks:python:test-suites:portable:py${getVersionSuffix(it)}:prismExamples") + } +} diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle index 4f232c5b104f..be87be749862 100644 --- a/sdks/python/test-suites/portable/common.gradle +++ b/sdks/python/test-suites/portable/common.gradle @@ -23,6 +23,7 @@ import org.apache.tools.ant.taskdefs.condition.Os def pythonRootDir = "${rootDir}/sdks/python" def pythonVersionSuffix = project.ext.pythonVersion.replace('.', '') def latestFlinkVersion = project.ext.latestFlinkVersion +def currentJavaVersion = project.ext.currentJavaVersion ext { pythonContainerTask = ":sdks:python:container:py${pythonVersionSuffix}:docker" @@ -201,6 +202,56 @@ tasks.register("sparkValidatesRunner") { dependsOn 'sparkCompatibilityMatrixLOOPBACK' } +def createPrismRunnerTestTask(String workerType) { + def taskName = "prismCompatibilityMatrix${workerType}" + + def prismBin = "${rootDir}/runners/prism/build/tmp/prism" + def options = "--prism_bin=${prismBin} --environment_type=${workerType}" + if (workerType == 'PROCESS') { + options += " --environment_options=process_command=${buildDir.absolutePath}/sdk_worker.sh" + } + def task = toxTask(taskName, 'prism-runner-test', options) + task.configure { + dependsOn ":runners:prism:build" + // The Java SDK worker is required to execute external transforms. + def suffix = getSupportedJavaVersion() + dependsOn ":sdks:java:container:${suffix}:docker" + if (workerType == 'DOCKER') { + dependsOn pythonContainerTask + } else if (workerType == 'PROCESS') { + dependsOn createProcessWorker + } + } + return task +} + +createPrismRunnerTestTask('DOCKER') +createPrismRunnerTestTask('PROCESS') +createPrismRunnerTestTask('LOOPBACK') + +tasks.register("prismValidatesRunner") { + dependsOn 'prismCompatibilityMatrixLOOPBACK' +} + +tasks.register("prismTriggerTranscript") { + dependsOn 'setupVirtualenv' + dependsOn ':runners:prism:build' + def prismBin = "${rootDir}/runners/prism/build/tmp/prism" + doLast { + exec { + executable 'sh' + args '-c', """ + . ${envdir}/bin/activate \\ + && cd ${pythonRootDir} \\ + && pip install -e .[test] \\ + && pytest \\ + apache_beam/transforms/trigger_test.py::WeakTestStreamTranscriptTest \\ + --test-pipeline-options='--runner=PrismRunner --environment_type=LOOPBACK --prism_location=${prismBin}' + """ + } + } +} + project.tasks.register("preCommitPy${pythonVersionSuffix}") { dependsOn = [":sdks:python:container:py${pythonVersionSuffix}:docker", ":runners:flink:${latestFlinkVersion}:job-server:shadowJar", @@ -227,12 +278,14 @@ project.tasks.register("flinkExamples") { def testOpts = [ "--log-cli-level=INFO", ] + def flink_conf_dir = "${rootDir}/runners/flink/src/test/resources/" def pipelineOpts = [ "--runner=FlinkRunner", "--project=apache-beam-testing", "--environment_type=LOOPBACK", "--temp_location=gs://temp-storage-for-end-to-end-tests/temp-it", "--flink_job_server_jar=${project(":runners:flink:${latestFlinkVersion}:job-server").shadowJar.archivePath}", + "--flink_conf_dir=${flink_conf_dir}", '--sdk_harness_log_level_overrides=' + // suppress info level flink.runtime log flood '{\\"org.apache.flink.runtime\\":\\"WARN\\",' + @@ -283,12 +336,43 @@ project.tasks.register("sparkExamples") { } } +project.tasks.register("prismExamples") { + dependsOn = [ + 'setupVirtualenv', + 'installGcpTest', + ':runners:prism:build', + ] + def prismBin = "${rootDir}/runners/prism/build/tmp/prism" + doLast { + def testOpts = [ + "--log-cli-level=INFO", + ] + def pipelineOpts = [ + "--runner=PrismRunner", + "--project=apache-beam-testing", + "--environment_type=LOOPBACK", + "--temp_location=gs://temp-storage-for-end-to-end-tests/temp-it", + "--prism_location=${prismBin}", + ] + def cmdArgs = mapToArgString([ + "test_opts": testOpts, + "suite": "postCommitExamples-prism-py${pythonVersionSuffix}", + "pipeline_opts": pipelineOpts.join(" "), + "collect": "examples_postcommit and not sickbay_prism" + ]) + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && ${pythonRootDir}/scripts/run_integration_test.sh $cmdArgs" + } + } +} + project.tasks.register("postCommitPy${pythonVersionSuffix}IT") { dependsOn = [ 'setupVirtualenv', 'installGcpTest', ":runners:flink:${latestFlinkVersion}:job-server:shadowJar", - ':sdks:java:container:java8:docker', + ":sdks:java:container:${currentJavaVersion}:docker", ':sdks:java:testing:kafka-service:buildTestKafkaServiceJar', ':sdks:java:io:expansion-service:shadowJar', ':sdks:java:io:google-cloud-platform:expansion-service:shadowJar', @@ -339,7 +423,7 @@ project.tasks.register("xlangSpannerIOIT") { 'setupVirtualenv', 'installGcpTest', ":runners:flink:${latestFlinkVersion}:job-server:shadowJar", - ':sdks:java:container:java8:docker', + ":sdks:java:container:${currentJavaVersion}:docker", ':sdks:java:io:expansion-service:shadowJar', ':sdks:java:io:google-cloud-platform:expansion-service:shadowJar', ':sdks:java:io:kinesis:expansion-service:shadowJar', diff --git a/sdks/python/test-suites/portable/py38/build.gradle b/sdks/python/test-suites/portable/py38/build.gradle deleted file mode 100644 index e15443fa935f..000000000000 --- a/sdks/python/test-suites/portable/py38/build.gradle +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -apply plugin: org.apache.beam.gradle.BeamModulePlugin -applyPythonNature() - -addPortableWordCountTasks() - -// Required to setup a Python 3.8 virtualenv and task names. -pythonVersion = '3.8' -apply from: "../common.gradle" diff --git a/sdks/python/test-suites/tox/common.gradle b/sdks/python/test-suites/tox/common.gradle index df42a2c384c2..01265a6eeff5 100644 --- a/sdks/python/test-suites/tox/common.gradle +++ b/sdks/python/test-suites/tox/common.gradle @@ -31,7 +31,6 @@ test.dependsOn "testPy${pythonVersionSuffix}ML" // toxTask "testPy${pythonVersionSuffix}Dask", "py${pythonVersionSuffix}-dask", "${posargs}" // test.dependsOn "testPy${pythonVersionSuffix}Dask" - project.tasks.register("preCommitPy${pythonVersionSuffix}") { // Since codecoverage reports will always be generated for py38, // all tests will be exercised. diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle deleted file mode 100644 index 2ca82d3d9268..000000000000 --- a/sdks/python/test-suites/tox/py38/build.gradle +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Unit tests for Python 3.8 - */ - -plugins { id 'org.apache.beam.module' } -applyPythonNature() - -// Required to setup a Python 3 virtualenv and task names. -pythonVersion = '3.8' - -def posargs = project.findProperty("posargs") ?: "" - -apply from: "../common.gradle" - -toxTask "testPy38CloudCoverage", "py38-cloudcoverage", "${posargs}" -test.dependsOn "testPy38CloudCoverage" -project.tasks.register("preCommitPyCoverage") { - dependsOn = ["testPy38CloudCoverage"] -} - -// Dep Postcommit runs test suites that evaluate compatibility of particular -// dependencies. It is exercised on a single Python version. -// -// Should still leave at least one version in PreCommit unless the marked tests -// are also exercised by existing PreCommit -// e.g. pyarrow and pandas also run on PreCommit Dataframe and Coverage -project.tasks.register("postCommitPyDep") {} - -// Create a test task for supported major versions of pyarrow -// We should have a test for the lowest supported version and -// For versions that we would like to prioritize for testing, -// for example versions released in a timeframe of last 1-2 years. - -toxTask "testPy38pyarrow-3", "py38-pyarrow-3", "${posargs}" -test.dependsOn "testPy38pyarrow-3" -postCommitPyDep.dependsOn "testPy38pyarrow-3" - -toxTask "testPy38pyarrow-9", "py38-pyarrow-9", "${posargs}" -test.dependsOn "testPy38pyarrow-9" -postCommitPyDep.dependsOn "testPy38pyarrow-9" - -toxTask "testPy38pyarrow-10", "py38-pyarrow-10", "${posargs}" -test.dependsOn "testPy38pyarrow-10" -postCommitPyDep.dependsOn "testPy38pyarrow-10" - -toxTask "testPy38pyarrow-11", "py38-pyarrow-11", "${posargs}" -test.dependsOn "testPy38pyarrow-11" -postCommitPyDep.dependsOn "testPy38pyarrow-11" - -toxTask "testPy38pyarrow-12", "py38-pyarrow-12", "${posargs}" -test.dependsOn "testPy38pyarrow-12" -postCommitPyDep.dependsOn "testPy38pyarrow-12" - -toxTask "testPy38pyarrow-13", "py38-pyarrow-13", "${posargs}" -test.dependsOn "testPy38pyarrow-13" -postCommitPyDep.dependsOn "testPy38pyarrow-13" - -toxTask "testPy38pyarrow-14", "py38-pyarrow-14", "${posargs}" -test.dependsOn "testPy38pyarrow-14" -postCommitPyDep.dependsOn "testPy38pyarrow-14" - -toxTask "testPy38pyarrow-15", "py38-pyarrow-15", "${posargs}" -test.dependsOn "testPy38pyarrow-15" -postCommitPyDep.dependsOn "testPy38pyarrow-15" - -toxTask "testPy38pyarrow-16", "py38-pyarrow-16", "${posargs}" -test.dependsOn "testPy38pyarrow-16" -postCommitPyDep.dependsOn "testPy38pyarrow-16" - -// Create a test task for each supported minor version of pandas -toxTask "testPy38pandas-14", "py38-pandas-14", "${posargs}" -test.dependsOn "testPy38pandas-14" -postCommitPyDep.dependsOn "testPy38pandas-14" - -toxTask "testPy38pandas-15", "py38-pandas-15", "${posargs}" -test.dependsOn "testPy38pandas-15" -postCommitPyDep.dependsOn "testPy38pandas-15" - -toxTask "testPy38pandas-20", "py38-pandas-20", "${posargs}" -test.dependsOn "testPy38pandas-20" -postCommitPyDep.dependsOn "testPy38pandas-20" - -// TODO(https://github.com/apache/beam/issues/31192): Add below suites -// after dependency compat tests suite switches to Python 3.9 or we add -// Python 2.2 support. - -// toxTask "testPy39pandas-21", "py39-pandas-21", "${posargs}" -// test.dependsOn "testPy39pandas-21" -// postCommitPyDep.dependsOn "testPy39pandas-21" - -// toxTask "testPy39pandas-22", "py39-pandas-22", "${posargs}" -// test.dependsOn "testPy39pandas-22" -// postCommitPyDep.dependsOn "testPy39pandas-22" - -// TODO(https://github.com/apache/beam/issues/30908): Revise what are we testing - -// Create a test task for each minor version of pytorch -toxTask "testPy38pytorch-19", "py38-pytorch-19", "${posargs}" -test.dependsOn "testPy38pytorch-19" -postCommitPyDep.dependsOn "testPy38pytorch-19" - -toxTask "testPy38pytorch-110", "py38-pytorch-110", "${posargs}" -test.dependsOn "testPy38pytorch-110" -postCommitPyDep.dependsOn "testPy38pytorch-110" - -toxTask "testPy38pytorch-111", "py38-pytorch-111", "${posargs}" -test.dependsOn "testPy38pytorch-111" -postCommitPyDep.dependsOn "testPy38pytorch-111" - -toxTask "testPy38pytorch-112", "py38-pytorch-112", "${posargs}" -test.dependsOn "testPy38pytorch-112" -postCommitPyDep.dependsOn "testPy38pytorch-112" - -toxTask "testPy38pytorch-113", "py38-pytorch-113", "${posargs}" -test.dependsOn "testPy38pytorch-113" -postCommitPyDep.dependsOn "testPy38pytorch-113" - -// run on precommit -toxTask "testPy38pytorch-200", "py38-pytorch-200", "${posargs}" -test.dependsOn "testPy38pytorch-200" -postCommitPyDep.dependsOn "testPy38pytorch-200" - -toxTask "testPy38tft-113", "py38-tft-113", "${posargs}" -test.dependsOn "testPy38tft-113" -postCommitPyDep.dependsOn "testPy38tft-113" - -// TODO(https://github.com/apache/beam/issues/25796) - uncomment onnx tox task once onnx supports protobuf 4.x.x -// Create a test task for each minor version of onnx -// toxTask "testPy38onnx-113", "py38-onnx-113", "${posargs}" -// test.dependsOn "testPy38onnx-113" -// postCommitPyDep.dependsOn "testPy38onnx-113" - -// Create a test task for each minor version of tensorflow -toxTask "testPy38tensorflow-212", "py38-tensorflow-212", "${posargs}" -test.dependsOn "testPy38tensorflow-212" -postCommitPyDep.dependsOn "testPy38tensorflow-212" - -// Create a test task for each minor version of transformers -toxTask "testPy38transformers-428", "py38-transformers-428", "${posargs}" -test.dependsOn "testPy38transformers-428" -postCommitPyDep.dependsOn "testPy38transformers-428" - -toxTask "testPy38transformers-429", "py38-transformers-429", "${posargs}" -test.dependsOn "testPy38transformers-429" -postCommitPyDep.dependsOn "testPy38transformers-429" - -toxTask "testPy38transformers-430", "py38-transformers-430", "${posargs}" -test.dependsOn "testPy38transformers-430" -postCommitPyDep.dependsOn "testPy38transformers-430" - -toxTask "testPy38embeddingsMLTransform", "py38-embeddings", "${posargs}" -test.dependsOn "testPy38embeddingsMLTransform" -postCommitPyDep.dependsOn "testPy38embeddingsMLTransform" - -// Part of MLTransform embeddings test suite but requires tensorflow hub, which we need to test on -// mutliple versions so keeping this suite separate. -toxTask "testPy38TensorflowHubEmbeddings-014", "py38-TFHubEmbeddings-014", "${posargs}" -test.dependsOn "testPy38TensorflowHubEmbeddings-014" -postCommitPyDep.dependsOn "testPy38TensorflowHubEmbeddings-014" - -toxTask "testPy38TensorflowHubEmbeddings-015", "py38-TFHubEmbeddings-015", "${posargs}" -test.dependsOn "testPy38TensorflowHubEmbeddings-015" -postCommitPyDep.dependsOn "testPy38TensorflowHubEmbeddings-015" - -toxTask "whitespacelint", "whitespacelint", "${posargs}" - -task archiveFilesToLint(type: Zip) { - archiveFileName = "files-to-whitespacelint.zip" - destinationDirectory = file("$buildDir/dist") - - from ("$rootProject.projectDir") { - include "**/*.md" - include "**/build.gradle" - include '**/build.gradle.kts' - exclude '**/build/**' // intermediate build directory - exclude 'website/www/site/themes/docsy/**' // fork to google/docsy - exclude "**/node_modules/*" - exclude "**/.gogradle/*" - } -} - -task unpackFilesToLint(type: Copy) { - from zipTree("$buildDir/dist/files-to-whitespacelint.zip") - into "$buildDir/files-to-whitespacelint" -} - -whitespacelint.dependsOn archiveFilesToLint, unpackFilesToLint -unpackFilesToLint.dependsOn archiveFilesToLint -archiveFilesToLint.dependsOn cleanPython - -toxTask "jest", "jest", "${posargs}" - -toxTask "eslint", "eslint", "${posargs}" - -task copyTsSource(type: Copy) { - from ("$rootProject.projectDir") { - include "sdks/python/apache_beam/runners/interactive/extensions/**/*" - exclude "sdks/python/apache_beam/runners/interactive/extensions/**/lib/*" - exclude "sdks/python/apache_beam/runners/interactive/extensions/**/node_modules/*" - } - into "$buildDir/ts" -} - -jest.dependsOn copyTsSource -eslint.dependsOn copyTsSource -copyTsSource.dependsOn cleanPython diff --git a/sdks/python/test-suites/tox/py39/build.gradle b/sdks/python/test-suites/tox/py39/build.gradle index ea02e9d5b1e8..e9624f8e810e 100644 --- a/sdks/python/test-suites/tox/py39/build.gradle +++ b/sdks/python/test-suites/tox/py39/build.gradle @@ -168,7 +168,9 @@ postCommitPyDep.dependsOn "testPy39transformers-430" toxTask "testPy39embeddingsMLTransform", "py39-embeddings", "${posargs}" test.dependsOn "testPy39embeddingsMLTransform" -postCommitPyDep.dependsOn "testPy39embeddingsMLTransform" +// TODO(https://github.com/apache/beam/issues/32965): re-enable this suite for the dep +// postcommit once the sentence-transformers import error is debugged +// postCommitPyDep.dependsOn "testPy39embeddingsMLTransform" // Part of MLTransform embeddings test suite but requires tensorflow hub, which we need to test on // mutliple versions so keeping this suite separate. diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index d733fd17fb6b..68ac15ced70d 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -101,18 +101,42 @@ commands = python apache_beam/examples/complete/autocomplete_test.py bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" -[testenv:py{39,310,311,312}-ml] +[testenv:py{39,310,311}-ml] # Don't set TMPDIR to avoid "AF_UNIX path too long" errors in certain tests. setenv = extras = test,gcp,dataframe,ml_test commands = + # Log tensorflow version for debugging + /bin/sh -c "pip freeze | grep -E tensorflow" bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" -[testenv:py{39,310,311,312}-dask] -extras = test,dask +[testenv:py312-ml] +# many packages do not support py3.12 +# Don't set TMPDIR to avoid "AF_UNIX path too long" errors in certain tests. +setenv = +extras = test,gcp,dataframe,p312_ml_test commands = + # Log tensorflow version for debugging + /bin/sh -c "pip freeze | grep -E tensorflow" bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" +[testenv:py{39,310,311,312}-dask] +extras = test,dask,dataframes +commands_pre = + pip install 'distributed>=2024.4.2' 'dask>=2024.4.2' +commands = + bash {toxinidir}/scripts/run_pytest.sh {envname} {toxinidir}/apache_beam/runners/dask/ + +[testenv:py{39,310,311,312}-win-dask] +# use the tight range since the latest dask requires cloudpickle 3.0 +commands_pre = + pip install 'distributed>=2024.4.2,<2024.9.0' 'dask>=2024.4.2,<2024.9.0' +commands = + python apache_beam/examples/complete/autocomplete_test.py + bash {toxinidir}/scripts/run_pytest.sh {envname} {toxinidir}/apache_beam/runners/dask/ +install_command = {envbindir}/python.exe {envbindir}/pip.exe install --retries 10 {opts} {packages} +list_dependencies_command = {envbindir}/python.exe {envbindir}/pip.exe freeze + [testenv:py39-cloudcoverage] deps = pytest-cov==3.0.0 @@ -149,7 +173,7 @@ commands = [testenv:mypy] deps = - mypy==0.790 + mypy==1.13.0 dask==2022.01.0 distributed==2022.01.0 # make extras available in case any of these libs are typed @@ -163,10 +187,10 @@ commands = [testenv:docs] extras = test,gcp,docs,interactive,dataframe,dask deps = - Sphinx==1.8.5 - sphinx_rtd_theme==0.4.3 - docutils<0.18 - Jinja2==3.0.3 # TODO(https://github.com/apache/beam/issues/21587): Sphinx version is too old. + Sphinx==7.4.7 + sphinx_rtd_theme==3.0.1 + docutils>=0.18.1 + Jinja2==3.1.0 commands = time {toxinidir}/scripts/generate_pydoc.sh @@ -285,6 +309,10 @@ extras = test commands = bash {toxinidir}/scripts/pytest_validates_runner.sh {envname} {toxinidir}/apache_beam/runners/portability/spark_runner_test.py {posargs} +[testenv:prism-runner-test] +extras = test +commands = + bash {toxinidir}/scripts/pytest_validates_runner.sh {envname} {toxinidir}/apache_beam/runners/portability/prism_runner_test.py {posargs} [testenv:py{39,310}-pyarrow-{3,9,10,11,12,13,14,15,16}] deps = @@ -293,6 +321,7 @@ deps = # Since Pandas 2 requires pyarrow>=7, downgrade pandas for this test. 3: pyarrow>=3,<4 3: pandas<2 + 3: numpy>=1.14.3,<1.27.0 # Test against versions of pyarrow released in last ~2 years. 9: pyarrow>=9,<10 10: pyarrow>=10,<11 @@ -314,10 +343,13 @@ commands = [testenv:py{39,310}-pandas-{14,15,20}] deps = 14: pandas>=1.4.3,<1.5.0 + 14: numpy>=1.14.3,<1.27.0 # Exclude 1.5.0 and 1.5.1 because of https://github.com/pandas-dev/pandas/issues/45725 15: pandas>=1.5.2,<1.6.0 + 15: numpy>=1.14.3,<1.27.0 20: pandas>=2.0.0,<2.1.0 20: pyarrow>=7 + 20: numpy>=1.14.3,<1.27.0 commands = # Log pandas and numpy version for debugging /bin/sh -c "pip freeze | grep -E '(pandas|numpy)'" @@ -386,10 +418,13 @@ commands = [testenv:py39-tensorflow-212] deps = - 212: tensorflow>=2.12rc1,<2.13 - # Help pip resolve conflict with typing-extensions for old version of TF https://github.com/apache/beam/issues/30852 - 212: pydantic<2.7 + 212: + tensorflow>=2.12rc1,<2.13 + # Help pip resolve conflict with typing-extensions for old version of TF https://github.com/apache/beam/issues/30852 + pydantic<2.7 extras = test,gcp +commands_pre = + pip install -U 'protobuf==4.25.5' commands = # Log tensorflow version for debugging /bin/sh -c "pip freeze | grep -E tensorflow" @@ -420,7 +455,8 @@ deps = 430: transformers>=4.30.0,<4.31.0 torch>=1.9.0,<1.14.0 tensorflow==2.12.0 -extras = test,gcp + protobuf==4.25.5 +extras = test,gcp,ml_test commands = # Log transformers and its dependencies version for debugging /bin/sh -c "pip freeze | grep -E transformers" diff --git a/sdks/standard_expansion_services.yaml b/sdks/standard_expansion_services.yaml index ad965c8a1ee3..623e33fe2e7c 100644 --- a/sdks/standard_expansion_services.yaml +++ b/sdks/standard_expansion_services.yaml @@ -48,7 +48,7 @@ # Handwritten Kafka wrappers already exist in apache_beam/io/kafka.py - 'beam:schematransform:org.apache.beam:kafka_write:v1' - 'beam:schematransform:org.apache.beam:kafka_read:v1' - # Not ready to generate + # Available through apache_beam.transforms.managed.[Read/Write] - 'beam:schematransform:org.apache.beam:iceberg_write:v1' - 'beam:schematransform:org.apache.beam:iceberg_read:v1' diff --git a/sdks/typescript/package.json b/sdks/typescript/package.json index 3dcbab684090..9ccfcaa663d1 100644 --- a/sdks/typescript/package.json +++ b/sdks/typescript/package.json @@ -1,6 +1,6 @@ { "name": "apache-beam", - "version": "2.61.0-SNAPSHOT", + "version": "2.62.0-SNAPSHOT", "devDependencies": { "@google-cloud/bigquery": "^5.12.0", "@types/mocha": "^9.0.0", diff --git a/sdks/typescript/src/apache_beam/runners/flink.ts b/sdks/typescript/src/apache_beam/runners/flink.ts index ad4339b431f5..ab2d641b3302 100644 --- a/sdks/typescript/src/apache_beam/runners/flink.ts +++ b/sdks/typescript/src/apache_beam/runners/flink.ts @@ -28,7 +28,7 @@ import { JavaJarService } from "../utils/service"; const MAGIC_HOST_NAMES = ["[local]", "[auto]"]; // These should stay in sync with gradle.properties. -const PUBLISHED_FLINK_VERSIONS = ["1.15", "1.16", "1.17", "1.18"]; +const PUBLISHED_FLINK_VERSIONS = ["1.17", "1.18", "1.19"]; const defaultOptions = { flinkMaster: "[local]", diff --git a/settings.gradle.kts b/settings.gradle.kts index 9701b4dbc06f..d90bb3fb5b82 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -19,7 +19,7 @@ import com.gradle.enterprise.gradleplugin.internal.extension.BuildScanExtensionW pluginManagement { plugins { - id("org.javacc.javacc") version "3.0.2" // enable the JavaCC parser generator + id("org.javacc.javacc") version "3.0.3" // enable the JavaCC parser generator } } @@ -125,14 +125,6 @@ include(":runners:extensions-java:metrics") * verify versions in website/www/site/content/en/documentation/runners/flink.md * verify version in sdks/python/apache_beam/runners/interactive/interactive_beam.py */ -// Flink 1.15 -include(":runners:flink:1.15") -include(":runners:flink:1.15:job-server") -include(":runners:flink:1.15:job-server-container") -// Flink 1.16 -include(":runners:flink:1.16") -include(":runners:flink:1.16:job-server") -include(":runners:flink:1.16:job-server-container") // Flink 1.17 include(":runners:flink:1.17") include(":runners:flink:1.17:job-server") @@ -141,6 +133,10 @@ include(":runners:flink:1.17:job-server-container") include(":runners:flink:1.18") include(":runners:flink:1.18:job-server") include(":runners:flink:1.18:job-server-container") +// Flink 1.19 +include(":runners:flink:1.19") +include(":runners:flink:1.19:job-server") +include(":runners:flink:1.19:job-server-container") /* End Flink Runner related settings */ include(":runners:twister2") include(":runners:google-cloud-dataflow-java") @@ -337,6 +333,8 @@ project(":beam-test-gha").projectDir = file(".github") include("beam-validate-runner") project(":beam-validate-runner").projectDir = file(".test-infra/validate-runner") include("com.google.api.gax.batching") +include("sdks:java:io:kafka:kafka-312") +findProject(":sdks:java:io:kafka:kafka-312")?.name = "kafka-312" include("sdks:java:io:kafka:kafka-251") findProject(":sdks:java:io:kafka:kafka-251")?.name = "kafka-251" include("sdks:java:io:kafka:kafka-241") @@ -349,12 +347,6 @@ include("sdks:java:io:kafka:kafka-211") findProject(":sdks:java:io:kafka:kafka-211")?.name = "kafka-211" include("sdks:java:io:kafka:kafka-201") findProject(":sdks:java:io:kafka:kafka-201")?.name = "kafka-201" -include("sdks:java:io:kafka:kafka-111") -findProject(":sdks:java:io:kafka:kafka-111")?.name = "kafka-111" -include("sdks:java:io:kafka:kafka-100") -findProject(":sdks:java:io:kafka:kafka-100")?.name = "kafka-100" -include("sdks:java:io:kafka:kafka-01103") -findProject(":sdks:java:io:kafka:kafka-01103")?.name = "kafka-01103" include("sdks:java:managed") findProject(":sdks:java:managed")?.name = "managed" include("sdks:java:io:iceberg") diff --git a/vendor/grpc-1_60_1/build.gradle b/vendor/grpc-1_60_1/build.gradle index 834c496d9ca4..da152ef10f75 100644 --- a/vendor/grpc-1_60_1/build.gradle +++ b/vendor/grpc-1_60_1/build.gradle @@ -23,7 +23,7 @@ plugins { id 'org.apache.beam.vendor-java' } description = "Apache Beam :: Vendored Dependencies :: gRPC :: 1.60.1" group = "org.apache.beam" -version = "0.2" +version = "0.3" vendorJava( dependencies: GrpcVendoring_1_60_1.dependencies(), diff --git a/website/www/site/config.toml b/website/www/site/config.toml index e937289fbde7..0ed8aef2a906 100644 --- a/website/www/site/config.toml +++ b/website/www/site/config.toml @@ -104,7 +104,7 @@ github_project_repo = "https://github.com/apache/beam" [params] description = "Apache Beam is an open source, unified model and set of language-specific SDKs for defining and executing data processing workflows, and also data ingestion and integration flows, supporting Enterprise Integration Patterns (EIPs) and Domain Specific Languages (DSLs). Dataflow pipelines simplify the mechanics of large-scale batch and streaming data processing and can run on a number of runtimes like Apache Flink, Apache Spark, and Google Cloud Dataflow (a cloud service). Beam also brings DSL in different languages, allowing users to easily implement their data integration processes." -release_latest = "2.59.0" +release_latest = "2.61.0" # The repository and branch where the files live in Github or Colab. This is used # to serve and stage from your local branch, but publish to the master branch. # e.g. https://github.com/{{< param branch_repo >}}/path/to/notebook.ipynb diff --git a/website/www/site/content/en/blog/beam-2.52.0.md b/website/www/site/content/en/blog/beam-2.52.0.md index 2e604c8fabf8..f468eb4811e3 100644 --- a/website/www/site/content/en/blog/beam-2.52.0.md +++ b/website/www/site/content/en/blog/beam-2.52.0.md @@ -52,6 +52,7 @@ classes finally moved to `extensions/avro`. In case if it's still required to us as a workaround, a copy of "old" `CountingSource` class should be placed into a project code and used directly ([#25252](https://github.com/apache/beam/issues/25252)). * Renamed `host` to `firestoreHost` in `FirestoreOptions` to avoid potential conflict of command line arguments (Java) ([#29201](https://github.com/apache/beam/pull/29201)). +* Transforms which use `SnappyCoder` are update incompatible with previous versions of the same transform (Java) on some runners. This includes PubSubIO's read ([#28655](https://github.com/apache/beam/pull/28655#issuecomment-2407839769)). ## Bugfixes @@ -64,6 +65,12 @@ as a workaround, a copy of "old" `CountingSource` class should be placed into a * Fixed [CVE-2023-39325](https://www.cve.org/CVERecord?id=CVE-2023-39325) (Java/Python/Go) ([#29118](https://github.com/apache/beam/issues/29118)). * Mitigated [CVE-2023-47248](https://nvd.nist.gov/vuln/detail/CVE-2023-47248) (Python) [#29392](https://github.com/apache/beam/issues/29392). +## Known issues +* MLTransform drops the identical elements in the output PCollection. For any duplicate elements, a single element will be emitted downstream. ([#29600](https://github.com/apache/beam/issues/29600)). +* Some Python pipelines that run with 2.52.0-2.54.0 SDKs and use large materialized side inputs might be affected by a performance regression. To restore the prior behavior on these SDK versions, supply the `--max_cache_memory_usage_mb=0` pipeline option. (Python) ([#30360](https://github.com/apache/beam/issues/30360)). +* Users who lauch Python pipelines in an environment without internet access and use the `--setup_file` pipeline option might experience an increase in pipeline submission time. This has been fixed in 2.56.0 ([#31070](https://github.com/apache/beam/pull/31070)). +* Transforms which use `SnappyCoder` are update incompatible with previous versions of the same transform (Java) on some runners. This includes PubSubIO's read ([#28655](https://github.com/apache/beam/pull/28655#issuecomment-2407839769)). + ## List of Contributors According to git shortlog, the following people contributed to the 2.52.0 release. Thank you to all contributors! diff --git a/website/www/site/content/en/blog/beam-2.57.0.md b/website/www/site/content/en/blog/beam-2.57.0.md index b583b4ee3c51..7be75a7891c5 100644 --- a/website/www/site/content/en/blog/beam-2.57.0.md +++ b/website/www/site/content/en/blog/beam-2.57.0.md @@ -79,6 +79,10 @@ For more information on changes in 2.57.0, check out the [detailed release notes ## Known Issues * Python pipelines that run with 2.53.0-2.58.0 SDKs and read data from GCS might be affected by a data corruption issue ([#32169](https://github.com/apache/beam/issues/32169)). The issue will be fixed in 2.59.0 ([#32135](https://github.com/apache/beam/pull/32135)). To work around this, update the google-cloud-storage package to version 2.18.2 or newer. +* BigQuery Enrichment (Python): The following issues are present when using the BigQuery enrichment transform ([#32780](https://github.com/apache/beam/pull/32780)): + * Duplicate Rows: Multiple conditions may be applied incorrectly, leading to the duplication of rows in the output. + * Incorrect Results with Batched Requests: Conditions may not be correctly scoped to individual rows within the batch, potentially causing inaccurate results. + * Fixed in 2.61.0. For the most up to date list of known issues, see https://github.com/apache/beam/blob/master/CHANGES.md diff --git a/website/www/site/content/en/blog/beam-2.58.0.md b/website/www/site/content/en/blog/beam-2.58.0.md index cfdf23c725e0..0d944fd419d4 100644 --- a/website/www/site/content/en/blog/beam-2.58.0.md +++ b/website/www/site/content/en/blog/beam-2.58.0.md @@ -53,6 +53,10 @@ For more information about changes in 2.58.0, check out the [detailed release no * Python pipelines that run with 2.53.0-2.58.0 SDKs and read data from GCS might be affected by a data corruption issue ([#32169](https://github.com/apache/beam/issues/32169)). The issue will be fixed in 2.59.0 ([#32135](https://github.com/apache/beam/pull/32135)). To work around this, update the google-cloud-storage package to version 2.18.2 or newer. * [KafkaIO] Records read with `ReadFromKafkaViaSDF` are redistributed and may contain duplicates regardless of the configuration. This affects Java pipelines with Dataflow v2 runner and xlang pipelines reading from Kafka, ([#32196](https://github.com/apache/beam/issues/32196)) +* BigQuery Enrichment (Python): The following issues are present when using the BigQuery enrichment transform ([#32780](https://github.com/apache/beam/pull/32780)): + * Duplicate Rows: Multiple conditions may be applied incorrectly, leading to the duplication of rows in the output. + * Incorrect Results with Batched Requests: Conditions may not be correctly scoped to individual rows within the batch, potentially causing inaccurate results. + * Fixed in 2.61.0. For the most up to date list of known issues, see https://github.com/apache/beam/blob/master/CHANGES.md diff --git a/website/www/site/content/en/blog/beam-2.59.0.md b/website/www/site/content/en/blog/beam-2.59.0.md index 6ce81e7c48eb..846e45916d66 100644 --- a/website/www/site/content/en/blog/beam-2.59.0.md +++ b/website/www/site/content/en/blog/beam-2.59.0.md @@ -66,8 +66,11 @@ For more information on changes in 2.59.0, check out the [detailed release notes * In the 2.59.0 release, Prism passes most runner validations tests with the exceptions of pipelines using the following features: OrderedListState, OnWindowExpiry (eg. GroupIntoBatches), CustomWindows, MergingWindowFns, Trigger and WindowingStrategy associated features, Bundle Finalization, Looping Timers, and some Coder related issues such as with Python combiner packing, and Java Schema transforms, and heterogenous flatten coders. Processing Time timers do not yet have real time support. * If your pipeline is having difficulty with the Python or Java direct runners, but runs well on Prism, please let us know. - * Java file-based IOs read or write lots (100k+) files could experience slowness and/or broken metrics visualization on Dataflow UI [#32649](https://github.com/apache/beam/issues/32649). +* BigQuery Enrichment (Python): The following issues are present when using the BigQuery enrichment transform ([#32780](https://github.com/apache/beam/pull/32780)): + * Duplicate Rows: Multiple conditions may be applied incorrectly, leading to the duplication of rows in the output. + * Incorrect Results with Batched Requests: Conditions may not be correctly scoped to individual rows within the batch, potentially causing inaccurate results. + * Fixed in 2.61.0. For the most up to date list of known issues, see https://github.com/apache/beam/blob/master/CHANGES.md diff --git a/website/www/site/content/en/blog/beam-2.60.0.md b/website/www/site/content/en/blog/beam-2.60.0.md new file mode 100644 index 000000000000..e5767cff5114 --- /dev/null +++ b/website/www/site/content/en/blog/beam-2.60.0.md @@ -0,0 +1,85 @@ +--- +title: "Apache Beam 2.60.0" +date: 2024-10-17 15:00:00 -0500 +categories: + - blog + - release +authors: + - yhu +--- + + +We are happy to present the new 2.60.0 release of Beam. +This release includes both improvements and new functionality. +See the [download page](/get-started/downloads/#2600-2024-10-17) for this release. + + + +For more information on changes in 2.60.0, check out the [detailed release notes](https://github.com/apache/beam/milestone/24). + +## Highlights + +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) +* [Managed Iceberg] Added support for streaming writes ([#32451](https://github.com/apache/beam/pull/32451)) +* [Managed Iceberg] Added auto-sharding for streaming writes ([#32612](https://github.com/apache/beam/pull/32612)) +* [Managed Iceberg] Added support for writing to dynamic destinations ([#32565](https://github.com/apache/beam/pull/32565)) + +## New Features / Improvements + +* Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). +* Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349)) +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) +* Prism release binaries and container bootloaders are now being built with the latest Go 1.23 patch. ([#32575](https://github.com/apache/beam/pull/32575)) +* Prism + * Prism now supports Bundle Finalization. ([#32425](https://github.com/apache/beam/pull/32425)) +* Significantly improved performance of Kafka IO reads that enable [commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--) by removing the data reshuffle from SDF implementation. ([#31682](https://github.com/apache/beam/pull/31682)). +* Added support for dynamic writing in MqttIO (Java) ([#19376](https://github.com/apache/beam/issues/19376)) +* Optimized Spark Runner parDo transform evaluator (Java) ([#32537](https://github.com/apache/beam/issues/32537)) +* [Managed Iceberg] More efficient manifest file writes/commits ([#32666](https://github.com/apache/beam/issues/32666)) + +## Breaking Changes + +* In Python, assert_that now throws if it is not in a pipeline context instead of silently succeeding ([#30771](https://github.com/apache/beam/pull/30771)) +* In Python and YAML, ReadFromJson now override the dtype from None to + an explicit False. Most notably, string values like `"123"` are preserved + as strings rather than silently coerced (and possibly truncated) to numeric + values. To retain the old behavior, pass `dtype=True` (or any other value + accepted by `pandas.read_json`). +* Users of KafkaIO Read transform that enable [commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--) might encounter pipeline graph compatibility issues when updating the pipeline. To mitigate, set the `updateCompatibilityVersion` option to the SDK version used for the original pipeline, example `--updateCompatabilityVersion=2.58.1` + +## Deprecations + +* Python 3.8 is reaching EOL and support is being removed in Beam 2.61.0. The 2.60.0 release will warn users +when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) + +## Bugfixes + +* (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)). +* (Java, Python, Go) Fixed PeriodicSequence backlog bytes reporting, which was preventing Dataflow Runner autoscaling from functioning properly ([#32506](https://github.com/apache/beam/issues/32506)). +* (Java) Fix improper decoding of rows with schemas containing nullable fields when encoded with a schema with equal encoding positions but modified field order. ([#32388](https://github.com/apache/beam/issues/32388)). +* (Java) Skip close on bundles in BigtableIO.Read ([#32661](https://github.com/apache/beam/pull/32661), [#32759](https://github.com/apache/beam/pull/32759)). + +## Known Issues + +* BigQuery Enrichment (Python): The following issues are present when using the BigQuery enrichment transform ([#32780](https://github.com/apache/beam/pull/32780)): + * Duplicate Rows: Multiple conditions may be applied incorrectly, leading to the duplication of rows in the output. + * Incorrect Results with Batched Requests: Conditions may not be correctly scoped to individual rows within the batch, potentially causing inaccurate results. + * Fixed in 2.61.0. + +For the most up to date list of known issues, see https://github.com/apache/beam/blob/master/CHANGES.md + +## List of Contributors + +According to git shortlog, the following people contributed to the 2.60.0 release. Thank you to all contributors! + +Ahmed Abualsaud, Aiden Grossman, Arun Pandian, Bartosz Zablocki, Chamikara Jayalath, Claire McGinty, DKPHUONG, Damon Douglass, Danny McCormick, Dip Patel, Ferran Fernández Garrido, Hai Joey Tran, Hyeonho Kim, Igor Bernstein, Israel Herraiz, Jack McCluskey, Jaehyeon Kim, Jeff Kinard, Jeffrey Kinard, Joey Tran, Kenneth Knowles, Kirill Berezin, Michel Davit, Minbo Bae, Naireen Hussain, Niel Markwick, Nito Buendia, Reeba Qureshi, Reuven Lax, Robert Bradshaw, Robert Burke, Rohit Sinha, Ryan Fu, Sam Whittle, Shunping Huang, Svetak Sundhar, Udaya Chathuranga, Vitaly Terentyev, Vlado Djerek, Yi Hu, Claude van der Merwe, XQ Hu, Martin Trieu, Valentyn Tymofieiev, twosom diff --git a/website/www/site/content/en/blog/beam-2.61.0.md b/website/www/site/content/en/blog/beam-2.61.0.md new file mode 100644 index 000000000000..a4c7ac0cefbd --- /dev/null +++ b/website/www/site/content/en/blog/beam-2.61.0.md @@ -0,0 +1,74 @@ +--- +title: "Apache Beam 2.61.0" +date: 2024-11-25 15:00:00 -0500 +categories: + - blog + - release +authors: + - damccorm +--- + + +We are happy to present the new 2.61.0 release of Beam. +This release includes both improvements and new functionality. +See the [download page](/get-started/downloads/#2610-2024-11-25) for this release. + + + +For more information on changes in 2.61.0, check out the [detailed release notes](https://github.com/apache/beam/milestone/25). + +## Highlights + +* [Python] Introduce Managed Transforms API ([#31495](https://github.com/apache/beam/pull/31495)) +* Flink 1.19 support added ([#32648](https://github.com/apache/beam/pull/32648)) + +## I/Os + +* [Managed Iceberg] Support creating tables if needed ([#32686](https://github.com/apache/beam/pull/32686)) +* [Managed Iceberg] Now available in Python SDK ([#31495](https://github.com/apache/beam/pull/31495)) +* [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) +* BigQuery CDC writes are now available in Python SDK, only supported when using StorageWrite API at least once mode ([#32527](https://github.com/apache/beam/issues/32527)) +* [Managed Iceberg] Allow updating table partition specs during pipeline runtime ([#32879](https://github.com/apache/beam/pull/32879)) +* Added BigQueryIO as a Managed IO ([#31486](https://github.com/apache/beam/pull/31486)) +* Support for writing to [Solace messages queues](https://solace.com/) (`SolaceIO.Write`) added (Java) ([#31905](https://github.com/apache/beam/issues/31905)). + +## New Features / Improvements + +* Added support for read with metadata in MqttIO (Java) ([#32195](https://github.com/apache/beam/issues/32195)) +* Added support for processing events which use a global sequence to "ordered" extension (Java) ([#32540](https://github.com/apache/beam/pull/32540)) +* Add new meta-transform FlattenWith and Tee that allow one to introduce branching + without breaking the linear/chaining style of pipeline construction. +* Use Prism as a fallback to the Python Portable runner when running a pipeline with the Python Direct runner ([#32876](https://github.com/apache/beam/pull/32876)) + +## Deprecations + +* Removed support for Flink 1.15 and 1.16 +* Removed support for Python 3.8 + +## Bugfixes + +* (Java) Fixed tearDown not invoked when DoFn throws on Portable Runners ([#18592](https://github.com/apache/beam/issues/18592), [#31381](https://github.com/apache/beam/issues/31381)). +* (Java) Fixed protobuf error with MapState.remove() in Dataflow Streaming Java Legacy Runner without Streaming Engine ([#32892](https://github.com/apache/beam/issues/32892)). +* Adding flag to support conditionally disabling auto-commit in JdbcIO ReadFn ([#31111](https://github.com/apache/beam/issues/31111)) + +## Known Issues + +N/A + +For the most up to date list of known issues, see https://github.com/apache/beam/blob/master/CHANGES.md + +## List of Contributors + +According to git shortlog, the following people contributed to the 2.60.0 release. Thank you to all contributors! + +Ahmed Abualsaud, Ahmet Altay, Arun Pandian, Ayush Pandey, Chamikara Jayalath, Chris Ashcraft, Christoph Grotz, DKPHUONG, Damon, Danny Mccormick, Dmitry Ulyumdzhiev, Ferran Fernández Garrido, Hai Joey Tran, Hyeonho Kim, Idan Attias, Israel Herraiz, Jack McCluskey, Jan Lukavský, Jeff Kinard, Jeremy Edwards, Joey Tran, Kenneth Knowles, Maciej Szwaja, Manit Gupta, Mattie Fu, Michel Davit, Minbo Bae, Mohamed Awnallah, Naireen Hussain, Rebecca Szper, Reeba Qureshi, Reuven Lax, Robert Bradshaw, Robert Burke, S. Veyrié, Sam Whittle, Sergei Lilichenko, Shunping Huang, Steven van Rossum, Tan Le, Thiago Nunes, Vitaly Terentyev, Vlado Djerek, Yi Hu, claudevdm, fozzie15, johnjcasey, kushmiD, liferoad, martin trieu, pablo rodriguez defino, razvanculea, s21lee, tvalentyn, twosom diff --git a/website/www/site/content/en/blog/beam-summit-2024-overview.md b/website/www/site/content/en/blog/beam-summit-2024-overview.md new file mode 100644 index 000000000000..5cf922d69544 --- /dev/null +++ b/website/www/site/content/en/blog/beam-summit-2024-overview.md @@ -0,0 +1,59 @@ +--- +title: "Apache Beam Summit 2024: Unlocking the power of ML for data processing" +date: 2024-10-16 00:00:01 -0800 +categories: + - blog +aliases: + - /blog/2024/10/16/beam-summit-2024-overview.html +authors: + - liferoad + - damccorm + - rez +--- + + +At the recently concluded [Beam Summit 2024](https://beamsummit.org/), a two-day event held from September 4 to 5, numerous captivating presentations showcased the potential of Beam to address a wide range of challenges, with an emphasis on machine learning (ML). These challenges included feature engineering, data enrichment, and model inference for large-scale distributed data. In all, the summit included [47 talks](https://beamsummit.org/sessions/2024/), with 16 focused specifically on ML use cases or features and many more touching on these topics. + +The talks displayed the breadth and diversity of the Beam community. Among the speakers and attendees, [23 countries](https://docs.google.com/presentation/d/1IJ1sExHzrzIFF5QXKWlcAuPdp7lKOepRQKl9BnfHxJw/edit#slide=id.g3058d3e2f5f_0_10) were represented. Attendees included Beam users, committers in the Beam project, Beam Google Summer of Code contributors, and data processing/machine learning experts. + +## User-friendly turnkey transforms for ML + +With the features recently added to Beam, Beam now offers a set of rich turn-key transforms for ML users that handle a wide range of ML-Ops tasks. These transforms include: + +* [RunInference](https://beam.apache.org/documentation/ml/overview/#prediction-and-inference): deploy ML models on CPUs and GPUs +* [Enrichment](https://beam.apache.org/documentation/ml/overview/#data-processing): enrich data for ML feature enhancements +* [MLTransform](https://beam.apache.org/documentation/ml/overview/#data-processing): transform data into ML features + +The Summit talks covering both how to use these features and how people are already using them. Highlights included: + +* A talk about [scaling autonomous driving at Cruise](https://beamsummit.org/slides/2024/ScalingAutonomousDrivingwithApacheBeam.pdf) +* Multiple talks about deploying LLMs for batch and streaming inference +* Three different talks about streaming processing for [RAG](https://cloud.google.com/use-cases/retrieval-augmented-generation) (including [a talk](https://www.youtube.com/watch?v=X_VzKQOcpC4) from one of Beam's Google Summer of Code contributors\!) + +## Beam YAML: Simplifying ML data processing + +Beam pipeline creation can be challenging and often requires learning concepts, managing dependencies, debugging, and maintaining code for ML tasks. To simplify the entry point, [Beam YAML](https://beam.apache.org/blog/beam-yaml-release/) introduces a declarative approach that uses YAML configuration files to create data processing pipelines. No coding is required. + +Beam Summit was the first opportunity that the Beam community had to show off some of the use cases of Beam YAML. It featured several talks about how Beam YAML is already a core part of many users' workflows at companies like [MavenCode](https://beamsummit.org/slides/2024/ALowCodeStructuredApproachtoDeployingApacheBeamMLWorkloadsonKubernetesusingBeamStack.pdf) and [ChartBoost](https://youtu.be/avSXvbScbW0). With Beam YAML, these companies are able to build configuration-based data processing systems, significantly lowering the bar for entry at their companies. + +## Prism: Provide a unified ML pipeline development framework for local and remote runner environments + +Beam provides a variety of support for portable runners, but developing a local pipeline has traditionally been painful. Local runners are often incomplete and incompatible with remote runners, such as DataflowRunner and FlinkRunner. + +At Beam Summit, Beam contributors introduced [the Prism local runner](https://youtu.be/R4iNwLBa3VQ) to the community. Prism greatly improves the local developer experience and reduces the gap between local and remote execution. In particular, when handling complicated ML tasks, Prism guarantees consistent runner behavior across these runners, a task that had previously lacked consistent support. + +## Summary + +Beam Summit 2024 showcased the tremendous potential of Apache Beam for addressing a wide range of data processing and machine learning challenges. We look forward to seeing even more innovative use cases and contributions in the future. + +To stay updated on the latest Beam developments and events, visit [the Apache Beam website](https://beam.apache.org/get-started/) and follow us on [social media](https://www.linkedin.com/company/apache-beam/). We encourage you to join [the Beam community](https://beam.apache.org/community/contact-us/) and [contribute to the project](https://beam.apache.org/contribute/). Together, let's unlock the full potential of Beam and shape the future of data processing and machine learning. diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md new file mode 100644 index 000000000000..995b59b978c1 --- /dev/null +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -0,0 +1,277 @@ +--- +title: "Efficient Streaming Data Processing with Beam YAML and Protobuf" +date: "2024-09-20T11:53:38+02:00" +categories: + - blog +authors: + - ffernandez92 +--- + + +# Efficient Streaming Data Processing with Beam YAML and Protobuf + +As streaming data processing grows, so do its maintenance, complexity, and costs. +This post explains how to efficiently scale pipelines by using [Protobuf](https://protobuf.dev/), +which ensures that pipelines are reusable and quick to deploy. The goal is to keep this process simple +for engineers to implement using [Beam YAML](https://beam.apache.org/documentation/sdks/yaml/). + + + +## Simplify pipelines with Beam YAML + +Creating a pipeline in Beam can be somewhat difficult, especially for new Apache Beam users. +Setting up the project, managing dependencies, and so on can be challenging. +Beam YAML eliminates most of the boilerplate code, +which allows you to focus on the most important part of the work: data transformation. + +Some of the key benefits of Beam YAML include: + +* **Readability:** By using a declarative language ([YAML](https://yaml.org/)), the pipeline configuration is more human readable. +* **Reusability:** Reusing the same components across different pipelines is simplified. +* **Maintainability:** Pipeline maintenance and updates are easier. + +The following template shows an example of reading events from a [Kafka](https://kafka.apache.org/intro) topic and +writing them into [BigQuery](https://cloud.google.com/bigquery?hl=en). + +```yaml +pipeline: + transforms: + - type: ReadFromKafka + name: ReadProtoMovieEvents + config: + topic: 'TOPIC_NAME' + format: RAW/AVRO/JSON/PROTO + bootstrap_servers: 'BOOTSTRAP_SERVERS' + schema: 'SCHEMA' + - type: WriteToBigQuery + name: WriteMovieEvents + input: ReadProtoMovieEvents + config: + table: 'PROJECT_ID.DATASET.MOVIE_EVENTS_TABLE' + useAtLeastOnceSemantics: true + +options: + streaming: true + dataflow_service_options: [streaming_mode_at_least_once] +``` + +## The complete workflow + +This section demonstrates the complete workflow for this pipeline. + +### Create a simple proto event + +The following code creates a simple movie event. + +```protobuf +// events/v1/movie_event.proto + +syntax = "proto3"; + +package event.v1; + +import "bq_field.proto"; +import "bq_table.proto"; +import "buf/validate/validate.proto"; +import "google/protobuf/wrappers.proto"; + +message MovieEvent { + option (gen_bq_schema.bigquery_opts).table_name = "movie_table"; + google.protobuf.StringValue event_id = 1 [(gen_bq_schema.bigquery).description = "Unique Event ID"]; + google.protobuf.StringValue user_id = 2 [(gen_bq_schema.bigquery).description = "Unique User ID"]; + google.protobuf.StringValue movie_id = 3 [(gen_bq_schema.bigquery).description = "Unique Movie ID"]; + google.protobuf.Int32Value rating = 4 [(buf.validate.field).int32 = { + // validates the average rating is at least 0 + gte: 0, + // validates the average rating is at most 100 + lte: 100 + }, (gen_bq_schema.bigquery).description = "Movie rating"]; + string event_dt = 5 [ + (gen_bq_schema.bigquery).type_override = "DATETIME", + (gen_bq_schema.bigquery).description = "UTC Datetime representing when we received this event. Format: YYYY-MM-DDTHH:MM:SS", + (buf.validate.field) = { + string: { + pattern: "^\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}$" + }, + ignore_empty: false, + } + ]; +} +``` + +Because these events are written to BigQuery, +the [`bq_field`](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_field.proto) proto +and the [`bq_table`](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_table.proto) proto are imported. +These proto files help generate the BigQuery JSON schema. +This example also demonstrates a shift-left approach, which moves testing, quality, +and performance as early as possible in the development process. For example, to ensure that only valid events are generated from the source, the `buf.validate` elements are included. + +After you create the `movie_event.proto` proto in the `events/v1` folder, you can generate +the necessary [file descriptor](https://buf.build/docs/reference/descriptors). +A file descriptor is a compiled representation of the schema that allows various tools and systems +to understand and work with protobuf data dynamically. To simplify the process, this example uses Buf, +which requires the following configuration files. + + +Buf configuration: + +```yaml +# buf.yaml + +version: v2 +deps: + - buf.build/googlecloudplatform/bq-schema-api + - buf.build/bufbuild/protovalidate +breaking: + use: + - FILE +lint: + use: + - DEFAULT +``` + +```yaml +# buf.gen.yaml + +version: v2 +managed: + enabled: true +plugins: + # Python Plugins + - remote: buf.build/protocolbuffers/python + out: gen/python + - remote: buf.build/grpc/python + out: gen/python + + # Java Plugins + - remote: buf.build/protocolbuffers/java:v25.2 + out: gen/maven/src/main/java + - remote: buf.build/grpc/java + out: gen/maven/src/main/java + + # BQ Schemas + - remote: buf.build/googlecloudplatform/bq-schema:v1.1.0 + out: protoc-gen/bq_schema + +``` + +Run the following two commands to generate the necessary Java, Python, BigQuery schema, and Descriptor file: + +```bash +// Generate the buf.lock file +buf deps update + +// It generates the descriptor in descriptor.binp. +buf build . -o descriptor.binp --exclude-imports + +// It generates the Java, Python and BigQuery schema as described in buf.gen.yaml +buf generate --include-imports +``` + +### Make the Beam YAML read proto + +Make the following modifications to the to the YAML file: + +```yaml +# movie_events_pipeline.yml + +pipeline: + transforms: + - type: ReadFromKafka + name: ReadProtoMovieEvents + config: + topic: 'movie_proto' + format: PROTO + bootstrap_servers: '' + file_descriptor_path: 'gs://my_proto_bucket/movie/v1.0.0/descriptor.binp' + message_name: 'event.v1.MovieEvent' + - type: WriteToBigQuery + name: WriteMovieEvents + input: ReadProtoMovieEvents + config: + table: '.raw.movie_table' + useAtLeastOnceSemantics: true +options: + streaming: true + dataflow_service_options: [streaming_mode_at_least_once] +``` + +This step changes the format to `PROTO` and adds the `file_descriptor_path` and the `message_name`. + +### Deploy the pipeline with Terraform + +You can use [Terraform](https://www.terraform.io/) to deploy the Beam YAML pipeline +with [Dataflow](https://cloud.google.com/products/dataflow?hl=en) as the runner. +The following Terraform code example demonstrates how to achieve this: + +```hcl +// Enable Dataflow API. +resource "google_project_service" "enable_dataflow_api" { + project = var.gcp_project_id + service = "dataflow.googleapis.com" +} + +// DF Beam YAML +resource "google_dataflow_flex_template_job" "data_movie_job" { + provider = google-beta + project = var.gcp_project_id + name = "movie-proto-events" + container_spec_gcs_path = "gs://dataflow-templates-${var.gcp_region}/latest/flex/Yaml_Template" + region = var.gcp_region + on_delete = "drain" + machine_type = "n2d-standard-4" + enable_streaming_engine = true + subnetwork = var.subnetwork + skip_wait_on_job_termination = true + parameters = { + yaml_pipeline_file = "gs://${var.bucket_name}/yamls/${var.package_version}/movie_events_pipeline.yml" + max_num_workers = 40 + worker_zone = var.gcp_zone + } + depends_on = [google_project_service.enable_dataflow_api] +} +``` + +Assuming the BigQuery table exists, which you can do by using Terraform and Proto, +this code creates a Dataflow job by using the Beam YAML code that reads Proto events from +Kafka and writes them into BigQuery. + +## Improvements and conclusions + +The following community contributions could improve the Beam YAML code in this example: + +* **Support schema registries:** Integrate with schema registries such as Buf Registry or Apicurio for +better schema management. The current workflow generates the descriptors by using Buf and store them in Google Cloud Storage. +The descriptors could be stored in a schema registry instead. + + +* **Enhanced Monitoring:** Implement advanced monitoring and alerting mechanisms to quickly identify and address +issues in the data pipeline. + +Leveraging Beam YAML and Protobuf lets us streamline the creation and maintenance of +data processing pipelines, significantly reducing complexity. This approach ensures that engineers can more +efficiently implement and scale robust, reusable pipelines without needs to manually write Beam code. + +## Contribute + +Developers who want to help build out and add functionalities are welcome to start contributing to the effort in the +[Beam YAML module](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/yaml). + +There is also a list of open [bugs](https://github.com/apache/beam/issues?q=is%3Aopen+is%3Aissue+label%3Ayaml) found +on the GitHub repo - now marked with the `yaml` tag. + +Although Beam YAML is marked stable as of Beam 2.52, it is still under heavy development, with new features being +added with each release. Those who want to be part of the design decisions and give insights to how the framework is +being used are highly encouraged to join the [dev mailing list](https://beam.apache.org/community/contact-us/), where those discussions are occurring. diff --git a/website/www/site/content/en/case-studies/accenture_baltics.md b/website/www/site/content/en/case-studies/accenture_baltics.md new file mode 100644 index 000000000000..98f9d9a8a687 --- /dev/null +++ b/website/www/site/content/en/case-studies/accenture_baltics.md @@ -0,0 +1,105 @@ +--- +title: "Accenture Baltics' Journey with Apache Beam to Streamlined Data Workflows for a Sustainable Energy Leader" +name: "Accenture Baltics" +icon: /images/logos/powered-by/accenture.png +hasNav: true +category: study +cardTitle: "Accenture Baltics' Journey with Apache Beam" +cardDescription: "Accenture Baltics uses Apache Beam on Google Cloud to build a robust data processing infrastructure for a sustainable energy leader.They use Beam to democratize data access, process data in real-time, and handle complex ETL tasks." +authorName: "Jana Polianskaja" +authorPosition: "Data Engineer @ Accenture Baltics" +authorImg: /images/case-study/accenture/Jana_Polianskaja_sm.jpg +publishDate: 2024-11-25T00:12:00+00:00 +--- + + +
    +
    + +
    +
    +

    + “Apache Beam empowers team members who don’t have data engineering backgrounds to directly access and analyze BigQuery data by using SQL. The data scientists, the finance department, and production optimization teams all benefit from improved data accessibility, which gives them immediate access to critical information for faster analysis and decision-making.” +

    +
    +
    + +
    +
    +
    + Jana Polianskaja +
    +
    + Data Engineer @ Accenture Baltics +
    +
    +
    +
    +
    + + +
    + +# Accenture Baltics' Journey with Apache Beam to Streamlined Data Workflows for a Sustainable Energy Leader + +## Background + +Accenture Baltics, a branch of the global professional services company Accenture, leverages its expertise across various industries to provide consulting, technology, and outsourcing solutions to clients worldwide. A specific project at Accenture Baltics highlights the effective implementation of Apache Beam to support a client who is a global leader in sustainable energy and uses Google Cloud. + +## Journey to Apache Beam + +The team responsible for transforming, curating, and preparing data, including transactional, analytics, and sensor data, for data scientists and other teams has been using Dataflow with Apache Beam for about five years. Dataflow with Beam is a natural choice for both streaming and batch data processing. For our workloads, we typically use the following configurations: worker machine types are `n1-standard-2` or `n1-standard-4`, and the maximum number of workers varies up to five, using the Dataflow runner. + +As an example, a streaming pipeline ingests transaction data from Pub/Sub, performs basic ETL and data cleaning, and outputs the results to BigQuery. A separate batch Dataflow pipeline evaluates a binary classification model, reading input and writing results to Google Cloud Storage. The following diagram shows a workflow that uses Pub/Sub to feed Dataflow pipelines across three Google Cloud projects. It also shows how Dataflow, Composer, Cloud Storage, BigQuery, and Grafana integrate into the architecture. + +
    + + Diagram of Accenture Baltics' Dataflow pipeline architecture + +
    + +## Use Cases + +Apache Beam is an invaluable tool for our use cases, particularly in the following areas: + +* **Democratizing data access:** Beam empowers team members without data engineering backgrounds to directly access and analyze BigQuery data using their SQL skills. The data scientists, the finance department, and production optimization teams all benefit from improved data accessibility, gaining immediate access to critical information for faster analysis and decision-making. +* **Real-time data processing:** Beam excels at ingesting and processing data in real time from sources like Pub/Sub. +* **ETL (extract, transform, load):** Beam effectively manages the full spectrum of data transformation and cleaning tasks, even when dealing with complex data structures. +* **Data routing and partitioning:** Beam enables sophisticated data routing and partitioning strategies. For example, it can automatically route failed transactions to a separate BigQuery table for further analysis. +* **Data deduplication and error handling:** Beam has been instrumental in tackling challenging tasks like deduplicating Pub/Sub messages and implementing robust error handling, such as for JSON parsing, that are crucial for maintaining data integrity and pipeline reliability. + +We also utilize Grafana (shown in below) with custom notification emails and tickets for comprehensive monitoring of our Beam pipelines. Notifications are generated from Google’s Cloud Logging and Cloud Monitoring services to ensure we stay informed about the performance and health of our pipelines. The seamless integration of Airflow with Dataflow and Beam further enhances our workflow, allowing us to effortlessly use operators such as `DataflowCreatePythonJobOperator` and `BeamRunPythonPipelineOperator` in [Airflow 2](https://airflow.apache.org/docs/apache-airflow-providers-google/stable/_api/airflow/providers/google/cloud/operators/dataflow/index.html). + +
    + + scheme + +
    + +## Results + +Our data processing infrastructure uses 12 distinct pipelines to manage and transform data across various projects within the organization. These pipelines are divided into two primary categories: + +* **Streaming pipelines:** These pipelines are designed to handle real-time or near real-time data streams. In our current setup, these pipelines process an average of 10,000 messages per second from Pub/Sub and about 200,000 rows per hour to BigQuery, ensuring that time-sensitive data is ingested and processed with minimal latency. +* **Batch pipelines:** These pipelines are optimized for processing large volumes of data in scheduled batches. Our current batch pipelines handle approximately two gigabytes of data per month, transforming and loading this data into our data warehouse for further analysis and reporting. + +Apache Beam has proven to be a highly effective solution for orchestrating and managing the complex data pipelines required by the client. By leveraging the capabilities of Dataflow, a fully managed service for executing Beam pipelines, we have successfully addressed and fulfilled the client's specific data processing needs. This powerful combination has enabled us to achieve scalability, reliability, and efficiency in handling large volumes of data, ensuring timely and accurate delivery of insights to the client. + +*Check out [my Medium blog](https://medium.com/@jana_om)\! I usually post about using Beam/Dataflow as an ETL tool with Python and how it works with other data engineering tools. My focus is on building projects that are easy to understand and learn from, especially if you want to get some hands-on experience with Beam.* + + +{{< case_study_feedback "AccentureBalticsStreaming" >}} +
    +
    diff --git a/website/www/site/content/en/case-studies/behalf.md b/website/www/site/content/en/case-studies/behalf.md deleted file mode 100644 index e5a240a03d4f..000000000000 --- a/website/www/site/content/en/case-studies/behalf.md +++ /dev/null @@ -1,19 +0,0 @@ ---- -title: "Behalf" -icon: /images/logos/powered-by/behalf.png -hasLink: "https://www.behalf.com/" ---- - - diff --git a/website/www/site/content/en/case-studies/yelp_streaming.md b/website/www/site/content/en/case-studies/yelp_streaming.md new file mode 100644 index 000000000000..529070b30727 --- /dev/null +++ b/website/www/site/content/en/case-studies/yelp_streaming.md @@ -0,0 +1,149 @@ +--- +title: "Building data abstractions with streaming at Yelp" +name: "Yelp" +icon: /images/logos/powered-by/yelp.png +hasNav: true +category: study +cardTitle: "Building data abstractions with streaming at Yelp" +cardDescription: "At Yelp, Apache Beam allows teams to create custom streaming pipelines using Python, eliminating the need to switch to Scala or Java. This reduces the learning curve for Python developers and minimizes friction, while providing the flexibility to utilize existing Python libraries." +authorName: "Hakampreet Singh Pandher" +authorPosition: "Software Engineer @ Yelp" +authorImg: /images/case-study/yelp/hpandher.png +publishDate: 2024-11-10T00:12:00+00:00 +--- + + +
    +
    + +
    +
    +

    + “At Yelp, Apache Beam provides an option for different teams to create custom streaming pipelines using Python, eliminating the need to switch to Scala or Java. This reduces the learning curve for Python developers and minimizes friction, while providing the flexibility to utilize existing Python libraries.” +

    +
    +
    + +
    +
    +
    + Hakampreet Singh Pandher +
    +
    + Software Engineer @ Yelp +
    +
    +
    +
    +
    + + +
    + +# Building data abstractions with streaming at Yelp + +## Background + +Yelp relies heavily on streaming to synchronize enormous volumes of data in real time. This is facilitated by Yelp's underlying [data pipeline infrastructure](https://engineeringblog.yelp.com/2016/07/billions-of-messages-a-day-yelps-real-time-data-pipeline.html), which manages the real-time flow of millions of messages originating from a plethora of services. This blog post covers how we leverage Yelp’s extensive streaming infrastructure to build robust data abstractions for our offline and streaming data consumers. We will use Yelp’s Business Properties ecosystem (explained in the upcoming sections) as an example. + + +## Key terminology + +Let’s start by covering certain key terms used throughout the post: + +* **Offline systems** - data warehousing platforms such as AWS Redshift or [Yelp’s Data Lake](https://engineeringblog.yelp.com/2021/04/powering-messaging-enabledness-with-yelps-data-infrastructure.html), which are intended for large-scale data analysis + +* **Online systems** - systems designed around high-performance SQL and NoSQL database solutions like MySQL or Cassandra DB, specifically built to handle and serve live traffic in real time, typically via REST APIs over HTTP. These databases are optimized for swiftly processing and delivering data as it's generated or requested, making them crucial for applications and services that require immediate access to up-to-date information + +## Status Quo + +### Introduction to business properties ecosystem + +Generally speaking, ‘Business Property’ can be any piece of data that is associated with a Yelp business. For example, if we're talking about a restaurant, its business properties could include things like what payment methods it accepts, what amenities it provides, and when it is open for business. + +There are two types of business properties: Business Attributes and Business Features. You may notice that the terms, attributes and features, are synonymous to each other, and that’s by no accident. The primary distinction is that Business Attributes belong to the legacy system, **yelp-main**, while Business Features are in a dedicated microservice, aligning with Yelp's transition to Service Oriented Architecture. + +We also gather additional metadata about business properties themselves, such as when they were last modified, how confident we are in their accuracy, and where they originated from. This additional information is referred to as “properties metadata.” We store this metadata in a separate table, which contains data about both Business Features and Business Attributes. + +Business properties data is accessed via two primary methods: HTTP APIs for real-time online applications and streaming for offline data synchronization. This post mainly focuses on the streaming aspect. + + +### Existing Business Properties' streaming architecture + +
    + + scheme + +
    + +1. In yelp-main’s MySQL database, data for Business Attributes is scattered across more than a dozen tables. To share this data efficiently, we employ the [MySQL Replication Handler](https://engineeringblog.yelp.com/2016/08/streaming-mysql-tables-in-real-time-to-kafka.html) to push it to [Kafka](https://kafka.apache.org/intro) + +2. Business Features and metadata for business properties are stored in their respective tables in Cassandra db and we use [Cassandra Source Connector](https://engineeringblog.yelp.com/2019/12/cassandra-source-connector-part-1.html) to publish their data into Kafka + +3. Ultimately, we use [Redshift Connector](https://engineeringblog.yelp.com/2016/10/redshift-connector.html) to synchronize data from all these tables with their corresponding tables in Redshift. This process allows us to maintain an up-to-date dataset in Redshift for analysis and reporting + + +### Challenges with the existing workflow + +* **Weak Encapsulation**: Storing data in offline systems exactly as it is stored in source databases forces our clients to understand the inner workings of the source data, which weakens data encapsulation. Ideally, we wanted to abstract away distinctions like 'Business Features' and 'Business Attributes' and hide implementation details from clients to simplify their interactions. Furthermore, exposing raw data to offline consumers can lead to the disclosure of outdated or incorrect information. Transformation layers via REST APIs prevented online users from facing data discrepancies. However, offline users analyzing raw data still had to grapple with data accuracy issues, such as managing soft-deleted entries. + +* **Discovery and consumption**: The lack of proper abstractions also made data analysis and consumption challenging as it meant that consumers, whether they are Product Managers, Data Analysts, or batch processing systems, must create multiple workflows to collect data from various sources. Not to mention, dealing with edge cases and transforming data into a consistent schema added significant effort and cost, leading to an increase in the friction for consumption and a reduction in the general utility of the data. + +* **Maintenance challenges**: It also posed certain maintenance challenges as any alteration in the source schema necessitated corresponding changes in the destination store. Ideally, we would prefer the destination store's schema to be more flexible, dynamic, and less susceptible to changes. This minimizes disruptions for users and mitigates the risk of infrastructure problems due to frequent schema upgrades. It also underscores the fact that a storage schema suitable for one database system might not be ideal for another. + + +## Improved Implementation: + +We did explore various alternatives, including a non-streaming solution that involved using Apache Spark for routine batch executions to generate data dumps in diverse formats. However, as some of the data consumer use cases required relatively real-time updates, we had to lean towards a streaming approach. + +### Building robust data abstractions for both offline and streaming data consumers + +We tackled the aforementioned challenges by treating both streaming and offline data consumption as just additional channels for accessing and utilizing data, much like online HTTP clients. Similar to how we simplify complexities for online data consumers through REST APIs, we aimed to provide a consistent experience for streamed data by abstracting away internal implementation details. This means that if a client service transitions from consuming data directly through REST APIs to an asynchronous streaming approach, it will encounter similar data abstractions. For example, just as online consumers won't see stale or invalid data, the same principle applies to streamed data consumers. + +In order to achieve the same, we implemented a unified stream that delivers all relevant business property data in a consistent and user-friendly format. This approach ensures that Business Property consumers are spared from navigating the nuances between Business Attributes and Features or understanding the intricacies of data storage in their respective online source databases. + +### New consolidated business properties streaming architecture + +
    + + scheme + +
    + +1. **Business Attributes data collection and transformation**: we utilize [Apache Beam](https://beam.apache.org/) with [Apache Flink](https://flink.apache.org/) as the distributed processing backend for data transformation and formatting Business attribute data. Apache Beam transformation jobs process data originating from various input streams generated by the MySQL replication handler. These streams contain replicated data from their corresponding MySQL tables. The transformation jobs are responsible for standardizing the incoming streaming data, transforming it into a consistent format across all business properties. The transformed data is then published into a single unified stream. + +2. **Streaming Business Features**: in a similar fashion, the output stream for Business Features, sourced from Cassandra using a [source connector](https://engineeringblog.yelp.com/2019/12/cassandra-source-connector-part-1.html), also has its dedicated Apache Beam transformer job. This job formats the data to match the unified format used for Business Attributes, and the resulting data is published into the same unified output stream + +3. **Enrich data with properties metadata**: we employed a [Joinery Flink](https://engineeringblog.yelp.com/2018/12/joinery-a-tale-of-unwindowed-joins.html) job - a homegrown solution at Yelp commonly used for joining data across multiple Kafka topics - to amalgamate the business data for both Business Attributes and Features with the corresponding metadata. As a result, the data stream not only contains the business properties data but also the relevant metadata linked to each property. + +4. **Final data formatting**: transformation job to address issues related to data inconsistencies, remove invalid data entries, and add any necessary supplementary fields, before the final business properties with metadata consolidated stream is exposed for consumption + +5. **Offline data storage**: the processed business properties data, complete with metadata, is made available for offline consumption and ends up in Redshift, through Redshift Connector. Additionally, it is ingested into Yelp's Data Lake using a Data Lake connector, making it available for a broader range of analytics and data processing tasks + +6. **Real-time consumption and Integration**: the same consolidated data stream can cater to real-time consumption by other services within the organization. We use the same stream to sync business property data with Marketing systems, as they require timely syncs for their campaigns + +To summarize, with the architecture described above, we have created a unified business properties stream addressing the challenges with the existing workflow mentioned above. This stream is utilized to sync business properties data into offline systems, enabling users to access all business properties through a singular schema, thereby facilitating data discovery, consumption, and overall ease of use. + +Additionally, this approach allowed us to enrich business property data with associated metadata and resolve data inconsistencies, such as removing duplicate business properties etc. We used the [entity–attribute–value (EAV) model](https://en.wikipedia.org/wiki/Entity%E2%80%93attribute%E2%80%93value_model), which accommodates the frequent introduction of new business properties without requiring modifications to the destination store schemas, hence reducing some of the maintenance overhead. + +## Final words +This post shows how Yelp's robust data pipeline infrastructure can be leveraged to create sophisticated data pipelines that provide data in formats which are more suited and beneficial for both offline and streaming users. While this doesn't imply that streaming and exposing raw data is never appropriate, however in such situations, it may be more effective to offer multiple streams: one with the raw data and others with processed data that is more befitting for data analysis and consumption + +*This blog post was originally published in March 2024 on Yelp's Engineering Blog: [Building Data Abstractions with Streaming at Yelp](https://engineeringblog.yelp.com/2024/03/building-data-abstractions-with-streaming-at-yelp.html).* + + +{{< case_study_feedback "YelpStreaming" >}} +
    +
    diff --git a/website/www/site/content/en/documentation/ml/large-language-modeling.md b/website/www/site/content/en/documentation/ml/large-language-modeling.md index 90bbd43383c0..b8bd0704d20e 100644 --- a/website/www/site/content/en/documentation/ml/large-language-modeling.md +++ b/website/www/site/content/en/documentation/ml/large-language-modeling.md @@ -170,3 +170,32 @@ class MyModelHandler(): def run_inference(self, batch: Sequence[str], model: MyWrapper, inference_args): return model.predict(unpickleable_object) ``` + +## RAG and Prompt Engineering in Beam + +Beam is also an excellent tool for improving the quality of your LLM prompts using Retrieval Augmented Generation (RAG). +Retrieval augmented generation is a technique that enhances large language models (LLMs) by connecting them to external knowledge sources. +This allows the LLM to access and process real-time information, improving the accuracy, relevance, and factuality of its responses. + +Beam has several mechanisms to make this process simpler: + +1. Beam's [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) provides an embeddings package to generate the embeddings used for RAG. You can also use RunInference to generate embeddings if you have a model without an embeddings handler. +2. Beam's [Enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) makes it easy to look up embeddings or other information in an external storage system like a [vector database](https://www.pinecone.io/learn/vector-database/). + +Collectively, you can use these to perform RAG using the following steps: + +**Pipeline 1 - generate knowledge base:** + +1. Ingest data from external source using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) +2. Generate embeddings on that data using [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) +3. Write those embeddings to a vector DB using a [ParDo](https://beam.apache.org/documentation/programming-guide/#pardo) + +**Pipeline 2 - use knowledge base to perform RAG:** + +1. Ingest data from external source using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) +2. Generate embeddings on that data using [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) +3. Enrich that data with additional embeddings from your vector DB using [Enrichment](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) +4. Use that enriched data to prompt your LLM with [RunInference](https://beam.apache.org/documentation/transforms/python/elementwise/runinference/) +5. Write that data to your desired sink using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) + +To view an example pipeline performing RAG, see https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/rag_usecase/beam_rag_notebook.ipynb diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index c716c7554db4..955c2b8797d1 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -35,7 +35,7 @@ programming guide, take a look at the {{< language-switcher java py go typescript yaml >}} {{< paragraph class="language-py" >}} -The Python SDK supports Python 3.8, 3.9, 3.10, and 3.11. +The Python SDK supports Python 3.8, 3.9, 3.10, 3.11, and 3.12. {{< /paragraph >}} {{< paragraph class="language-go">}} @@ -2024,7 +2024,7 @@ playerAccuracies := ... // PCollection #### 4.2.5. Flatten {#flatten} [`Flatten`](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/index.html?org/apache/beam/sdk/transforms/Flatten.html) -[`Flatten`](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/transforms/core.py) +[`Flatten`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.core.html#apache_beam.transforms.core.Flatten) [`Flatten`](https://github.com/apache/beam/blob/master/sdks/go/pkg/beam/flatten.go) `Flatten` is a Beam transform for `PCollection` objects that store the same data type. @@ -2045,6 +2045,22 @@ PCollectionList collections = PCollectionList.of(pc1).and(pc2).and(pc3); PCollection merged = collections.apply(Flatten.pCollections()); {{< /highlight >}} +{{< paragraph class="language-java" >}} +One can also use the [`FlattenWith`](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/index.html?org/apache/beam/sdk/transforms/Flatten.html) +transform to merge PCollections into an output PCollection in a manner more compatible with chaining. +{{< /paragraph >}} + +{{< highlight java >}} +PCollection merged = pc1 + .apply(...) + // Merges the elements of pc2 in at this point... + .apply(FlattenWith.of(pc2)) + .apply(...) + // and the elements of pc3 at this point. + .apply(FlattenWith.of(pc3)) + .apply(...); +{{< /highlight >}} + {{< highlight py >}} # Flatten takes a tuple of PCollection objects. @@ -2052,6 +2068,26 @@ PCollection merged = collections.apply(Flatten.pCollections()); {{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten >}} {{< /highlight >}} +{{< paragraph class="language-py" >}} +One can also use the [`FlattenWith`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.core.html#apache_beam.transforms.core.FlattenWith) +transform to merge PCollections into an output PCollection in a manner more compatible with chaining. +{{< /paragraph >}} + +{{< highlight py >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with >}} +{{< /highlight >}} + +{{< paragraph class="language-py" >}} +`FlattenWith` can take root `PCollection`-producing transforms +(such as `Create` and `Read`) as well as already constructed PCollections, +and will apply them and flatten their outputs into the resulting output +PCollection. +{{< /paragraph >}} + +{{< highlight py >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with_transform >}} +{{< /highlight >}} + {{< highlight go >}} // Flatten accepts any number of PCollections of the same element type. // Returns a single PCollection that contains all of the elements in input PCollections. @@ -6173,7 +6209,7 @@ class MyDoFn(beam.DoFn): self.gauge = metrics.Metrics.gauge("namespace", "gauge1") def process(self, element): - self.gaguge.set(element) + self.gauge.set(element) yield element {{< /highlight >}} diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index 7325c480955c..8356b9067a6b 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -93,7 +93,7 @@ from the [compatibility table](#flink-version-compatibility) below. For example: {{< highlight java >}} org.apache.beam - beam-runners-flink-1.17 + beam-runners-flink-1.18 {{< param release_latest >}} {{< /highlight >}} @@ -166,7 +166,7 @@ If you have a Flink `JobManager` running on your local machine you can provide ` To run a pipeline on Flink, set the runner to `FlinkRunner` and `flink_master` to the master URL of a Flink cluster. In addition, optionally set `environment_type` set to `LOOPBACK`. For example, -after starting up a [local flink cluster](https://ci.apache.org/projects/flink/flink-docs-release-1.10/getting-started/tutorials/local_setup.html), +after starting up a [local flink cluster](https://ci.apache.org/projects/flink/flink-docs-release-1.18/getting-started/tutorials/local_setup.html), one could run: {{< /paragraph >}} @@ -196,7 +196,6 @@ The optional `flink_version` option may be required as well for older versions o {{< paragraph class="language-portable" >}} Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: -[Flink 1.15](https://hub.docker.com/r/apache/beam_flink1.15_job_server). [Flink 1.16](https://hub.docker.com/r/apache/beam_flink1.16_job_server). [Flink 1.17](https://hub.docker.com/r/apache/beam_flink1.17_job_server). [Flink 1.18](https://hub.docker.com/r/apache/beam_flink1.18_job_server). @@ -208,12 +207,17 @@ To run a pipeline on an embedded Flink cluster: {{< /paragraph >}} {{< paragraph class="language-portable" >}} -(1) Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +(1) Start the JobService endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest` {{< /paragraph >}} {{< paragraph class="language-portable" >}} The JobService is the central instance where you submit your Beam pipeline to. -The JobService will create a Flink job for the pipeline and execute the job. +It creates a Flink job from your pipeline and executes it. +You might encounter an error message like `Caused by: java.io.IOException: Insufficient number of network buffers:...`. +This can be resolved by providing a Flink configuration file to override the default settings. +You can find an example configuration file [here](https://github.com/apache/beam/blob/master/runners/flink/src/test/resources/flink-conf.yaml). +To start the Job Service endpoint with your custom configuration, mount a local directory containing your Flink configuration to the `/flink-conf` path in the Docker container and pass this as `--flink-conf-dir`: +`docker run --net=host -v :/flink-conf beam-flink-runner apache/beam_flink1.18_job_server:latest --flink-conf-dir /flink-conf` {{< /paragraph >}} {{< paragraph class="language-portable" >}} @@ -236,7 +240,7 @@ with beam.Pipeline(options) as p: {{< paragraph class="language-portable" >}} -To run on a separate [Flink cluster](https://ci.apache.org/projects/flink/flink-docs-release-1.10/getting-started/tutorials/local_setup.html): +To run on a separate [Flink cluster](https://ci.apache.org/projects/flink/flink-docs-release-1.18/getting-started/tutorials/local_setup.html): {{< /paragraph >}} {{< paragraph class="language-portable" >}} @@ -244,7 +248,7 @@ To run on a separate [Flink cluster](https://ci.apache.org/projects/flink/flink- {{< /paragraph >}} {{< paragraph class="language-portable" >}} -(2) Start JobService with Flink Rest endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest --flink-master=localhost:8081`. +(2) Start JobService with Flink Rest endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest --flink-master=localhost:8081`. {{< /paragraph >}} {{< paragraph class="language-portable" >}} @@ -312,8 +316,8 @@ reference. ## Flink Version Compatibility The Flink cluster version has to match the minor version used by the FlinkRunner. -The minor version is the first two numbers in the version string, e.g. in `1.16.0` the -minor version is `1.16`. +The minor version is the first two numbers in the version string, e.g. in `1.18.0` the +minor version is `1.18`. We try to track the latest version of Apache Flink at the time of the Beam release. A Flink version is supported by Beam for the time it is supported by the Flink community. @@ -326,6 +330,11 @@ To find out which version of Flink is compatible with Beam please see the table Artifact Id Supported Beam Versions + + 1.19.x + beam-runners-flink-1.19 + ≥ 2.61.0 + 1.18.x beam-runners-flink-1.18 @@ -339,12 +348,12 @@ To find out which version of Flink is compatible with Beam please see the table 1.16.x beam-runners-flink-1.16 - ≥ 2.47.0 + 2.47.0 - 2.60.0 1.15.x beam-runners-flink-1.15 - ≥ 2.40.0 + 2.40.0 - 2.60.0 1.14.x diff --git a/website/www/site/content/en/documentation/runtime/environments.md b/website/www/site/content/en/documentation/runtime/environments.md index d9a42db29e24..48039d50a10b 100644 --- a/website/www/site/content/en/documentation/runtime/environments.md +++ b/website/www/site/content/en/documentation/runtime/environments.md @@ -105,20 +105,22 @@ This method requires building image artifacts from Beam source. For additional i 2. Customize the `Dockerfile` for a given language, typically `sdks//container/Dockerfile` directory (e.g. the [Dockerfile for Python](https://github.com/apache/beam/blob/master/sdks/python/container/Dockerfile). -3. Return to the root Beam directory and run the Gradle `docker` target for your image. +3. Return to the root Beam directory and run the Gradle `docker` target for your + image. For self-contained instructions on building a container image, + follow [this guide](/documentation/sdks/python-sdk-image-build). ``` cd $BEAM_WORKDIR # The default repository of each SDK - ./gradlew :sdks:java:container:java8:docker ./gradlew :sdks:java:container:java11:docker ./gradlew :sdks:java:container:java17:docker + ./gradlew :sdks:java:container:java21:docker ./gradlew :sdks:go:container:docker - ./gradlew :sdks:python:container:py38:docker ./gradlew :sdks:python:container:py39:docker ./gradlew :sdks:python:container:py310:docker ./gradlew :sdks:python:container:py311:docker + ./gradlew :sdks:python:container:py312:docker # Shortcut for building all Python SDKs ./gradlew :sdks:python:container:buildAll @@ -168,9 +170,9 @@ builds the Python 3.6 container and tags it as `example-repo/beam_python3.6_sdk: From Beam 2.21.0 and later, a `docker-pull-licenses` flag was introduced to add licenses/notices for third party dependencies to the docker images. For example: ``` -./gradlew :sdks:java:container:java8:docker -Pdocker-pull-licenses +./gradlew :sdks:java:container:java11:docker -Pdocker-pull-licenses ``` -creates a Java 8 SDK image with appropriate licenses in `/opt/apache/beam/third_party_licenses/`. +creates a Java 11 SDK image with appropriate licenses in `/opt/apache/beam/third_party_licenses/`. By default, no licenses/notices are added to the docker images. diff --git a/website/www/site/content/en/documentation/sdks/python-sdk-image-build.md b/website/www/site/content/en/documentation/sdks/python-sdk-image-build.md new file mode 100644 index 000000000000..f456a686afea --- /dev/null +++ b/website/www/site/content/en/documentation/sdks/python-sdk-image-build.md @@ -0,0 +1,306 @@ + + +# Building Beam Python SDK Image Guide + +There are two options to build Beam Python SDK image. If you only need to modify +[the Python SDK boot entrypoint binary](https://github.com/apache/beam/blob/master/sdks/python/container/boot.go), +read [Update Boot Entrypoint Application Only](#update-boot-entrypoint-application-only). +If you need to build a Beam Python SDK image fully, +read [Build Beam Python SDK Image Fully](#build-beam-python-sdk-image-fully). + + +## Update Boot Entrypoint Application Only. + +If you only need to make a change to [the Python SDK boot entrypoint binary](https://github.com/apache/beam/blob/master/sdks/python/container/boot.go). You +can rebuild the boot application only and include the updated boot application +in the preexisting image. +Read [the Python container Dockerfile](https://github.com/apache/beam/blob/master/sdks/python/container/Dockerfile) +for reference. + +```shell +# From beam repo root, make changes to boot.go. +your_editor sdks/python/container/boot.go + +# Rebuild the entrypoint +./gradlew :sdks:python:container:gobuild + +cd sdks/python/container/build/target/launcher/linux_amd64 + +# Create a simple Dockerfile to use custom boot entrypoint. +cat >Dockerfile <//beam_python3.10_sdk:2.60.0-custom-boot +docker push us-central1-docker.pkg.dev///beam_python3.10_sdk:2.60.0-custom-boot +``` + +You can build a docker image if your local environment has Java, Python, Golang +and Docker installation. Try +`./gradlew :sdks:python:container:py:docker`. For example, +`:sdks:python:container:py310:docker` builds `apache/beam_python3.10_sdk` +locally if successful. You can follow this guide building a custom image from +a VM if the build fails in your local environment. + +## Build Beam Python SDK Image Fully + +This section introduces a way to build everything from the scratch. + +### Prepare VM + +Prepare a VM with Debian 11. This guide was tested on Debian 11. + +#### Google Compute Engine + +An option to create a Debian 11 VM is using a GCE instance. + +```shell +gcloud compute instances create beam-builder \ + --zone=us-central1-a \ + --image-project=debian-cloud \ + --image-family=debian-11 \ + --machine-type=n1-standard-8 \ + --boot-disk-size=20GB \ + --scopes=cloud-platform +``` + +Login to the VM. All the following steps are executed inside the VM. + +```shell +gcloud compute ssh beam-builder --zone=us-central1-a --tunnel-through-iap +``` + +Update the apt package list. + +```shell +sudo apt-get update +``` + +> [!NOTE] +> * A high CPU machine is recommended to reduce the compile time. +> * The image build needs a large disk. The build will fail with "no space left + on device" with the default disk size 10GB. +> * The `cloud-platform` is recommended to avoid permission issues with Google + Cloud Artifact Registry. You can use the default scopes if you don't push + the image to Google Cloud Artifact Registry. +> * Use a zone in the region of your docker repository of Artifact Registry if + you push the image to Artifact Registry. + +### Prerequisite Packages + +#### Java + +You need Java to run Gradle tasks. + +```shell +sudo apt-get install -y openjdk-11-jdk +``` + +#### Golang + +Download and install. Reference: https://go.dev/doc/install. + +```shell +# Download and install +curl -OL https://go.dev/dl/go1.23.2.linux-amd64.tar.gz +sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go1.23.2.linux-amd64.tar.gz + +# Add go to PATH. +export PATH=:/usr/local/go/bin:$PATH +``` + +Confirm the Golang version + +```shell +go version +``` + +Expected output: + +```text +go version go1.23.2 linux/amd64 +``` + +> [!NOTE] +> Old Go version (e.g. 1.16) will fail at `:sdks:python:container:goBuild`. + +#### Python + +This guide uses Pyenv to manage multiple Python versions. +Reference: https://realpython.com/intro-to-pyenv/#build-dependencies + +```shell +# Install dependencies +sudo apt-get install -y make build-essential libssl-dev zlib1g-dev \ +libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev \ +libncursesw5-dev xz-utils tk-dev libffi-dev liblzma-dev + +# Install Pyenv +curl https://pyenv.run | bash + +# Add pyenv to PATH. +export PATH="$HOME/.pyenv/bin:$PATH" +eval "$(pyenv init -)" +eval "$(pyenv virtualenv-init -)" +``` + +Install Python 3.9 and set the Python version. This will take several minutes. + +```shell +pyenv install 3.9 +pyenv global 3.9 +``` + +Confirm the python version. + +```shell +python --version +``` + +Expected output example: + +```text +Python 3.9.17 +``` + +> [!NOTE] +> You can use a different Python version for building with [ +`-PpythonVersion` option](https://github.com/apache/beam/blob/v2.60.0/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy#L2956-L2961) +> to Gradle task run. Otherwise, you should have `python3.9` in the build +> environment for Apache Beam 2.60.0 or later (python3.8 for older Apache Beam +> versions). If you use the wrong version, the Gradle task +`:sdks:python:setupVirtualenv` fails. + +#### Docker + +Install Docker +following [the reference](https://docs.docker.com/engine/install/debian/#install-using-the-repository). + +```shell +# Add GPG keys. +sudo apt-get update +sudo apt-get install ca-certificates curl +sudo install -m 0755 -d /etc/apt/keyrings +sudo curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc +sudo chmod a+r /etc/apt/keyrings/docker.asc + +# Add the Apt repository. +echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian \ + $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \ + sudo tee /etc/apt/sources.list.d/docker.list > /dev/null +sudo apt-get update + +# Install docker packages. +sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin +``` + +You need to run `docker` command without the root privilege in Beam Python SDK +image build. You can do this +by [adding your account to the docker group](https://docs.docker.com/engine/install/linux-postinstall/). + +```shell +sudo usermod -aG docker $USER +newgrp docker +``` + +Confirm if you can run a container without the root privilege. + +```shell +docker run hello-world +``` + +#### Git + +Git is not necessary for building Python SDK image. Git is just used to download +the Apache Beam code in this guide. + +```shell +sudo apt-get install -y git +``` + +### Build Beam Python SDK Image + +Download Apache Beam +from [the Github repository](https://github.com/apache/beam). + +```shell +git clone https://github.com/apache/beam beam +cd beam +``` + +Make changes to the Apache Beam code. + +Run the Gradle task to start Docker image build. This will take several minutes. +You can run `:sdks:python:container:py:docker` to build an image +for different Python version. +See [the supported Python version list](https://github.com/apache/beam/tree/master/sdks/python/container). +For example, `py310` is for Python 3.10. + +```shell +./gradlew :sdks:python:container:py310:docker +``` + +If the build is successful, you can see the built image locally. + +```shell +docker images +``` + +Expected output: + +```text +REPOSITORY TAG IMAGE ID CREATED SIZE +apache/beam_python3.10_sdk 2.60.0 33db45f57f25 About a minute ago 2.79GB +``` + +> [!NOTE] +> If you run the build in your local environment and Gradle task +`:sdks:python:setupVirtualenv` fails by an incompatible python version, please +> try with `-PpythonVersion` with the Python version installed in your local +> environment (e.g. `-PpythonVersion=3.10`) + +### Push to Repository + +You may push the custom image to a image repository. The image can be used +for [Dataflow custom container](https://cloud.google.com/dataflow/docs/guides/run-custom-container#usage). + +#### Google Cloud Artifact Registry + +You can push the image to Artifact Registry. No additional authentication is +necessary if you use Google Compute Engine. + +```shell +docker tag apache/beam_python3.10_sdk:2.60.0 us-central1-docker.pkg.dev///beam_python3.10_sdk:2.60.0-custom +docker push us-central1-docker.pkg.dev///beam_python3.10_sdk:2.60.0-custom +``` + +If you push an image in an environment other than a VM in Google Cloud, you +should configure [docker authentication with +`gcloud`](https://cloud.google.com/artifact-registry/docs/docker/authentication#gcloud-helper) +before `docker push`. + +#### Docker Hub + +You can push your Docker hub repository +after [docker login](https://docs.docker.com/reference/cli/docker/login/). + +```shell +docker tag apache/beam_python3.10_sdk:2.60.0 /beam_python3.10_sdk:2.60.0-custom +docker push /beam_python3.10_sdk:2.60.0-custom +``` + diff --git a/website/www/site/content/en/documentation/sdks/python-streaming.md b/website/www/site/content/en/documentation/sdks/python-streaming.md index 2d0bdfa9500b..7122cc8b9ae5 100644 --- a/website/www/site/content/en/documentation/sdks/python-streaming.md +++ b/website/www/site/content/en/documentation/sdks/python-streaming.md @@ -155,11 +155,3 @@ about executing streaming pipelines: - [DirectRunner streaming execution](/documentation/runners/direct/#streaming-execution) - [DataflowRunner streaming execution](/documentation/runners/dataflow/#streaming-execution) - [Portable Flink runner](/documentation/runners/flink/) - -## Unsupported features - -Python streaming execution does not currently support the following features: - -- Custom source API -- User-defined custom merging `WindowFn` (with fnapi) -- For portable runners, see [portability support table](https://s.apache.org/apache-beam-portability-support-table). diff --git a/website/www/site/content/en/documentation/sdks/yaml-errors.md b/website/www/site/content/en/documentation/sdks/yaml-errors.md index 8c0d9f06ade3..6edd1751a65b 100644 --- a/website/www/site/content/en/documentation/sdks/yaml-errors.md +++ b/website/www/site/content/en/documentation/sdks/yaml-errors.md @@ -37,7 +37,8 @@ The `output` parameter is a name that must referenced as an input to another transform that will process the errors (e.g. by writing them out). For example, the following code will write all "good" processed records to one file and -any "bad" records to a separate file. +any "bad" records, along with metadata about what error was encountered, +to a separate file. ``` pipeline: @@ -77,6 +78,10 @@ for a robust pipeline). Note also that the exact format of the error outputs is still being finalized. They can be safely printed and written to outputs, but their precise schema may change in a future version of Beam and should not yet be depended on. +It generally contains the failed record itself as well as information about +the error that was encountered (e.g. error messages and tracebacks). +To recover the bad record alone one can process the error output with the +`StripErrorMetadata` transformation. Some transforms allow for extra arguments in their error_handling config, e.g. for Python functions one can give a `threshold` which limits the relative number @@ -136,9 +141,16 @@ pipeline: error_handling: output: my_error_output + - type: StripErrorMetadata + name: FailedRecordsWithoutMetadata + # Takes the error information from ComputeRatio and returns just the + # failing records themselves for another attempt with a different + # transform. + input: ComputeRatio.my_error_output + - type: MapToFields name: ComputeRatioForBadRecords - input: ComputeRatio.my_error_output + input: FailedRecordsWithoutMetadata config: language: python fields: diff --git a/website/www/site/content/en/documentation/sdks/yaml.md b/website/www/site/content/en/documentation/sdks/yaml.md index d4829c163522..3559a18076ba 100644 --- a/website/www/site/content/en/documentation/sdks/yaml.md +++ b/website/www/site/content/en/documentation/sdks/yaml.md @@ -662,6 +662,9 @@ providers: MyCustomTransform: "urn:registered:in:expansion:service" ``` +A full example of how to build a java provider can be found +[here](https://github.com/Polber/beam-yaml-xlang). + Arbitrary Python transforms can be provided as well, using the syntax ``` @@ -705,6 +708,49 @@ options: streaming: true ``` +## Jinja Templatization + +It is a common to want to run a single Beam pipeline in different contexts +and/or with different configurations. +When running a YAML pipeline using `apache_beam.yaml.main` or via gcloud, +the yaml file can be parameterized with externally provided variables using +the [jinja variable syntax](https://jinja.palletsprojects.com/en/stable/templates/#variables). +The values are then passed via a `--jinja_variables` command line flag. + +For example, one could start a pipeline with + +``` +pipeline: + transforms: + - type: ReadFromCsv + config: + path: {{input_pattern}} +``` + +and then run it with + +```sh +python -m apache_beam.yaml.main \ + --yaml_pipeline_file=pipeline.yaml \ + --jinja_variables='{"input_pattern": "gs://path/to/this/runs/files*.csv"}' +``` + +Arbitrary [jinja control structures](https://jinja.palletsprojects.com/en/stable/templates/#list-of-control-structures), +such as looping and conditionals, can be used as well if desired as long as the +output results in a valid Beam YAML pipeline. + +We also expose the [`datetime`](https://docs.python.org/3/library/datetime.html) +module as a variable by default, which can be particularly useful in reading +or writing dated sources and sinks, e.g. + +``` +- type: WriteToJson + config: + path: "gs://path/to/{{ datetime.datetime.now().strftime('%Y/%m/%d') }}/dated-output.json" +``` + +would write to files like `gs://path/to/2016/08/04/dated-output*.json`. + ## Other Resources * [Example pipeline](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/yaml/examples) diff --git a/website/www/site/content/en/get-started/downloads.md b/website/www/site/content/en/get-started/downloads.md index 08614b8835c1..dea5dc314b17 100644 --- a/website/www/site/content/en/get-started/downloads.md +++ b/website/www/site/content/en/get-started/downloads.md @@ -96,10 +96,26 @@ versions denoted `0.x.y`. ## Releases +### 2.61.0 (2024-11-25) + +Official [source code download](https://downloads.apache.org/beam/2.61.0/apache-beam-2.61.0-source-release.zip). +[SHA-512](https://downloads.apache.org/beam/2.61.0/apache-beam-2.61.0-source-release.zip.sha512). +[signature](https://downloads.apache.org/beam/2.61.0/apache-beam-2.61.0-source-release.zip.asc). + +[Release notes](https://github.com/apache/beam/releases/tag/v2.61.0) + +### 2.60.0 (2024-10-17) + +Official [source code download](https://archive.apache.org/dist/beam/2.60.0/apache-beam-2.60.0-source-release.zip). +[SHA-512](https://archive.apache.org/dist/beam/2.60.0/apache-beam-2.60.0-source-release.zip.sha512). +[signature](https://archive.apache.org/dist/beam/2.60.0/apache-beam-2.60.0-source-release.zip.asc). + +[Release notes](https://github.com/apache/beam/releases/tag/v2.60.0) + ### 2.59.0 (2024-09-11) -Official [source code download](https://downloads.apache.org/beam/2.59.0/apache-beam-2.59.0-source-release.zip). -[SHA-512](https://downloads.apache.org/beam/2.59.0/apache-beam-2.59.0-source-release.zip.sha512). -[signature](https://downloads.apache.org/beam/2.59.0/apache-beam-2.59.0-source-release.zip.asc). +Official [source code download](https://archive.apache.org/dist/beam/2.59.0/apache-beam-2.59.0-source-release.zip). +[SHA-512](https://archive.apache.org/dist/beam/2.59.0/apache-beam-2.59.0-source-release.zip.sha512). +[signature](https://archive.apache.org/dist/beam/2.59.0/apache-beam-2.59.0-source-release.zip.asc). [Release notes](https://github.com/apache/beam/releases/tag/v2.59.0) diff --git a/website/www/site/content/en/get-started/from-spark.md b/website/www/site/content/en/get-started/from-spark.md index 36d1fb2a0450..e4c9b532f423 100644 --- a/website/www/site/content/en/get-started/from-spark.md +++ b/website/www/site/content/en/get-started/from-spark.md @@ -23,7 +23,7 @@ If you already know [_Apache Spark_](https://spark.apache.org/), using Beam should be easy. The basic concepts are the same, and the APIs are similar as well. -Spark stores data _Spark DataFrames_ for structured data, +Spark stores data in _Spark DataFrames_ for structured data, and in _Resilient Distributed Datasets_ (RDD) for unstructured data. We are using RDDs for this guide. diff --git a/website/www/site/content/en/get-started/quickstart-java.md b/website/www/site/content/en/get-started/quickstart-java.md index b3ff77174e0b..46c687b54236 100644 --- a/website/www/site/content/en/get-started/quickstart-java.md +++ b/website/www/site/content/en/get-started/quickstart-java.md @@ -139,9 +139,9 @@ if (project.hasProperty("dataflow-runner")) { {{< /highlight >}} 4. At the end of the build script, add the following task: {{< highlight >}} -task execute (type: JavaExec) { - classpath = sourceSets["main"].runtimeClasspath - mainClass.set(System.getProperty("mainClass")) +tasks.register("execute") { + mainClass.set(System.getProperty("mainClass")) + classpath = sourceSets.main.get().runtimeClasspath } {{< /highlight >}} 4. Build your project: diff --git a/website/www/site/content/en/get-started/quickstart-py.md b/website/www/site/content/en/get-started/quickstart-py.md index 3428f5346e02..d7f896153483 100644 --- a/website/www/site/content/en/get-started/quickstart-py.md +++ b/website/www/site/content/en/get-started/quickstart-py.md @@ -23,7 +23,7 @@ If you're interested in contributing to the Apache Beam Python codebase, see the {{< toc >}} -The Python SDK supports Python 3.8, 3.9, 3.10 and 3.11. Beam 2.48.0 was the last release with support for Python 3.7. +The Python SDK supports Python 3.8, 3.9, 3.10, 3.11 and 3.12. Beam 2.48.0 was the last release with support for Python 3.7. ## Set up your environment diff --git a/website/www/site/data/authors.yml b/website/www/site/data/authors.yml index 6f16ae12f01a..f3645f7475bf 100644 --- a/website/www/site/data/authors.yml +++ b/website/www/site/data/authors.yml @@ -107,6 +107,9 @@ lkuligin: name: Leonid Kuligin email: kuligin@google.com twitter: lkulighin +liferoad: + name: XQ Hu + email: xqhu@google.com markliu: name: Mark Liu email: markliu@apache.org @@ -283,4 +286,7 @@ jkinard: email: jkinard@google.com jkim: name: Jaehyeon Kim - email: dottami@gmail.com \ No newline at end of file + email: dottami@gmail.com +ffernandez92: + name: Ferran Fernandez + email: ffernandez.upc@gmail.com diff --git a/website/www/site/data/capability_matrix.yaml b/website/www/site/data/capability_matrix.yaml index dcbbca438b6e..e6fd51a9bb17 100644 --- a/website/www/site/data/capability_matrix.yaml +++ b/website/www/site/data/capability_matrix.yaml @@ -393,7 +393,7 @@ capability-matrix: - class: dataflow l1: "Partially" l2: non-merging windows - l3: State is supported for non-merging windows. SetState and MapState are not yet supported. + l3: "State is supported for non-merging windows. The MapState, SetState, and MultimapState state types are supported in the following scenarios: Java pipelines that don't use Streaming Engine; Java pipelines that use Streaming Engine and version 2.58.0 or later of the Java SDK. SetState, MapState, and MultimapState are not supported for pipelines that use Runner v2." - class: flink l1: "Partially" l2: non-merging windows diff --git a/website/www/site/data/en/quotes.yaml b/website/www/site/data/en/quotes.yaml index 3a5225b3f29a..db45c4344346 100644 --- a/website/www/site/data/en/quotes.yaml +++ b/website/www/site/data/en/quotes.yaml @@ -71,6 +71,16 @@ logoUrl: images/logos/powered-by/hop.png linkUrl: case-studies/hop/index.html linkText: Learn more +- text: At Yelp, Apache Beam allows teams to create custom streaming pipelines using Python, eliminating the need to switch to Scala or Java. + icon: icons/quote-icon.svg + logoUrl: /images/logos/powered-by/yelp.png + linkUrl: case-studies/yelp_streaming/index.html + linkText: Learn more +- text: Accenture Baltics uses Apache Beam on Google Cloud to build a robust data processing infrastructure for a sustainable energy leader.They use Beam to democratize data access, process data in real-time, and handle complex ETL tasks. + icon: icons/quote-icon.svg + logoUrl: /images/logos/powered-by/accenture.png + linkUrl: case-studies/accenture_baltics/index.html + linkText: Learn more - text: Have a story to share? Your logo could be here. icon: icons/quote-icon.svg logoUrl: images/logos/powered-by/blank.jpg diff --git a/website/www/site/layouts/partials/section-menu/en/sdks.html b/website/www/site/layouts/partials/section-menu/en/sdks.html index ea48eb6f40d9..243bbd92a465 100644 --- a/website/www/site/layouts/partials/section-menu/en/sdks.html +++ b/website/www/site/layouts/partials/section-menu/en/sdks.html @@ -44,6 +44,7 @@
  • Managing pipeline dependencies
  • Python multi-language pipelines quickstart
  • Python Unrecoverable Errors
  • +
  • Python SDK image build
  • diff --git a/website/www/site/static/images/case-study/accenture/Jana_Polianskaja_sm.jpg b/website/www/site/static/images/case-study/accenture/Jana_Polianskaja_sm.jpg new file mode 100644 index 000000000000..49ce23e5d7cd Binary files /dev/null and b/website/www/site/static/images/case-study/accenture/Jana_Polianskaja_sm.jpg differ diff --git a/website/www/site/static/images/case-study/accenture/dataflow_grafana.jpg b/website/www/site/static/images/case-study/accenture/dataflow_grafana.jpg new file mode 100644 index 000000000000..8a7471c1798f Binary files /dev/null and b/website/www/site/static/images/case-study/accenture/dataflow_grafana.jpg differ diff --git a/website/www/site/static/images/case-study/accenture/dataflow_pipelines.png b/website/www/site/static/images/case-study/accenture/dataflow_pipelines.png new file mode 100644 index 000000000000..423044910e36 Binary files /dev/null and b/website/www/site/static/images/case-study/accenture/dataflow_pipelines.png differ diff --git a/website/www/site/static/images/case-study/yelp/existing_streaming_architecture.png b/website/www/site/static/images/case-study/yelp/existing_streaming_architecture.png new file mode 100644 index 000000000000..f344368d0dbe Binary files /dev/null and b/website/www/site/static/images/case-study/yelp/existing_streaming_architecture.png differ diff --git a/website/www/site/static/images/case-study/yelp/hpandher.png b/website/www/site/static/images/case-study/yelp/hpandher.png new file mode 100644 index 000000000000..dba9a90b5ceb Binary files /dev/null and b/website/www/site/static/images/case-study/yelp/hpandher.png differ diff --git a/website/www/site/static/images/case-study/yelp/new_streaming_architecture.png b/website/www/site/static/images/case-study/yelp/new_streaming_architecture.png new file mode 100644 index 000000000000..9bfbb3fe4621 Binary files /dev/null and b/website/www/site/static/images/case-study/yelp/new_streaming_architecture.png differ diff --git a/website/www/site/static/images/logos/powered-by/behalf.png b/website/www/site/static/images/logos/powered-by/behalf.png deleted file mode 100644 index 346ec880d764..000000000000 Binary files a/website/www/site/static/images/logos/powered-by/behalf.png and /dev/null differ