diff --git a/.github/actions/setup-default-test-properties/test-properties.json b/.github/actions/setup-default-test-properties/test-properties.json index 098e4ca1935c..6439492ba5a2 100644 --- a/.github/actions/setup-default-test-properties/test-properties.json +++ b/.github/actions/setup-default-test-properties/test-properties.json @@ -14,7 +14,7 @@ }, "JavaTestProperties": { "SUPPORTED_VERSIONS": ["8", "11", "17", "21"], - "FLINK_VERSIONS": ["1.15", "1.16", "1.17", "1.18"], + "FLINK_VERSIONS": ["1.17", "1.18", "1.19"], "SPARK_VERSIONS": ["2", "3"] }, "GoTestProperties": { diff --git a/.github/trigger_files/IO_Iceberg_Integration_Tests.json b/.github/trigger_files/IO_Iceberg_Integration_Tests.json index 62ae7886c573..bbdc3a3910ef 100644 --- a/.github/trigger_files/IO_Iceberg_Integration_Tests.json +++ b/.github/trigger_files/IO_Iceberg_Integration_Tests.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 4 + "modification": 3 } diff --git a/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json b/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json index e3d6056a5de9..b98aece75634 100644 --- a/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Go_VR_Flink.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 1, + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java.json b/.github/trigger_files/beam_PostCommit_Java.json new file mode 100644 index 000000000000..9e26dfeeb6e6 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json b/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json index a03c067d2c4e..1efc8e9e4405 100644 --- a/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json +++ b/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 } diff --git a/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json b/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json new file mode 100644 index 000000000000..dd9afb90e638 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Examples_Flink.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json b/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json index 08c2e40784a9..920c8d132e4a 100644 --- a/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json +++ b/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 } \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json b/.github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json new file mode 100644 index 000000000000..dd9afb90e638 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Jpms_Flink_Java11.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json index e3d6056a5de9..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json index b970762c8397..bdd2197e534a 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "modification": "1" } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json index e3d6056a5de9..c537844dc84a 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 3 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json index e3d6056a5de9..f1ba03a243ee 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 5 } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json index b970762c8397..9200c368abbe 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json index b970762c8397..9200c368abbe 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json new file mode 100644 index 000000000000..b07a3c47e196 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.json @@ -0,0 +1,4 @@ +{ + + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json index e69de29bb2d1..0b34d452d42c 100644 --- a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Flink.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_TransformService_Direct.json b/.github/trigger_files/beam_PostCommit_TransformService_Direct.json index c4edaa85a89d..7663aee09101 100644 --- a/.github/trigger_files/beam_PostCommit_TransformService_Direct.json +++ b/.github/trigger_files/beam_PostCommit_TransformService_Direct.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "revision: "1" } diff --git a/.github/trigger_files/beam_PostCommit_XVR_Direct.json b/.github/trigger_files/beam_PostCommit_XVR_Direct.json new file mode 100644 index 000000000000..236b7bee8af8 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_Direct.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_XVR_Flink.json b/.github/trigger_files/beam_PostCommit_XVR_Flink.json new file mode 100644 index 000000000000..236b7bee8af8 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_Flink.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing Flink 1.19 support" +} diff --git a/.github/workflows/README.md b/.github/workflows/README.md index d386f4dc40f9..971bfd857b27 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -285,6 +285,7 @@ Additional PreCommit jobs running basic SDK unit test on a matrices of operating | [Java Tests](https://github.com/apache/beam/actions/workflows/java_tests.yml) | [![.github/workflows/java_tests.yml](https://github.com/apache/beam/actions/workflows/java_tests.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/java_tests.yml?query=event%3Aschedule) | | [Python Tests](https://github.com/apache/beam/actions/workflows/python_tests.yml) | [![.github/workflows/python_tests.yml](https://github.com/apache/beam/actions/workflows/python_tests.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/python_tests.yml?query=event%3Aschedule) | | [TypeScript Tests](https://github.com/apache/beam/actions/workflows/typescript_tests.yml) | [![.github/workflows/typescript_tests.yml](https://github.com/apache/beam/actions/workflows/typescript_tests.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/typescript_tests.yml?query=event%3Aschedule) | +| [Build Wheels](https://github.com/apache/beam/actions/workflows/build_wheels.yml) | [![.github/workflows/build_wheels.yml](https://github.com/apache/beam/actions/workflows/build_wheels.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/build_wheels.yml?query=event%3Aschedule) | ### PostCommit Jobs diff --git a/.github/workflows/beam_LoadTests_Java_GBK_Smoke.yml b/.github/workflows/beam_LoadTests_Java_GBK_Smoke.yml index 6dea32d4f5a0..d3b6c38ce7ae 100644 --- a/.github/workflows/beam_LoadTests_Java_GBK_Smoke.yml +++ b/.github/workflows/beam_LoadTests_Java_GBK_Smoke.yml @@ -106,7 +106,7 @@ jobs: arguments: | --info \ -PloadTest.mainClass=org.apache.beam.sdk.loadtests.GroupByKeyLoadTest \ - -Prunner=:runners:flink:1.15 \ + -Prunner=:runners:flink:1.19 \ '-PloadTest.args=${{ env.beam_LoadTests_Java_GBK_Smoke_test_arguments_3 }}' \ - name: run GroupByKey load test Spark uses: ./.github/actions/gradle-command-self-hosted-action @@ -115,4 +115,4 @@ jobs: arguments: | -PloadTest.mainClass=org.apache.beam.sdk.loadtests.GroupByKeyLoadTest \ -Prunner=:runners:spark:3 \ - '-PloadTest.args=${{ env.beam_LoadTests_Java_GBK_Smoke_test_arguments_4 }}' \ No newline at end of file + '-PloadTest.args=${{ env.beam_LoadTests_Java_GBK_Smoke_test_arguments_4 }}' diff --git a/.github/workflows/playground_backend_precommit.yml b/.github/workflows/beam_Playground_Precommit.yml similarity index 75% rename from .github/workflows/playground_backend_precommit.yml rename to .github/workflows/beam_Playground_Precommit.yml index 9ba6cf20534f..edb50661b1ee 100644 --- a/.github/workflows/playground_backend_precommit.yml +++ b/.github/workflows/beam_Playground_Precommit.yml @@ -17,10 +17,12 @@ name: Playground PreCommit on: workflow_dispatch: - pull_request: + pull_request_target: paths: - .github/workflows/playground_backend_precommit.yml - playground/backend/** + issue_comment: + types: [created] env: DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} @@ -28,17 +30,30 @@ env: GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} jobs: - precommit_check: - name: precommit-check - runs-on: ubuntu-latest + beam_Playground_PreCommit: + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'pull_request_target' || + github.event.comment.body == 'Run Playground PreCommit' + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + runs-on: [self-hosted, ubuntu-20.04, main] + strategy: + fail-fast: false + matrix: + job_name: [beam_Playground_PreCommit] + job_phrase: [Run Playground PreCommit] env: DATASTORE_EMULATOR_VERSION: '423.0.0' PYTHON_VERSION: '3.9' JAVA_VERSION: '11' steps: - - name: Check out the repo - uses: actions/checkout@v4 - + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action with: @@ -58,7 +73,7 @@ jobs: sudo chmod 644 /etc/apt/trusted.gpg.d/scalasbt-release.gpg sudo apt-get update --yes sudo apt-get install sbt --yes - sudo wget https://codeload.github.com/spotify/scio.g8/zip/7c1ba7c1651dfd70976028842e721da4107c0d6d -O scio.g8.zip && unzip scio.g8.zip && mv scio.g8-7c1ba7c1651dfd70976028842e721da4107c0d6d /opt/scio.g8 + sudo wget https://codeload.github.com/spotify/scio.g8/zip/7c1ba7c1651dfd70976028842e721da4107c0d6d -O scio.g8.zip && unzip scio.g8.zip && sudo mv scio.g8-7c1ba7c1651dfd70976028842e721da4107c0d6d /opt/scio.g8 - name: Set up Cloud SDK and its components uses: google-github-actions/setup-gcloud@v2 with: diff --git a/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml b/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml index 4077b7be68fe..e42d6a88b8df 100644 --- a/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml +++ b/.github/workflows/beam_PostCommit_Java_Examples_Flink.yml @@ -80,7 +80,7 @@ jobs: - name: run examplesIntegrationTest script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.15:examplesIntegrationTest + gradle-command: :runners:flink:1.19:examplesIntegrationTest - name: Archive JUnit Test Results uses: actions/upload-artifact@v4 if: ${{ !success() }} @@ -93,4 +93,4 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' diff --git a/.github/workflows/beam_PostCommit_Java_Nexmark_Flink.yml b/.github/workflows/beam_PostCommit_Java_Nexmark_Flink.yml index 58b12b4e4530..ef69a2918196 100644 --- a/.github/workflows/beam_PostCommit_Java_Nexmark_Flink.yml +++ b/.github/workflows/beam_PostCommit_Java_Nexmark_Flink.yml @@ -102,7 +102,7 @@ jobs: with: gradle-command: :sdks:java:testing:nexmark:run arguments: | - -Pnexmark.runner=:runners:flink:1.15 \ + -Pnexmark.runner=:runners:flink:1.19 \ "${{ env.GRADLE_COMMAND_ARGUMENTS }} --streaming=${{ matrix.streaming }} --queryLanguage=${{ matrix.queryLanguage }}" \ - name: run PostCommit Java Nexmark Flink (${{ matrix.streaming }}) script if: matrix.queryLanguage == 'none' @@ -110,5 +110,5 @@ jobs: with: gradle-command: :sdks:java:testing:nexmark:run arguments: | - -Pnexmark.runner=:runners:flink:1.15 \ - "${{ env.GRADLE_COMMAND_ARGUMENTS }}--streaming=${{ matrix.streaming }}" \ No newline at end of file + -Pnexmark.runner=:runners:flink:1.19 \ + "${{ env.GRADLE_COMMAND_ARGUMENTS }}--streaming=${{ matrix.streaming }}" diff --git a/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml b/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml index ff3cce441069..987be7789b29 100644 --- a/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_PVR_Flink_Streaming.yml @@ -77,7 +77,7 @@ jobs: - name: run PostCommit Java Flink PortableValidatesRunner Streaming script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: runners:flink:1.15:job-server:validatesPortableRunnerStreaming + gradle-command: runners:flink:1.19:job-server:validatesPortableRunnerStreaming - name: Archive JUnit Test Results uses: actions/upload-artifact@v4 if: ${{ !success() }} @@ -90,4 +90,4 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' diff --git a/.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml b/.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml index 19329026c034..cf85d9563122 100644 --- a/.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml +++ b/.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml @@ -101,5 +101,5 @@ jobs: with: gradle-command: :sdks:java:testing:tpcds:run arguments: | - -Ptpcds.runner=:runners:flink:1.18 \ + -Ptpcds.runner=:runners:flink:1.19 \ "-Ptpcds.args=${{env.tpcdsBigQueryArgs}} ${{env.tpcdsInfluxDBArgs}} ${{ env.GRADLE_COMMAND_ARGUMENTS }} --queries=${{env.tpcdsQueriesArg}}" \ diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink.yml index b6334d8e9858..f79ca8747828 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink.yml @@ -78,7 +78,7 @@ jobs: - name: run validatesRunner script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.18:validatesRunner + gradle-command: :runners:flink:1.19:validatesRunner - name: Archive JUnit Test Results uses: actions/upload-artifact@v4 if: ${{ !success() }} diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.yml index 15c99d7bfb37..c51c39987236 100644 --- a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Flink_Java8.yml @@ -80,11 +80,11 @@ jobs: 11 - name: run jar Java8 script run: | - ./gradlew :runners:flink:1.15:jar :runners:flink:1.15:testJar + ./gradlew :runners:flink:1.19:jar :runners:flink:1.19:testJar - name: run validatesRunner Java8 script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.15:validatesRunner + gradle-command: :runners:flink:1.19:validatesRunner arguments: | -x shadowJar \ -x shadowTestJar \ @@ -109,4 +109,4 @@ jobs: large_files: true commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' diff --git a/.github/workflows/beam_PostCommit_XVR_Flink.yml b/.github/workflows/beam_PostCommit_XVR_Flink.yml index 5cde38d24244..1f1d7d863b7e 100644 --- a/.github/workflows/beam_PostCommit_XVR_Flink.yml +++ b/.github/workflows/beam_PostCommit_XVR_Flink.yml @@ -47,7 +47,7 @@ env: DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} - FlinkVersion: 1.15 + FlinkVersion: 1.19 jobs: beam_PostCommit_XVR_Flink: @@ -110,4 +110,4 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' diff --git a/.github/workflows/beam_PostRelease_NightlySnapshot.yml b/.github/workflows/beam_PostRelease_NightlySnapshot.yml index 9b7fb2af2f2d..0c4144af1a4f 100644 --- a/.github/workflows/beam_PostRelease_NightlySnapshot.yml +++ b/.github/workflows/beam_PostRelease_NightlySnapshot.yml @@ -59,6 +59,9 @@ jobs: uses: ./.github/actions/setup-environment-action with: java-version: default + - name: Setup temp local maven + id: setup_local_maven + run: echo "NEW_TEMP_DIR=$(mktemp -d)" >> $GITHUB_OUTPUT - name: run PostRelease validation script uses: ./.github/actions/gradle-command-self-hosted-action with: @@ -66,3 +69,7 @@ jobs: arguments: | -Pver='${{ github.event.inputs.RELEASE }}' \ -Prepourl='${{ github.event.inputs.SNAPSHOT_URL }}' \ + -PmavenLocalPath='${{ steps.setup_local_maven.outputs.NEW_TEMP_DIR }}' + - name: Clean up local maven + if: steps.setup_local_maven.outcome == 'success' + run: rm -rf '${{ steps.setup_local_maven.outputs.NEW_TEMP_DIR }}' diff --git a/.github/workflows/beam_PreCommit_Java.yml b/.github/workflows/beam_PreCommit_Java.yml index 772eab98c343..20dafca72a57 100644 --- a/.github/workflows/beam_PreCommit_Java.yml +++ b/.github/workflows/beam_PreCommit_Java.yml @@ -19,6 +19,7 @@ on: tags: ['v*'] branches: ['master', 'release-*'] paths: + - "buildSrc/**" - 'model/**' - 'sdks/java/**' - 'runners/**' diff --git a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Batch.yml b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Batch.yml index 7019e4799ace..b459c4625547 100644 --- a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Batch.yml +++ b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Batch.yml @@ -94,7 +94,7 @@ jobs: - name: run validatesPortableRunnerBatch script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.15:job-server:validatesPortableRunnerBatch + gradle-command: :runners:flink:1.19:job-server:validatesPortableRunnerBatch env: CLOUDSDK_CONFIG: ${{ env.KUBELET_GCLOUD_CONFIG_PATH }} - name: Archive JUnit Test Results @@ -108,4 +108,4 @@ jobs: with: name: java-code-coverage-report path: "**/build/test-results/**/*.xml" -# TODO: Investigate 'Max retries exceeded' issue with EnricoMi/publish-unit-test-result-action@v2. \ No newline at end of file +# TODO: Investigate 'Max retries exceeded' issue with EnricoMi/publish-unit-test-result-action@v2. diff --git a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml index 3c16b57c9bce..5feb0270c68c 100644 --- a/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml +++ b/.github/workflows/beam_PreCommit_Java_PVR_Flink_Docker.yml @@ -99,7 +99,7 @@ jobs: - name: run PreCommit Java PVR Flink Docker script uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :runners:flink:1.15:job-server:validatesPortableRunnerDocker + gradle-command: :runners:flink:1.19:job-server:validatesPortableRunnerDocker env: CLOUDSDK_CONFIG: ${{ env.KUBELET_GCLOUD_CONFIG_PATH}} - name: Archive JUnit Test Results @@ -114,4 +114,4 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' diff --git a/.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml b/.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml new file mode 100644 index 000000000000..ea5cf9b5578e --- /dev/null +++ b/.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: PreCommit Java PVR Prism Loopback + +on: + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/go/cmd/prism/**' + - 'runners/prism/**' + - 'runners/java-fn-execution/**' + - 'sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/**' + - '.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml' + pull_request_target: + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/go/cmd/prism/**' + - 'runners/prism/**' + - 'runners/java-fn-execution/**' + - 'sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/**' + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PreCommit_Java_PVR_Prism_Loopback.json' + issue_comment: + types: [created] + schedule: + - cron: '22 2/6 * * *' + workflow_dispatch: + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: write + checks: write + contents: read + deployments: read + id-token: none + issues: write + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PreCommit_Java_PVR_Prism_Loopback: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + strategy: + matrix: + job_name: ["beam_PreCommit_Java_PVR_Prism_Loopback"] + job_phrase: ["Run Java_PVR_Prism_Loopback PreCommit"] + timeout-minutes: 240 + runs-on: [self-hosted, ubuntu-20.04] + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + github.event_name == 'workflow_dispatch' || + github.event.comment.body == 'Run Java_PVR_Prism_Loopback PreCommit' + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + - name: run prismLoopbackValidatesRunnerTests script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :runners:prism:java:prismLoopbackValidatesRunnerTests + - name: Archive JUnit Test Results + uses: actions/upload-artifact@v4 + if: ${{ !success() }} + with: + name: JUnit Test Results + path: "**/build/reports/tests/" + - name: Upload test report + uses: actions/upload-artifact@v4 + with: + name: java-code-coverage-report + path: "**/build/test-results/**/*.xml" diff --git a/.github/workflows/beam_PreCommit_Prism_Python.yml b/.github/workflows/beam_PreCommit_Prism_Python.yml new file mode 100644 index 000000000000..5eb26d139ef5 --- /dev/null +++ b/.github/workflows/beam_PreCommit_Prism_Python.yml @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PreCommit Prism Python + +on: + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/python/**' + - 'release/**' + - '.github/workflows/beam_PreCommit_Prism_Python.yml' + pull_request_target: + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/python/**' + - 'release/**' + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PreCommit_Prism_Python.json' + issue_comment: + types: [created] + schedule: + - cron: '30 2/6 * * *' + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PreCommit_Prism_Python: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + timeout-minutes: 120 + runs-on: ['self-hosted', ubuntu-20.04, main] + strategy: + fail-fast: false + matrix: + job_name: ['beam_PreCommit_Prism_Python'] + job_phrase: ['Run Prism_Python PreCommit'] + python_version: ['3.9', '3.12'] + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + startsWith(github.event.comment.body, 'Run Prism_Python PreCommit') + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} ${{ matrix.python_version }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + java-version: default + python-version: | + ${{ matrix.python_version }} + 3.9 + - name: Set PY_VER_CLEAN + id: set_py_ver_clean + run: | + PY_VER=${{ matrix.python_version }} + PY_VER_CLEAN=${PY_VER//.} + echo "py_ver_clean=$PY_VER_CLEAN" >> $GITHUB_OUTPUT + - name: run Prism Python Validates Runner script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:python:test-suites:portable:py${{steps.set_py_ver_clean.outputs.py_ver_clean}}:prismValidatesRunner \ No newline at end of file diff --git a/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml b/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml index 7107385c1722..e3791119be90 100644 --- a/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml +++ b/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml @@ -66,7 +66,6 @@ jobs: - "java:container:java11" - "java:container:java17" - "java:container:java21" - - "python:container:py38" - "python:container:py39" - "python:container:py310" - "python:container:py311" diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 0a15ba9d150c..828a6328c0cd 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -49,7 +49,7 @@ jobs: env: EVENT_NAME: ${{ github.event_name }} # Keep in sync with py_version matrix value below - if changed, change that as well. - PY_VERSIONS_FULL: "cp38-* cp39-* cp310-* cp311-* cp312-*" + PY_VERSIONS_FULL: "cp39-* cp310-* cp311-* cp312-*" outputs: gcp-variables-set: ${{ steps.check_gcp_variables.outputs.gcp-variables-set }} py-versions-full: ${{ steps.set-py-versions.outputs.py-versions-full }} @@ -229,7 +229,7 @@ jobs: {"os": "ubuntu-20.04", "runner": [self-hosted, ubuntu-20.04, main], "python": "${{ needs.check_env_variables.outputs.py-versions-test }}", arch: "aarch64" } ] # Keep in sync (remove asterisks) with PY_VERSIONS_FULL env var above - if changed, change that as well. - py_version: ["cp38-", "cp39-", "cp310-", "cp311-", "cp312-"] + py_version: ["cp39-", "cp310-", "cp311-", "cp312-"] steps: - name: Download python source distribution from artifacts uses: actions/download-artifact@v4.1.8 diff --git a/.github/workflows/self-assign.yml b/.github/workflows/self-assign.yml index 084581db7340..6c2f2219b4e3 100644 --- a/.github/workflows/self-assign.yml +++ b/.github/workflows/self-assign.yml @@ -40,12 +40,16 @@ jobs: repo: context.repo.repo, assignees: [context.payload.comment.user.login] }); - github.rest.issues.removeLabel({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - name: 'awaiting triage' - }); + try { + github.rest.issues.removeLabel({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + name: 'awaiting triage' + }); + } catch (error) { + console.log(`Failed to remove awaiting triage label. It may not exist on this issue. Error ${error}`); + } } else if (bodyString == '.close-issue') { console.log('Closing issue'); if (i + 1 < body.length && body[i+1].toLowerCase() == 'not_planned') { diff --git a/.github/workflows/update_python_dependencies.yml b/.github/workflows/update_python_dependencies.yml index a91aff39f29a..0ab52e97b9f0 100644 --- a/.github/workflows/update_python_dependencies.yml +++ b/.github/workflows/update_python_dependencies.yml @@ -56,7 +56,6 @@ jobs: uses: ./.github/actions/setup-environment-action with: python-version: | - 3.8 3.9 3.10 3.11 diff --git a/.test-infra/jenkins/metrics_report/tox.ini b/.test-infra/jenkins/metrics_report/tox.ini index 026db5dc4860..d143a0dcf59c 100644 --- a/.test-infra/jenkins/metrics_report/tox.ini +++ b/.test-infra/jenkins/metrics_report/tox.ini @@ -17,7 +17,7 @@ ; TODO(https://github.com/apache/beam/issues/20209): Don't hardcode Py3.8 version. [tox] skipsdist = True -envlist = py38-test,py38-generate-report +envlist = py39-test,py39-generate-report [testenv] commands_pre = diff --git a/.test-infra/mock-apis/pyproject.toml b/.test-infra/mock-apis/pyproject.toml index 680bf489ba13..c98d9152cfb9 100644 --- a/.test-infra/mock-apis/pyproject.toml +++ b/.test-infra/mock-apis/pyproject.toml @@ -27,7 +27,7 @@ packages = [ ] [tool.poetry.dependencies] -python = "^3.8" +python = "^3.9" google = "^3.0.0" grpcio = "^1.53.0" grpcio-tools = "^1.53.0" diff --git a/.test-infra/tools/python_installer.sh b/.test-infra/tools/python_installer.sh index b1b05e597cb3..04e10555243a 100644 --- a/.test-infra/tools/python_installer.sh +++ b/.test-infra/tools/python_installer.sh @@ -20,7 +20,7 @@ set -euo pipefail # Variable containing the python versions to install -python_versions_arr=("3.8.16" "3.9.16" "3.10.10" "3.11.4") +python_versions_arr=("3.9.16" "3.10.10" "3.11.4", "3.12.6") # Install pyenv dependencies. pyenv_dep(){ diff --git a/CHANGES.md b/CHANGES.md index f2b865cec236..261fafc024f3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -60,6 +60,7 @@ * New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). * New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). * [Python] Introduce Managed Transforms API ([#31495](https://github.com/apache/beam/pull/31495)) +* Flink 1.19 support added ([#32648](https://github.com/apache/beam/pull/32648)) ## I/Os @@ -68,12 +69,15 @@ * [Managed Iceberg] Now available in Python SDK ([#31495](https://github.com/apache/beam/pull/31495)) * [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) * BigQuery CDC writes are now available in Python SDK, only supported when using StorageWrite API at least once mode ([#32527](https://github.com/apache/beam/issues/32527)) +* [Managed Iceberg] Allow updating table partition specs during pipeline runtime ([#32879](https://github.com/apache/beam/pull/32879)) ## New Features / Improvements * Added support for read with metadata in MqttIO (Java) ([#32195](https://github.com/apache/beam/issues/32195)) * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Added support for processing events which use a global sequence to "ordered" extension (Java) [#32540](https://github.com/apache/beam/pull/32540) +* Add new meta-transform FlattenWith and Tee that allow one to introduce branching + without breaking the linear/chaining style of pipeline construction. ## Breaking Changes @@ -81,12 +85,15 @@ ## Deprecations +* Removed support for Flink 1.15 and 1.16 +* Removed support for Python 3.8 * X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). ## Bugfixes * Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * (Java) Fixed tearDown not invoked when DoFn throws on Portable Runners ([#18592](https://github.com/apache/beam/issues/18592), [#31381](https://github.com/apache/beam/issues/31381)). +* (Java) Fixed protobuf error with MapState.remove() in Dataflow Streaming Java Legacy Runner without Streaming Engine ([#32892](https://github.com/apache/beam/issues/32892)). ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). @@ -95,7 +102,7 @@ * ([#X](https://github.com/apache/beam/issues/X)). -# [2.60.0] - Unreleased +# [2.60.0] - 2024-10-17 ## Highlights @@ -104,6 +111,10 @@ * [Managed Iceberg] Added auto-sharding for streaming writes ([#32612](https://github.com/apache/beam/pull/32612)) * [Managed Iceberg] Added support for writing to dynamic destinations ([#32565](https://github.com/apache/beam/pull/32565)) +## I/Os + +* PubsubIO can validate that the Pub/Sub topic exists before running the Read/Write pipeline (Java) ([#32465](https://github.com/apache/beam/pull/32465)) + ## New Features / Improvements * Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). diff --git a/build.gradle.kts b/build.gradle.kts index 38b58b6979ee..d96e77a4c78c 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -501,23 +501,11 @@ tasks.register("pythonFormatterPreCommit") { dependsOn("sdks:python:test-suites:tox:pycommon:formatter") } -tasks.register("python38PostCommit") { - dependsOn(":sdks:python:test-suites:dataflow:py38:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py38:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py38:hdfsIntegrationTest") - dependsOn(":sdks:python:test-suites:direct:py38:azureIntegrationTest") - dependsOn(":sdks:python:test-suites:portable:py38:postCommitPy38") - // TODO: https://github.com/apache/beam/issues/22651 - // The default container uses Python 3.8. The goal here is to - // duild Docker images for TensorRT tests during run time for python versions - // other than 3.8 and add these tests in other python postcommit suites. - dependsOn(":sdks:python:test-suites:dataflow:py38:inferencePostCommitIT") - dependsOn(":sdks:python:test-suites:direct:py38:inferencePostCommitIT") -} - tasks.register("python39PostCommit") { dependsOn(":sdks:python:test-suites:dataflow:py39:postCommitIT") dependsOn(":sdks:python:test-suites:direct:py39:postCommitIT") + dependsOn(":sdks:python:test-suites:direct:py39:hdfsIntegrationTest") + dependsOn(":sdks:python:test-suites:direct:py39:azureIntegrationTest") dependsOn(":sdks:python:test-suites:portable:py39:postCommitPy39") // TODO (https://github.com/apache/beam/issues/23966) // Move this to Python 3.10 test suite once tfx-bsl has python 3.10 wheel. @@ -528,6 +516,11 @@ tasks.register("python310PostCommit") { dependsOn(":sdks:python:test-suites:dataflow:py310:postCommitIT") dependsOn(":sdks:python:test-suites:direct:py310:postCommitIT") dependsOn(":sdks:python:test-suites:portable:py310:postCommitPy310") + // TODO: https://github.com/apache/beam/issues/22651 + // The default container uses Python 3.10. The goal here is to + // duild Docker images for TensorRT tests during run time for python versions + // other than 3.10 and add these tests in other python postcommit suites. + dependsOn(":sdks:python:test-suites:dataflow:py310:inferencePostCommitIT") } tasks.register("python311PostCommit") { diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy index cd46c1270f83..b3949223f074 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy @@ -59,6 +59,7 @@ class BeamDockerPlugin implements Plugin { boolean load = false boolean push = false String builder = null + String target = null File resolvedDockerfile = null File resolvedDockerComposeTemplate = null @@ -130,6 +131,7 @@ class BeamDockerPlugin implements Plugin { group = 'Docker' description = 'Builds Docker image.' dependsOn prepare + environment 'DOCKER_BUILDKIT', '1' }) Task tag = project.tasks.create('dockerTag', { @@ -288,6 +290,9 @@ class BeamDockerPlugin implements Plugin { } else { buildCommandLine.addAll(['-t', "${-> ext.name}", '.']) } + if (ext.target != null && ext.target != "") { + buildCommandLine.addAll '--target', ext.target + } logger.debug("${buildCommandLine}" as String) return buildCommandLine } diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index a7e129211757..5af91ec2f056 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -637,6 +637,7 @@ class BeamModulePlugin implements Plugin { def sbe_tool_version = "1.25.1" def singlestore_jdbc_version = "1.1.4" def slf4j_version = "1.7.30" + def snakeyaml_engine_version = "2.6" def snakeyaml_version = "2.2" def solace_version = "10.21.0" def spark2_version = "2.4.8" @@ -714,7 +715,7 @@ class BeamModulePlugin implements Plugin { cdap_plugin_zendesk : "io.cdap.plugin:zendesk-plugins:1.0.0", checker_qual : "org.checkerframework:checker-qual:$checkerframework_version", classgraph : "io.github.classgraph:classgraph:$classgraph_version", - commons_codec : "commons-codec:commons-codec:1.17.0", + commons_codec : "commons-codec:commons-codec:1.17.1", commons_collections : "commons-collections:commons-collections:3.2.2", commons_compress : "org.apache.commons:commons-compress:1.26.2", commons_csv : "org.apache.commons:commons-csv:1.8", @@ -870,6 +871,7 @@ class BeamModulePlugin implements Plugin { singlestore_jdbc : "com.singlestore:singlestore-jdbc-client:$singlestore_jdbc_version", slf4j_api : "org.slf4j:slf4j-api:$slf4j_version", snake_yaml : "org.yaml:snakeyaml:$snakeyaml_version", + snakeyaml_engine : "org.snakeyaml:snakeyaml-engine:$snakeyaml_engine_version", slf4j_android : "org.slf4j:slf4j-android:$slf4j_version", slf4j_ext : "org.slf4j:slf4j-ext:$slf4j_version", slf4j_jdk14 : "org.slf4j:slf4j-jdk14:$slf4j_version", @@ -907,7 +909,7 @@ class BeamModulePlugin implements Plugin { testcontainers_solace : "org.testcontainers:solace:$testcontainers_version", truth : "com.google.truth:truth:1.1.5", threetenbp : "org.threeten:threetenbp:1.6.8", - vendored_grpc_1_60_1 : "org.apache.beam:beam-vendor-grpc-1_60_1:0.2", + vendored_grpc_1_60_1 : "org.apache.beam:beam-vendor-grpc-1_60_1:0.3", vendored_guava_32_1_2_jre : "org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1", vendored_calcite_1_28_0 : "org.apache.beam:beam-vendor-calcite-1_28_0:0.2", woodstox_core_asl : "org.codehaus.woodstox:woodstox-core-asl:4.4.1", @@ -2511,6 +2513,8 @@ class BeamModulePlugin implements Plugin { def taskName = "run${config.type}Java${config.runner}" def releaseVersion = project.findProperty('ver') ?: project.version def releaseRepo = project.findProperty('repourl') ?: 'https://repository.apache.org/content/repositories/snapshots' + // shared maven local path for maven archetype projects + def sharedMavenLocal = project.findProperty('mavenLocalPath') ?: '' def argsNeeded = [ "--ver=${releaseVersion}", "--repourl=${releaseRepo}" @@ -2530,6 +2534,9 @@ class BeamModulePlugin implements Plugin { if (config.pubsubTopic) { argsNeeded.add("--pubsubTopic=${config.pubsubTopic}") } + if (sharedMavenLocal) { + argsNeeded.add("--mavenLocalPath=${sharedMavenLocal}") + } project.evaluationDependsOn(':release') project.task(taskName, dependsOn: ':release:classes', type: JavaExec) { group = "Verification" @@ -3145,7 +3152,6 @@ class BeamModulePlugin implements Plugin { mustRunAfter = [ ":runners:flink:${project.ext.latestFlinkVersion}:job-server:shadowJar", ':runners:spark:3:job-server:shadowJar', - ':sdks:python:container:py38:docker', ':sdks:python:container:py39:docker', ':sdks:python:container:py310:docker', ':sdks:python:container:py311:docker', diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_60_1.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_60_1.groovy index b2c7053dfb60..62733efb507c 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_60_1.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/GrpcVendoring_1_60_1.groovy @@ -189,6 +189,9 @@ class GrpcVendoring_1_60_1 { "org/junit/**", "org/mockito/**", "org/objenesis/**", + // proto source files + "google/**/*.proto", + "grpc/**/*.proto", ] } diff --git a/contributor-docs/code-change-guide.md b/contributor-docs/code-change-guide.md index f0785d3509d0..b4300103454c 100644 --- a/contributor-docs/code-change-guide.md +++ b/contributor-docs/code-change-guide.md @@ -145,7 +145,7 @@ in the Google Cloud documentation. Depending on the languages involved, your `PATH` file needs to have the following elements configured. -* A Java environment that uses a supported Java version, preferably Java 8. +* A Java environment that uses a supported Java version, preferably Java 11. * This environment is needed for all development, because Beam is a Gradle project that uses JVM. * Recommended: To manage Java versions, use [sdkman](https://sdkman.io/install). @@ -624,6 +624,11 @@ Tips for using the Dataflow runner: ## Appendix +### Common Issues + +* If you run into some strange errors such as `java.lang.NoClassDefFoundError`, run `./gradlew clean` first +* To run one single Java test with gradle, use `--tests` to filter, for example, `./gradlew :it:google-cloud-platform:WordCountIntegrationTest --tests "org.apache.beam.it.gcp.WordCountIT.testWordCountDataflow"` + ### Directories of snapshot builds * https://repository.apache.org/content/groups/snapshots/org/apache/beam/ Java SDK build (nightly) diff --git a/contributor-docs/discussion-docs/2024.md b/contributor-docs/discussion-docs/2024.md index baea7c9fc462..124fe8ef9bb7 100644 --- a/contributor-docs/discussion-docs/2024.md +++ b/contributor-docs/discussion-docs/2024.md @@ -35,11 +35,12 @@ limitations under the License. | 18 | Danny McCormick | [GSoC Proposal : Implement RAG Pipelines using Beam](https://docs.google.com/document/d/1M_8fvqKVBi68hQo_x1AMQ8iEkzeXTcSl0CwTH00cr80) | 2024-05-01 16:12:23 | | 19 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - June 2024](https://s.apache.org/beam-draft-report-2024-06) | 2024-05-23 14:57:16 | | 20 | Jack McCluskey | [Embeddings in MLTransform](https://docs.google.com/document/d/1En4bfbTu4rvu7LWJIKV3G33jO-xJfTdbaSFSURmQw_s) | 2024-05-29 10:26:47 | -| 21 | Bartosz Zab��ocki | [[External] Solace IO - Read Connector](https://docs.google.com/document/d/1Gvq67VrcHCnlO8f_NzMM1Y4c7wCNSdvo6qqLWg8upfw) | 2024-05-29 12:00:23 | +| 21 | Bartosz Zabłocki | [[External] Solace IO - Read Connector](https://docs.google.com/document/d/1Gvq67VrcHCnlO8f_NzMM1Y4c7wCNSdvo6qqLWg8upfw) | 2024-05-29 12:00:23 | | 22 | Danny McCormick | [RunInference Timeouts](https://docs.google.com/document/d/19ves6iv-m_6DFmePJZqYpLm-bCooPu6wQ-Ti6kAl2Jo) | 2024-08-07 07:11:38 | | 23 | Jack McCluskey | [BatchElements in Beam Python](https://docs.google.com/document/d/1fOjIjIUH5dxllOGp5Z4ZmpM7BJhAJc2-hNjTnyChvgc) | 2024-08-15 14:56:26 | | 24 | XQ Hu | [[Public] Beam 3.0: a discussion doc](https://docs.google.com/document/d/13r4NvuvFdysqjCTzMHLuUUXjKTIEY3d7oDNIHT6guww) | 2024-08-19 17:17:26 | | 25 | Danny McCormick | [Beam Patch Release Process](https://docs.google.com/document/d/1o4UK444hCm1t5KZ9ufEu33e_o400ONAehXUR9A34qc8) | 2024-08-23 04:51:48 | | 26 | Jack McCluskey | [Beam Python Type Hinting](https://s.apache.org/beam-python-type-hinting-overview) | 2024-08-26 14:16:42 | | 27 | Ahmed Abualsaud | [Python Multi-language with SchemaTransforms](https://docs.google.com/document/d/1_embA3pGwoYG7sbHaYzAkg3hNxjTughhFCY8ThcoK_Q) | 2024-08-26 19:53:10 | -| 28 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - September 2024](https://s.apache.org/beam-draft-report-2024-09) | 2024-09-11 15:01:55 | \ No newline at end of file +| 28 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - September 2024](https://s.apache.org/beam-draft-report-2024-09) | 2024-09-11 15:01:55 | +| 29 | Jeff Kinard | [Beam YA(ML)^2](https://docs.google.com/document/d/1z9lNlSBfqDVdOP1frJNv_NJoMR1F1VBI29wn788x6IE/) | 2024-09-11 15:01:55 | diff --git a/contributor-docs/release-guide.md b/contributor-docs/release-guide.md index 4fe35aa4aac2..df7f45cc8179 100644 --- a/contributor-docs/release-guide.md +++ b/contributor-docs/release-guide.md @@ -507,7 +507,7 @@ with tags: `${RELEASE_VERSION}rc${RC_NUM}` Verify that third party licenses are included in Docker. You can do this with a simple script: RC_TAG=${RELEASE_VERSION}rc${RC_NUM} - for pyver in 3.8 3.9 3.10 3.11; do + for pyver in 3.9 3.10 3.11 3.12; do docker run --rm --entrypoint sh \ apache/beam_python${pyver}_sdk:${RC_TAG} \ -c 'ls -al /opt/apache/beam/third_party_licenses/ | wc -l' @@ -554,10 +554,10 @@ to PyPI with an `rc` suffix. __Attention:__ Verify that: - [ ] The File names version include ``rc-#`` suffix -- [ ] [Download Files](https://pypi.org/project/apache-beam/#files) have: - - [ ] All wheels uploaded as artifacts - - [ ] Release source's zip published - - [ ] Signatures and hashes do not need to be uploaded +- [Download Files](https://pypi.org/project/apache-beam/#files) have: +- [ ] All wheels uploaded as artifacts +- [ ] Release source's zip published +- [ ] Signatures and hashes do not need to be uploaded ### Propose pull requests for website updates @@ -887,7 +887,7 @@ write to BigQuery, and create a cluster of machines for running containers (for ``` **Flink Local Runner** ``` - ./gradlew :runners:flink:1.18:runQuickstartJavaFlinkLocal \ + ./gradlew :runners:flink:1.19:runQuickstartJavaFlinkLocal \ -Prepourl=https://repository.apache.org/content/repositories/orgapachebeam-${KEY} \ -Pver=${RELEASE_VERSION} ``` @@ -1148,7 +1148,7 @@ All wheels should be published, in addition to the zip of the release source. ### Merge Website pull requests Merge all of the website pull requests -- [listing the release](/get-started/downloads/) +- [listing the release](https://beam.apache.org/get-started/downloads/) - publishing the [Python API reference manual](https://beam.apache.org/releases/pydoc/) and the [Java API reference manual](https://beam.apache.org/releases/javadoc/), and - adding the release blog post. diff --git a/examples/java/build.gradle b/examples/java/build.gradle index af91fa83fe91..4f1902cf1679 100644 --- a/examples/java/build.gradle +++ b/examples/java/build.gradle @@ -66,6 +66,8 @@ dependencies { implementation project(":sdks:java:extensions:python") implementation project(":sdks:java:io:google-cloud-platform") implementation project(":sdks:java:io:kafka") + runtimeOnly project(":sdks:java:io:iceberg") + implementation project(":sdks:java:managed") implementation project(":sdks:java:extensions:ml") implementation library.java.avro implementation library.java.bigdataoss_util @@ -100,6 +102,8 @@ dependencies { implementation "org.apache.httpcomponents:httpcore:4.4.13" implementation "com.fasterxml.jackson.core:jackson-annotations:2.14.1" implementation "com.fasterxml.jackson.core:jackson-core:2.14.1" + runtimeOnly library.java.hadoop_client + runtimeOnly library.java.bigdataoss_gcs_connector testImplementation project(path: ":runners:direct-java", configuration: "shadow") testImplementation project(":sdks:java:io:google-cloud-platform") testImplementation project(":sdks:java:extensions:ml") diff --git a/examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java b/examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java new file mode 100644 index 000000000000..446d11d03be4 --- /dev/null +++ b/examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples.cookbook; + +import java.util.Arrays; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.transforms.Filter; +import org.apache.beam.sdk.transforms.JsonToRow; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +/** + * Reads real-time NYC taxi ride information from {@code + * projects/pubsub-public-data/topics/taxirides-realtime} and writes to Iceberg tables using Beam's + * {@link Managed} IcebergIO sink. + * + *

This is a streaming pipeline that writes records to Iceberg tables dynamically, depending on + * each record's passenger count. New tables are created as needed. We set a triggering frequency of + * 10s; at around this interval, the sink will accumulate records and write them to the appropriate + * table, creating a new snapshot each time. + */ +public class IcebergTaxiExamples { + private static final String TAXI_RIDES_TOPIC = + "projects/pubsub-public-data/topics/taxirides-realtime"; + private static final Schema TAXI_RIDE_INFO_SCHEMA = + Schema.builder() + .addStringField("ride_id") + .addInt32Field("point_idx") + .addDoubleField("latitude") + .addDoubleField("longitude") + .addStringField("timestamp") + .addDoubleField("meter_reading") + .addDoubleField("meter_increment") + .addStringField("ride_status") + .addInt32Field("passenger_count") + .build(); + + public static void main(String[] args) { + IcebergPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).as(IcebergPipelineOptions.class); + options.setProject("apache-beam-testing"); + + // each record's 'passenger_count' value will be substituted in to determine + // its final table destination + // e.g. an event with 3 passengers will be written to 'iceberg_taxi.3_passengers' + String tableIdentifierTemplate = "iceberg_taxi.{passenger_count}_passengers"; + + Map catalogProps = + ImmutableMap.builder() + .put("catalog-impl", options.getCatalogImpl()) + .put("warehouse", options.getWarehouse()) + .build(); + Map icebergWriteConfig = + ImmutableMap.builder() + .put("table", tableIdentifierTemplate) + .put("catalog_name", options.getCatalogName()) + .put("catalog_properties", catalogProps) + .put("triggering_frequency_seconds", 10) + // perform a final filter to only write these two columns + .put("keep", Arrays.asList("ride_id", "meter_reading")) + .build(); + + Pipeline p = Pipeline.create(options); + p + // Read taxi ride data + .apply(PubsubIO.readStrings().fromTopic(TAXI_RIDES_TOPIC)) + // Convert JSON strings to Beam Rows + .apply(JsonToRow.withSchema(TAXI_RIDE_INFO_SCHEMA)) + // Filter to only include drop-offs + .apply(Filter.create().whereFieldName("ride_status", "dropoff"::equals)) + // Write to Iceberg tables + .apply(Managed.write(Managed.ICEBERG).withConfig(icebergWriteConfig)); + p.run(); + } + + public interface IcebergPipelineOptions extends GcpOptions { + @Description("Warehouse location where the table's data will be written to.") + @Default.String("gs://apache-beam-samples/iceberg-examples") + String getWarehouse(); + + void setWarehouse(String warehouse); + + @Description("Fully-qualified name of the catalog class to use.") + @Default.String("org.apache.iceberg.hadoop.HadoopCatalog") + String getCatalogImpl(); + + void setCatalogImpl(String catalogName); + + @Validation.Required + @Default.String("example-catalog") + String getCatalogName(); + + void setCatalogName(String catalogName); + } +} diff --git a/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java b/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java index 6b9916e87271..43b06347a39a 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java @@ -25,7 +25,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.MapElements; @@ -33,7 +32,6 @@ import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -66,7 +64,6 @@ public void testExtractWordsFn() throws Exception { /** Example test that tests a PTransform by using an in-memory input and inspecting the output. */ @Test - @Category(ValidatesRunner.class) public void testCountWords() throws Exception { PCollection input = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of())); diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java index 96d10a1f72ed..614d289d2d60 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java @@ -21,12 +21,10 @@ import java.util.Arrays; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,7 +35,6 @@ public class TopWikipediaSessionsTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testComputeTopUsers() { PCollection output = diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java index 33d3c5699477..9c99c3aafdcc 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java @@ -24,13 +24,11 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -69,7 +67,6 @@ public class GameStatsTest implements Serializable { /** Test the calculation of 'spammy users'. */ @Test - @Category(ValidatesRunner.class) public void testCalculateSpammyUsers() throws Exception { PCollection> input = p.apply(Create.of(USER_SCORES)); PCollection> output = input.apply(new CalculateSpammyUsers()); diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java index 1d89351adcf8..46d7b41746ab 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java @@ -26,7 +26,6 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.MapElements; @@ -37,7 +36,6 @@ import org.joda.time.Instant; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -90,7 +88,6 @@ public class HourlyTeamScoreTest implements Serializable { /** Test the filtering. */ @Test - @Category(ValidatesRunner.class) public void testUserScoresFilter() throws Exception { final Instant startMinTimestamp = new Instant(1447965680000L); diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java index 04aa122054bd..22fe98a50304 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java @@ -26,7 +26,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; @@ -36,7 +35,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -114,7 +112,6 @@ public void testParseEventFn() throws Exception { /** Tests ExtractAndSumScore("user"). */ @Test - @Category(ValidatesRunner.class) public void testUserScoreSums() throws Exception { PCollection input = p.apply(Create.of(GAME_EVENTS)); @@ -133,7 +130,6 @@ public void testUserScoreSums() throws Exception { /** Tests ExtractAndSumScore("team"). */ @Test - @Category(ValidatesRunner.class) public void testTeamScoreSums() throws Exception { PCollection input = p.apply(Create.of(GAME_EVENTS)); @@ -152,7 +148,6 @@ public void testTeamScoreSums() throws Exception { /** Test that bad input data is dropped appropriately. */ @Test - @Category(ValidatesRunner.class) public void testUserScoresBadInput() throws Exception { PCollection input = p.apply(Create.of(GAME_EVENTS2).withCoder(StringUtf8Coder.of())); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java index 2bd37a3caa52..110349c353e8 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java @@ -22,7 +22,6 @@ import org.apache.beam.examples.cookbook.BigQueryTornadoes.FormatCountsFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -31,7 +30,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,7 +39,6 @@ public class BigQueryTornadoesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testExtractTornadoes() { TableRow row = new TableRow().set("month", "6").set("tornado", true); PCollection input = p.apply(Create.of(ImmutableList.of(row))); @@ -51,7 +48,6 @@ public void testExtractTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testNoTornadoes() { TableRow row = new TableRow().set("month", 6).set("tornado", false); PCollection inputs = p.apply(Create.of(ImmutableList.of(row))); @@ -61,7 +57,6 @@ public void testNoTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testEmpty() { PCollection> inputs = p.apply(Create.empty(new TypeDescriptor>() {})); @@ -71,7 +66,6 @@ public void testEmpty() { } @Test - @Category(ValidatesRunner.class) public void testFormatCounts() { PCollection> inputs = p.apply(Create.of(KV.of(3, 0L), KV.of(4, Long.MAX_VALUE), KV.of(5, Long.MIN_VALUE))); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java index 988a492ad4a9..7ec889f0d2b4 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java @@ -22,13 +22,11 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Distinct; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -39,7 +37,6 @@ public class DistinctExampleTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testDistinct() { List strings = Arrays.asList("k1", "k5", "k5", "k2", "k1", "k2", "k3"); @@ -52,7 +49,6 @@ public void testDistinct() { } @Test - @Category(ValidatesRunner.class) public void testDistinctEmpty() { List strings = Arrays.asList(); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java index dedc0e313350..4eeb5c4b7dd0 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java @@ -22,13 +22,11 @@ import org.apache.beam.examples.cookbook.FilterExamples.ProjectionFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -68,7 +66,6 @@ public class FilterExamplesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testProjectionFn() { PCollection input = p.apply(Create.of(row1, row2, row3)); @@ -79,7 +76,6 @@ public void testProjectionFn() { } @Test - @Category(ValidatesRunner.class) public void testFilterSingleMonthDataFn() { PCollection input = p.apply(Create.of(outRow1, outRow2, outRow3)); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java index b91ca985ddcd..d27572667752 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java @@ -24,14 +24,12 @@ import org.apache.beam.examples.cookbook.JoinExamples.ExtractEventDataFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -107,7 +105,6 @@ public void testExtractCountryInfoFn() throws Exception { } @Test - @Category(ValidatesRunner.class) public void testJoin() throws java.lang.Exception { PCollection input1 = p.apply("CreateEvent", Create.of(EVENT_ARRAY)); PCollection input2 = p.apply("CreateCC", Create.of(CC_ARRAY)); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java index 5c32f36660d6..410b151ed32f 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java @@ -23,7 +23,6 @@ import org.apache.beam.examples.cookbook.MaxPerKeyExamples.FormatMaxesFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -31,7 +30,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -78,7 +76,6 @@ public class MaxPerKeyExamplesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testExtractTempFn() { PCollection> results = p.apply(Create.of(TEST_ROWS)).apply(ParDo.of(new ExtractTempFn())); @@ -87,7 +84,6 @@ public void testExtractTempFn() { } @Test - @Category(ValidatesRunner.class) public void testFormatMaxesFn() { PCollection results = p.apply(Create.of(TEST_KVS)).apply(ParDo.of(new FormatMaxesFn())); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java index 7e922bc87965..fb08730a9f54 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java @@ -22,7 +22,6 @@ import org.apache.beam.examples.cookbook.MinimalBigQueryTornadoes.FormatCountsFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -31,7 +30,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,7 +39,6 @@ public class MinimalBigQueryTornadoesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testExtractTornadoes() { TableRow row = new TableRow().set("month", "6").set("tornado", true); PCollection input = p.apply(Create.of(ImmutableList.of(row))); @@ -51,7 +48,6 @@ public void testExtractTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testNoTornadoes() { TableRow row = new TableRow().set("month", 6).set("tornado", false); PCollection inputs = p.apply(Create.of(ImmutableList.of(row))); @@ -61,7 +57,6 @@ public void testNoTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testEmpty() { PCollection> inputs = p.apply(Create.empty(new TypeDescriptor>() {})); @@ -71,7 +66,6 @@ public void testEmpty() { } @Test - @Category(ValidatesRunner.class) public void testFormatCounts() { PCollection> inputs = p.apply(Create.of(KV.of(3, 0L), KV.of(4, Long.MAX_VALUE), KV.of(5, Long.MIN_VALUE))); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java index 8f076a9d8d89..19c83c6eb73c 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java @@ -27,7 +27,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; @@ -42,7 +41,6 @@ import org.joda.time.Instant; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -118,7 +116,6 @@ public void testExtractTotalFlow() { } @Test - @Category(ValidatesRunner.class) public void testTotalFlow() { PCollection> flow = pipeline diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt index 56702a3a1746..514727878a44 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt @@ -39,7 +39,6 @@ class DistinctExampleTest { fun pipeline(): TestPipeline = pipeline @Test - @Category(ValidatesRunner::class) fun testDistinct() { val strings = listOf("k1", "k5", "k5", "k2", "k1", "k2", "k3") val input = pipeline.apply(Create.of(strings).withCoder(StringUtf8Coder.of())) @@ -49,7 +48,6 @@ class DistinctExampleTest { } @Test - @Category(ValidatesRunner::class) fun testDistinctEmpty() { val strings = listOf() val input = pipeline.apply(Create.of(strings).withCoder(StringUtf8Coder.of())) diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt index d5cb544a7606..8cfffe15fc65 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt @@ -64,7 +64,6 @@ class FilterExamplesTest { fun pipeline(): TestPipeline = pipeline @Test - @Category(ValidatesRunner::class) fun testProjectionFn() { val input = pipeline.apply(Create.of(row1, row2, row3)) val results = input.apply(ParDo.of(ProjectionFn())) @@ -73,7 +72,6 @@ class FilterExamplesTest { } @Test - @Category(ValidatesRunner::class) fun testFilterSingleMonthDataFn() { val input = pipeline.apply(Create.of(outRow1, outRow2, outRow3)) val results = input.apply(ParDo.of(FilterSingleMonthDataFn(7))) diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt index 8728a827229b..6bb818f5efa9 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt @@ -93,7 +93,6 @@ class JoinExamplesTest { } @Test - @Category(ValidatesRunner::class) fun testJoin() { val input1 = pipeline.apply("CreateEvent", Create.of(EVENT_ARRAY)) val input2 = pipeline.apply("CreateCC", Create.of(CC_ARRAY)) diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt index 7995d9c1c795..409434e02686 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt @@ -69,7 +69,6 @@ class MaxPerKeyExamplesTest { fun pipeline(): TestPipeline = pipeline @Test - @Category(ValidatesRunner::class) fun testExtractTempFn() { val results = pipeline.apply(Create.of(testRows)).apply(ParDo.of>(MaxPerKeyExamples.ExtractTempFn())) PAssert.that(results).containsInAnyOrder(ImmutableList.of(kv1, kv2, kv3)) @@ -77,7 +76,6 @@ class MaxPerKeyExamplesTest { } @Test - @Category(ValidatesRunner::class) fun testFormatMaxesFn() { val results = pipeline.apply(Create.of(testKvs)).apply(ParDo.of, TableRow>(MaxPerKeyExamples.FormatMaxesFn())) PAssert.that(results).containsInAnyOrder(resultRow1, resultRow2, resultRow3) diff --git a/bigquery_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb similarity index 98% rename from bigquery_enrichment_transform.ipynb rename to examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb index 331ecb9ba93d..182b88b9c72a 100644 --- a/bigquery_enrichment_transform.ipynb +++ b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb @@ -15,16 +15,6 @@ } }, "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, { "cell_type": "code", "source": [ @@ -573,7 +563,7 @@ "In this example, you create two handlers:\n", "\n", "* One for customer data that specifies `table_name` and `row_restriction_template`\n", - "* One for for usage data that uses a custom aggregation query by using the `query_fn` function\n", + "* One for usage data that uses a custom aggregation query by using the `query_fn` function\n", "\n", "These handlers are used in the Enrichment transforms in this pipeline to fetch and join data from BigQuery with the streaming data." ], @@ -778,4 +768,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb index d7b2b157f613..686c19da7f66 100644 --- a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb +++ b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb @@ -174,7 +174,8 @@ "WORKDIR /workspace\n", "\n", "COPY gemma2 gemma2\n", - "RUN apt-get update -y && apt-get install -y cmake && apt-get install -y vim" + "RUN apt-get update -y && apt-get install -y cmake && apt-get install -y vim\n", + "```" ] }, { @@ -208,7 +209,8 @@ "apache_beam[gcp]==2.54.0\n", "keras_nlp==0.14.3\n", "keras==3.4.1\n", - "jax[cuda12]" + "jax[cuda12]\n", + "```" ] }, { @@ -261,7 +263,8 @@ "\n", "\n", "# Set the entrypoint to the Apache Beam SDK launcher.\n", - "ENTRYPOINT [\"/opt/apache/beam/boot\"]" + "ENTRYPOINT [\"/opt/apache/beam/boot\"]\n", + "```" ] }, { diff --git a/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb b/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb index cc31ff678fe4..aae86e31aa44 100644 --- a/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb +++ b/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb @@ -209,7 +209,7 @@ "\n", "3. Create the index.\n", "\n", - "4. Index creation is neeeded only once." + "4. Index creation is needed only once." ] }, { diff --git a/examples/notebooks/beam-ml/run_inference_vllm.ipynb b/examples/notebooks/beam-ml/run_inference_vllm.ipynb new file mode 100644 index 000000000000..e9f1e53a452b --- /dev/null +++ b/examples/notebooks/beam-ml/run_inference_vllm.ipynb @@ -0,0 +1,616 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4", + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", + "\n", + "# Licensed to the Apache Software Foundation (ASF) under one\n", + "# or more contributor license agreements. See the NOTICE file\n", + "# distributed with this work for additional information\n", + "# regarding copyright ownership. The ASF licenses this file\n", + "# to you under the Apache License, Version 2.0 (the\n", + "# \"License\"); you may not use this file except in compliance\n", + "# with the License. You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing,\n", + "# software distributed under the License is distributed on an\n", + "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + "# KIND, either express or implied. See the License for the\n", + "# specific language governing permissions and limitations\n", + "# under the License" + ], + "metadata": { + "id": "OsFaZscKSPvo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Run ML inference by using vLLM on GPUs\n", + "\n", + "\n", + " \n", + " \n", + "
\n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" + ], + "metadata": { + "id": "NrHRIznKp3nS" + } + }, + { + "cell_type": "markdown", + "source": [ + "[vLLM](https://github.com/vllm-project/vllm) is a fast and user-frienly library for LLM inference and serving. vLLM optimizes LLM inference with mechanisms like PagedAttention for memory management and continuous batching for increasing throughput. For popular models, vLLM has been shown to increase throughput by a multiple of 2 to 4. With Apache Beam, you can serve models with vLLM and scale that serving with just a few lines of code.\n", + "\n", + "This notebook demonstrates how to run machine learning inference by using vLLM and GPUs in three ways:\n", + "\n", + "* locally without Apache Beam\n", + "* locally with the Apache Beam local runner\n", + "* remotely with the Dataflow runner\n", + "\n", + "It also shows how to swap in a different model without modifying your pipeline structure by changing the configuration." + ], + "metadata": { + "id": "H0ZFs9rDvtJm" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Requirements\n", + "\n", + "This notebook assumes that a GPU is enabled in Colab. If this setting isn't enabled, the locally executed sections of this notebook might not work. To enable a GPU, in the Colab menu, click **Runtime** > **Change runtime type**. For **Hardware accelerator**, choose a GPU accelerator. If you can't access a GPU in Colab, you can run the Dataflow section of this notebook.\n", + "\n", + "To run the Dataflow section, you need access to the following resources:\n", + "\n", + "- a computer with Docker installed\n", + "- a [Google Cloud](https://cloud.google.com/) account" + ], + "metadata": { + "id": "6x41tnbTvQM1" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Install dependencies\n", + "\n", + "Before creating your pipeline, download and install the dependencies required to develop with Apache Beam and vLLM. vLLM is supported in Apache Beam versions 2.60.0 and later." + ], + "metadata": { + "id": "8PSjyDIavRcn" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "irCKNe42p22r" + }, + "outputs": [], + "source": [ + "!pip install openai>=1.52.2\n", + "!pip install vllm>=0.6.3\n", + "!pip install apache-beam[gcp]==2.60.0\n", + "!pip check" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Run locally without Apache Beam\n", + "\n", + "In this section, you run a vLLM server without using Apache Beam. Use the `facebook/opt-125m` model. This model is small enough to fit in Colab memory and doesn't require any extra authentication.\n", + "\n", + "First, start the vLLM server. This step might take a minute or two, because the model needs to download before vLLM starts running inference." + ], + "metadata": { + "id": "3xz8zuA7vcS4" + } + }, + { + "cell_type": "code", + "source": [ + "! python -m vllm.entrypoints.openai.api_server --model facebook/opt-125m" + ], + "metadata": { + "id": "GbJGzINNt5sG" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, while the vLLM server is running, open a separate terminal to communicate with the vLLM serving process. To open a terminal in Colab, in the sidebar, click **Terminal**. In the terminal, run the following commands.\n", + "\n", + "```\n", + "pip install openai\n", + "python\n", + "\n", + "from openai import OpenAI\n", + "\n", + "# Modify OpenAI's API key and API base to use vLLM's API server.\n", + "openai_api_key = \"EMPTY\"\n", + "openai_api_base = \"http://localhost:8000/v1\"\n", + "client = OpenAI(\n", + " api_key=openai_api_key,\n", + " base_url=openai_api_base,\n", + ")\n", + "completion = client.completions.create(model=\"facebook/opt-125m\",\n", + " prompt=\"San Francisco is a\")\n", + "print(\"Completion result:\", completion)\n", + "```\n", + "\n", + "This code runs against the server running in the cell. You can experiment with different prompts." + ], + "metadata": { + "id": "n35LXTS3uzIC" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Run locally with Apache Beam\n", + "\n", + "In this section, you set up an Apache Beam pipeline to run a job with an embedded vLLM instance.\n", + "\n", + "First, define the `VllmCompletionsModelHandler` object. This configuration object gives Apache Beam the information that it needs to create a dedicated vLLM process in the middle of the pipeline. Apache Beam then provides examples to the pipeline. No additional code is needed." + ], + "metadata": { + "id": "Hbxi83BfwbBa" + } + }, + { + "cell_type": "code", + "source": [ + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n", + "from apache_beam.ml.inference.base import PredictionResult\n", + "import apache_beam as beam\n", + "\n", + "model_handler = VLLMCompletionsModelHandler('facebook/opt-125m')" + ], + "metadata": { + "id": "sUqjOzw3wpI4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, define examples to run inference against, and define a helper function to print out the inference results." + ], + "metadata": { + "id": "N06lXRKRxCz5" + } + }, + { + "cell_type": "code", + "source": [ + "class FormatOutput(beam.DoFn):\n", + " def process(self, element, *args, **kwargs):\n", + " yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n", + "\n", + "prompts = [\n", + " \"Hello, my name is\",\n", + " \"The president of the United States is\",\n", + " \"The capital of France is\",\n", + " \"The future of AI is\",\n", + " \"Emperor penguins are\",\n", + "]" + ], + "metadata": { + "id": "3a1PznmtxNR_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Finally, run the pipeline.\n", + "\n", + "This step might take a minute or two, because the model needs to download before Apache Beam can start running inference." + ], + "metadata": { + "id": "Njl0QfrLxQ0m" + } + }, + { + "cell_type": "code", + "source": [ + "with beam.Pipeline() as p:\n", + " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", + " | RunInference(model_handler) # Send the prompts to the model and get responses.\n", + " | beam.ParDo(FormatOutput()) # Format the output.\n", + " | beam.Map(print) # Print the formatted output.\n", + " )" + ], + "metadata": { + "id": "9yXbzV0ZmZcJ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Run remotely on Dataflow\n", + "\n", + "After you validate that the pipeline can run against a vLLM locally, you can productionalize the workflow on a remote runner. This notebook runs the pipeline on the Dataflow runner." + ], + "metadata": { + "id": "Jv7be6Pk9Hlx" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Build a Docker image\n", + "\n", + "To run a pipeline with vLLM on Dataflow, you must create a Docker image that contains your dependencies and is compatible with a GPU runtime. For more information about building GPU compatible Dataflow containers, see [Build a custom container image](https://cloud.google.com/dataflow/docs/gpu/use-gpus#custom-container) in the Datafow documentation.\n", + "\n", + "First, define and save your Dockerfile. This file uses an Nvidia GPU-compatible base image. In the Dockerfile, install the Python dependencies needed to run the job.\n", + "\n", + "Before proceeding, make sure that your configuration meets the following requirements:\n", + "\n", + "- The Python version in the following cell matches the Python version defined in the Dockerfile.\n", + "- The Apache Beam version defined in your dependencies matches the Apache Beam version defined in the Dockerfile." + ], + "metadata": { + "id": "J1LMrl1Yy6QB" + } + }, + { + "cell_type": "code", + "source": [ + "!python --version" + ], + "metadata": { + "id": "jCQ6-D55gqfl" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "cell_str='''\n", + "FROM nvidia/cuda:12.4.1-devel-ubuntu22.04\n", + "\n", + "RUN apt update\n", + "RUN apt install software-properties-common -y\n", + "RUN add-apt-repository ppa:deadsnakes/ppa\n", + "RUN apt update\n", + "RUN apt-get update\n", + "\n", + "ARG DEBIAN_FRONTEND=noninteractive\n", + "\n", + "RUN apt install python3.10-full -y\n", + "# RUN apt install python3.10-venv -y\n", + "# RUN apt install python3.10-dev -y\n", + "RUN rm /usr/bin/python3\n", + "RUN ln -s python3.10 /usr/bin/python3\n", + "RUN python3 --version\n", + "RUN apt-get install -y curl\n", + "RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && pip install --upgrade pip\n", + "\n", + "# Copy the Apache Beam worker dependencies from the Beam Python 3.10 SDK image.\n", + "COPY --from=apache/beam_python3.10_sdk:2.60.0 /opt/apache/beam /opt/apache/beam\n", + "\n", + "RUN pip install --no-cache-dir -vvv apache-beam[gcp]==2.60.0\n", + "RUN pip install openai>=1.52.2 vllm>=0.6.3\n", + "\n", + "RUN apt install libcairo2-dev pkg-config python3-dev -y\n", + "RUN pip install pycairo\n", + "\n", + "# Set the entrypoint to Apache Beam SDK worker launcher.\n", + "ENTRYPOINT [ \"/opt/apache/beam/boot\" ]\n", + "'''\n", + "\n", + "with open('VllmDockerfile', 'w') as f:\n", + " f.write(cell_str)" + ], + "metadata": { + "id": "7QyNq_gygHLO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "After you save the Dockerfile, build and push your Docker image. Because Docker is not accessible from Colab, you need to complete this step in a separate environment.\n", + "\n", + "1. In the sidebar, click **Files** to open the **Files** pane.\n", + "2. In an environment with Docker installed, download the file **VllmDockerfile** file to an empty folder.\n", + "3. Run the following commands. Replace `:` with a valid [Artifact Registry](https://cloud.google.com/artifact-registry/docs/overview) repository and tag.\n", + "\n", + " ```\n", + " docker build -t \":\" -f VllmDockerfile ./\n", + " docker image push \":\"\n", + " ```" + ], + "metadata": { + "id": "zWma0YetiEn5" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Define and run the pipeline\n", + "\n", + "When you have a working Docker image, define and run your pipeline.\n", + "\n", + "First, define the pipeline options that you want to use to launch the Dataflow job. Before running the next cell, replace the following variables:\n", + "\n", + "- ``: the name of a valid [Google Cloud Storage](https://cloud.google.com/storage?e=48754805&hl=en) bucket. Don't include a `gs://` prefix or trailing slashes.\n", + "- ``: the name of the Google Artifact Registry repository that you used in the previous step. \n", + "- ``: image tag used in the previous step. Prefer a versioned tag or SHA instead of :latest tag or mutable tags.\n", + "- ``: the name of the Google Cloud project that you created your bucket and Artifact Registry repository in.\n", + "\n", + "This workflow uses the following Dataflow service option: `worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx`. When you use this service option, Dataflow to installs a T4 GPU that uses a `5xx` series Nvidia driver on each worker machine. The 5xx driver is required to run vLLM jobs." + ], + "metadata": { + "id": "NjZyRjte0g0Q" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "from apache_beam.options.pipeline_options import GoogleCloudOptions\n", + "from apache_beam.options.pipeline_options import PipelineOptions\n", + "from apache_beam.options.pipeline_options import SetupOptions\n", + "from apache_beam.options.pipeline_options import StandardOptions\n", + "from apache_beam.options.pipeline_options import WorkerOptions\n", + "\n", + "\n", + "options = PipelineOptions()\n", + "\n", + "BUCKET_NAME = '' # Replace with your bucket name.\n", + "CONTAINER_IMAGE = ':' # Replace with the image repository and tag from the previous step.\n", + "PROJECT_NAME = '' # Replace with your GCP project\n", + "\n", + "options.view_as(GoogleCloudOptions).project = PROJECT_NAME\n", + "\n", + "# Provide required pipeline options for the Dataflow Runner.\n", + "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n", + "\n", + "# Set the Google Cloud region that you want to run Dataflow in.\n", + "options.view_as(GoogleCloudOptions).region = 'us-central1'\n", + "\n", + "# IMPORTANT: Replace BUCKET_NAME with the name of your Cloud Storage bucket.\n", + "dataflow_gcs_location = \"gs://%s/dataflow\" % BUCKET_NAME\n", + "\n", + "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", + "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n", + "\n", + "\n", + "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", + "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n", + "\n", + "# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.\n", + "options.view_as(GoogleCloudOptions).temp_location = '%s/temp' % dataflow_gcs_location\n", + "\n", + "# Enable GPU runtime. Make sure to enable 5xx driver since vLLM only works with 5xx drivers, not 4xx\n", + "options.view_as(GoogleCloudOptions).dataflow_service_options = [\"worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx\"]\n", + "\n", + "options.view_as(SetupOptions).save_main_session = True\n", + "\n", + "# Choose a machine type compatible with GPU type\n", + "options.view_as(WorkerOptions).machine_type = \"n1-standard-4\"\n", + "\n", + "options.view_as(WorkerOptions).sdk_container_image = CONTAINER_IMAGE" + ], + "metadata": { + "id": "kXy9FRYVCSjq" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, authenticate Colab so that it can to submit a job on your behalf." + ], + "metadata": { + "id": "xPhe597P1-QJ" + } + }, + { + "cell_type": "code", + "source": [ + "def auth_to_colab():\n", + " from google.colab import auth\n", + " auth.authenticate_user()\n", + "\n", + "auth_to_colab()" + ], + "metadata": { + "id": "Xkf6yIVlFB8-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Finally, run the pipeline on Dataflow. The pipeline definition is almost exactly the same as the definition used for local execution. The pipeline options are the only change to the pipeline.\n", + "\n", + "The following code creates a Dataflow job in your project. You can view the results in Colab or in the Google Cloud console. Creating a Dataflow job and downloading the model might take a few minutes. After the job starts performing inference, it quickly runs through the inputs." + ], + "metadata": { + "id": "MJtEI6Ux2eza" + } + }, + { + "cell_type": "code", + "source": [ + "import logging\n", + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n", + "from apache_beam.ml.inference.base import PredictionResult\n", + "import apache_beam as beam\n", + "\n", + "class FormatOutput(beam.DoFn):\n", + " def process(self, element, *args, **kwargs):\n", + " yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n", + "\n", + "logging.getLogger().setLevel(logging.INFO) # Output additional Dataflow Job metadata and launch logs. \n", + "prompts = [\n", + " \"Hello, my name is\",\n", + " \"The president of the United States is\",\n", + " \"The capital of France is\",\n", + " \"The future of AI is\",\n", + " \"John cena is\",\n", + "]\n", + "\n", + "# Specify the model handler, providing a path and the custom inference function.\n", + "model_handler = VLLMCompletionsModelHandler('facebook/opt-125m')\n", + "\n", + "with beam.Pipeline(options=options) as p:\n", + " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", + " | RunInference(model_handler) # Send the prompts to the model and get responses.\n", + " | beam.ParDo(FormatOutput()) # Format the output.\n", + " | beam.Map(logging.info) # Print the formatted output.\n", + " )" + ], + "metadata": { + "id": "8gjDdru_9Dii" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Run vLLM with a Gemma model\n", + "\n", + "After you configure your pipeline, switching the model used by the pipeline is relatively straightforward. You can run the same pipeline, but switch the model name defined in the model handler. This example runs the pipeline created previously but uses a [Gemma](https://ai.google.dev/gemma) model.\n", + "\n", + "Before you start, sign in to HuggingFace, and make sure that you can access the Gemma models. To access Gemma models, you must accept the terms and conditions.\n", + "\n", + "1. Navigate to the [Gemma Model Card](https://huggingface.co/google/gemma-2b).\n", + "2. Sign in, or sign up for a free HuggingFace account.\n", + "3. Follow the prompts to agree to the conditions\n", + "\n", + "When you complete these steps, the following message appears on the model card page: `You have been granted access to this model`.\n", + "\n", + "Next, sign in to your account from this notebook by running the following code and then following the prompts." + ], + "metadata": { + "id": "22cEHPCc28fH" + } + }, + { + "cell_type": "code", + "source": [ + "! huggingface-cli login" + ], + "metadata": { + "id": "JHwIsFI9kd9j" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Verify that the notebook can now access the Gemma model. Run the following code, which starts a vLLM server to serve the Gemma 2b model. Because the default T4 Colab runtime doesn't support the full data type precision needed to run Gemma models, the `--dtype=half` parameter is required.\n", + "\n", + "When successful, the following cell runs indefinitely. After it starts the server process, you can shut it down. When the server process starts, the Gemma 2b model is successfully downloaded, and the server is ready to serve traffic." + ], + "metadata": { + "id": "IjX2If8rnCol" + } + }, + { + "cell_type": "code", + "source": [ + "! python -m vllm.entrypoints.openai.api_server --model google/gemma-2b --dtype=half" + ], + "metadata": { + "id": "LH_oCFWMiwFs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "To run the pipeline in Apache Beam, run the following code. Update the `VLLMCompletionsModelHandler` object with the new parameters, which match the command from the previous cell. Reuse all of the pipeline logic from the previous pipelines." + ], + "metadata": { + "id": "31BmdDUAn-SW" + } + }, + { + "cell_type": "code", + "source": [ + "model_handler = VLLMCompletionsModelHandler('google/gemma-2b', vllm_server_kwargs={'dtype': 'half'})\n", + "\n", + "with beam.Pipeline() as p:\n", + " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", + " | RunInference(model_handler) # Send the prompts to the model and get responses.\n", + " | beam.ParDo(FormatOutput()) # Format the output.\n", + " | beam.Map(print) # Print the formatted output.\n", + " )" + ], + "metadata": { + "id": "DyC2ikXg237p" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Run Gemma on Dataflow\n", + "\n", + "As a next step, run this pipeline on Dataflow. Follow the same steps described in the \"Run remotely on Dataflow\" section of this page:\n", + "\n", + "1. Construct a Dockerfile and push a new Docker image. You can use the same Dockerfile that you created previously, but you need to add a step to set your HuggingFace authentication key. In your Dockerfile, add the following line before the entrypoint:\n", + "\n", + " ```\n", + " RUN python3 -c 'from huggingface_hub import HfFolder; HfFolder.save_token(\"\")'\n", + " ```\n", + "\n", + "2. Set pipeline options. You can reuse the options defined in this notebook. Replace the Docker image location with your new Docker image.\n", + "3. Run the pipeline. Copy the pipeline that you ran on Dataflow, and replace the pipeline options with the pipeline options that you just defined.\n", + "\n" + ], + "metadata": { + "id": "C6OYfub6ovFK" + } + } + ] +} diff --git a/gradle.properties b/gradle.properties index f2a0b05eca09..ffd4efaaab32 100644 --- a/gradle.properties +++ b/gradle.properties @@ -39,6 +39,6 @@ docker_image_default_repo_root=apache docker_image_default_repo_prefix=beam_ # supported flink versions -flink_versions=1.15,1.16,1.17,1.18 +flink_versions=1.17,1.18,1.19 # supported python versions -python_versions=3.8,3.9,3.10,3.11,3.12 +python_versions=3.9,3.10,3.11,3.12 diff --git a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md index 6eb1c04e966a..71abe616f1ad 100644 --- a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md +++ b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md @@ -191,8 +191,8 @@ $ wordcount --input gs://dataflow-samples/shakespeare/kinglear.txt \ {{if (eq .Sdk "java")}} ##### Portable -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.15`, `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` @@ -233,8 +233,8 @@ mvn exec:java -Dexec.mainClass=org.apache.beam.examples.WordCount \ {{end}} {{if (eq .Sdk "python")}} -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.10`, `Flink 1.11`, `Flink 1.12`, `Flink 1.13`, `Flink 1.14`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` diff --git a/local-env-setup.sh b/local-env-setup.sh index f13dc88432a6..ba30813b2bcc 100755 --- a/local-env-setup.sh +++ b/local-env-setup.sh @@ -55,7 +55,7 @@ if [ "$kernelname" = "Linux" ]; then exit fi - for ver in 3.8 3.9 3.10 3.11 3.12 3; do + for ver in 3.9 3.10 3.11 3.12 3; do apt install --yes python$ver-venv done @@ -89,7 +89,7 @@ elif [ "$kernelname" = "Darwin" ]; then echo "Installing openjdk@8" brew install openjdk@8 fi - for ver in 3.8 3.9 3.10 3.11 3.12; do + for ver in 3.9 3.10 3.11 3.12; do if brew ls --versions python@$ver > /dev/null; then echo "python@$ver already installed. Skipping" brew info python@$ver diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index b03350966d6c..f102e82bafa6 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -70,6 +70,10 @@ message ManagedTransforms { "beam:schematransform:org.apache.beam:kafka_read:v1"]; KAFKA_WRITE = 3 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:schematransform:org.apache.beam:kafka_write:v1"]; + BIGQUERY_READ = 4 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:bigquery_storage_read:v1"]; + BIGQUERY_WRITE = 5 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:bigquery_write:v1"]; } } diff --git a/playground/backend/internal/preparers/python_preparers.go b/playground/backend/internal/preparers/python_preparers.go index f050237492b1..96a4ed32910a 100644 --- a/playground/backend/internal/preparers/python_preparers.go +++ b/playground/backend/internal/preparers/python_preparers.go @@ -26,7 +26,7 @@ import ( ) const ( - addLogHandlerCode = "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + addLogHandlerCode = "import logging\nlogging.basicConfig(\n level=logging.ERROR,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" oneIndentation = " " findWithPipelinePattern = `(\s*)with.+Pipeline.+as (.+):` indentationPattern = `^(%s){0,1}\w+` diff --git a/playground/backend/internal/preparers/python_preparers_test.go b/playground/backend/internal/preparers/python_preparers_test.go index b2cfa7eccaac..f333a1639b7c 100644 --- a/playground/backend/internal/preparers/python_preparers_test.go +++ b/playground/backend/internal/preparers/python_preparers_test.go @@ -53,7 +53,7 @@ func TestGetPythonPreparers(t *testing.T) { } func Test_addCodeToFile(t *testing.T) { - wantCode := "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + pyCode + wantCode := "import logging\nlogging.basicConfig(\n level=logging.ERROR,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + pyCode type args struct { args []interface{} diff --git a/playground/infrastructure/cloudbuild/playground_cd_examples.sh b/playground/infrastructure/cloudbuild/playground_cd_examples.sh index d05773656b30..e571bc9fc9d9 100644 --- a/playground/infrastructure/cloudbuild/playground_cd_examples.sh +++ b/playground/infrastructure/cloudbuild/playground_cd_examples.sh @@ -97,15 +97,15 @@ LogOutput "Installing python and dependencies." export DEBIAN_FRONTEND=noninteractive apt install -y apt-transport-https ca-certificates software-properties-common curl unzip apt-utils > /dev/null 2>&1 add-apt-repository -y ppa:deadsnakes/ppa > /dev/null 2>&1 && apt update > /dev/null 2>&1 -apt install -y python3.8 python3.8-distutils python3-pip > /dev/null 2>&1 -apt install -y --reinstall python3.8-distutils > /dev/null 2>&1 +apt install -y python3.9 python3-distutils python3-pip > /dev/null 2>&1 +apt install -y --reinstall python3-distutils > /dev/null 2>&1 apt install -y python3-virtualenv virtualenv play_venv source play_venv/bin/activate pip install --upgrade google-api-python-client > /dev/null 2>&1 -python3.8 -m pip install pip --upgrade > /dev/null 2>&1 -ln -s /usr/bin/python3.8 /usr/bin/python > /dev/null 2>&1 -apt install -y python3.8-venv > /dev/null 2>&1 +python3.9 -m pip install pip --upgrade > /dev/null 2>&1 +ln -s /usr/bin/python3.9 /usr/bin/python > /dev/null 2>&1 +apt install -y python3.9-venv > /dev/null 2>&1 LogOutput "Installing Python packages from beam/playground/infrastructure/requirements.txt" cd $BEAM_ROOT_DIR diff --git a/playground/infrastructure/cloudbuild/playground_ci_examples.sh b/playground/infrastructure/cloudbuild/playground_ci_examples.sh index 437cc337faf7..2a63382615a5 100755 --- a/playground/infrastructure/cloudbuild/playground_ci_examples.sh +++ b/playground/infrastructure/cloudbuild/playground_ci_examples.sh @@ -94,12 +94,12 @@ export DEBIAN_FRONTEND=noninteractive LogOutput "Installing Python environment" apt-get install -y apt-transport-https ca-certificates software-properties-common curl unzip apt-utils > /dev/null add-apt-repository -y ppa:deadsnakes/ppa > /dev/null && apt update > /dev/null -apt install -y python3.8 python3.8-distutils python3-pip > /dev/null -apt install --reinstall python3.8-distutils > /dev/null +apt install -y python3.9 python3-distutils python3-pip > /dev/null +apt install --reinstall python3-distutils > /dev/null pip install --upgrade google-api-python-client > /dev/null -python3.8 -m pip install pip --upgrade > /dev/null -ln -s /usr/bin/python3.8 /usr/bin/python > /dev/null -apt install python3.8-venv > /dev/null +python3.9 -m pip install pip --upgrade > /dev/null +ln -s /usr/bin/python3.9 /usr/bin/python > /dev/null +apt install python3.9-venv > /dev/null LogOutput "Installing Python packages from beam/playground/infrastructure/requirements.txt" pip install -r $BEAM_ROOT_DIR/playground/infrastructure/requirements.txt diff --git a/release/build.gradle.kts b/release/build.gradle.kts index ca1c152c9eb5..7ec49b86aac2 100644 --- a/release/build.gradle.kts +++ b/release/build.gradle.kts @@ -39,7 +39,7 @@ task("runJavaExamplesValidationTask") { dependsOn(":runners:direct-java:runQuickstartJavaDirect") dependsOn(":runners:google-cloud-dataflow-java:runQuickstartJavaDataflow") dependsOn(":runners:spark:3:runQuickstartJavaSpark") - dependsOn(":runners:flink:1.18:runQuickstartJavaFlinkLocal") + dependsOn(":runners:flink:1.19:runQuickstartJavaFlinkLocal") dependsOn(":runners:direct-java:runMobileGamingJavaDirect") dependsOn(":runners:google-cloud-dataflow-java:runMobileGamingJavaDataflow") dependsOn(":runners:twister2:runQuickstartJavaTwister2") diff --git a/release/src/main/Dockerfile b/release/src/main/Dockerfile index 14fe6fdb5a49..6503c5c42ba8 100644 --- a/release/src/main/Dockerfile +++ b/release/src/main/Dockerfile @@ -42,12 +42,11 @@ RUN curl https://pyenv.run | bash && \ echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> /root/.bashrc && \ echo ''eval "$(pyenv init -)"'' >> /root/.bashrc && \ source /root/.bashrc && \ - pyenv install 3.8.9 && \ pyenv install 3.9.4 && \ pyenv install 3.10.7 && \ pyenv install 3.11.3 && \ pyenv install 3.12.3 && \ - pyenv global 3.8.9 3.9.4 3.10.7 3.11.3 3.12.3 + pyenv global 3.9.4 3.10.7 3.11.3 3.12.3 # Install a Go version >= 1.16 so we can bootstrap higher # Go versions diff --git a/release/src/main/groovy/TestScripts.groovy b/release/src/main/groovy/TestScripts.groovy index d5042aa61941..dc2438007ac1 100644 --- a/release/src/main/groovy/TestScripts.groovy +++ b/release/src/main/groovy/TestScripts.groovy @@ -36,6 +36,7 @@ class TestScripts { static String gcsBucket static String bqDataset static String pubsubTopic + static String mavenLocalPath } def TestScripts(String[] args) { @@ -47,6 +48,7 @@ class TestScripts { cli.gcsBucket(args:1, 'Google Cloud Storage Bucket') cli.bqDataset(args:1, "BigQuery Dataset") cli.pubsubTopic(args:1, "PubSub Topic") + cli.mavenLocalPath(args:1, "Maven local path") def options = cli.parse(args) var.repoUrl = options.repourl @@ -73,6 +75,10 @@ class TestScripts { var.pubsubTopic = options.pubsubTopic println "PubSub Topic: ${var.pubsubTopic}" } + if (options.mavenLocalPath) { + var.mavenLocalPath = options.mavenLocalPath + println "Maven local path: ${var.mavenLocalPath}" + } } def ver() { @@ -189,11 +195,16 @@ class TestScripts { } } - // Run a maven command, setting up a new local repository and a settings.xml with a custom repository + // Run a maven command, setting up a new local repository and a settings.xml with a custom repository if needed private String _mvn(String args) { - def m2 = new File(var.startDir, ".m2/repository") + String mvnlocalPath = var.mavenLocalPath + if (!(var.mavenLocalPath)) { + mvnlocalPath = var.startDir + } + def m2 = new File(mvnlocalPath, ".m2/repository") m2.mkdirs() - def settings = new File(var.startDir, "settings.xml") + def settings = new File(mvnlocalPath, "settings.xml") + if(!settings.exists()) { settings.write """ ${m2.absolutePath} @@ -209,16 +220,17 @@ class TestScripts { - """ - def cmd = "mvn ${args} -s ${settings.absolutePath} -Ptestrel -B" - String path = System.getenv("PATH"); - // Set the path on jenkins executors to use a recent maven - // MAVEN_HOME is not set on some executors, so default to 3.5.2 - String maven_home = System.getenv("MAVEN_HOME") ?: '/home/jenkins/tools/maven/apache-maven-3.5.4' - println "Using maven ${maven_home}" - def mvnPath = "${maven_home}/bin" - def setPath = "export PATH=\"${mvnPath}:${path}\" && " - return _execute(setPath + cmd) + """ + } + def cmd = "mvn ${args} -s ${settings.absolutePath} -Ptestrel -B" + String path = System.getenv("PATH"); + // Set the path on jenkins executors to use a recent maven + // MAVEN_HOME is not set on some executors, so default to 3.5.2 + String maven_home = System.getenv("MAVEN_HOME") ?: '/usr/local/maven' + println "Using maven ${maven_home}" + def mvnPath = "${maven_home}/bin" + def setPath = "export PATH=\"${mvnPath}:${path}\" && " + return _execute(setPath + cmd) } // Clean up and report error diff --git a/release/src/main/python-release/python_release_automation.sh b/release/src/main/python-release/python_release_automation.sh index 2f6986885a96..248bdd9b65ac 100755 --- a/release/src/main/python-release/python_release_automation.sh +++ b/release/src/main/python-release/python_release_automation.sh @@ -19,7 +19,7 @@ source release/src/main/python-release/run_release_candidate_python_quickstart.sh source release/src/main/python-release/run_release_candidate_python_mobile_gaming.sh -for version in 3.8 3.9 3.10 3.11 3.12 +for version in 3.9 3.10 3.11 3.12 do run_release_candidate_python_quickstart "tar" "python${version}" run_release_candidate_python_mobile_gaming "tar" "python${version}" diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java index 4fc5d3beca31..5f9bb6392ec2 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java @@ -20,8 +20,8 @@ import com.google.auto.value.AutoValue; import java.io.Serializable; import java.util.Arrays; -import java.util.HashSet; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; @@ -54,13 +54,13 @@ public static StringSetData create(Set set) { if (set.isEmpty()) { return empty(); } - HashSet combined = new HashSet<>(); + Set combined = ConcurrentHashMap.newKeySet(); long stringSize = addUntilCapacity(combined, 0L, set); return new AutoValue_StringSetData(combined, stringSize); } /** Returns a {@link StringSetData} which is made from the given set in place. */ - private static StringSetData createInPlace(HashSet set, long stringSize) { + private static StringSetData createInPlace(Set set, long stringSize) { return new AutoValue_StringSetData(set, stringSize); } @@ -76,11 +76,12 @@ public static StringSetData empty() { *

>Should only be used by {@link StringSetCell#add}. */ public StringSetData addAll(String... strings) { - HashSet combined; - if (this.stringSet() instanceof HashSet) { - combined = (HashSet) this.stringSet(); + Set combined; + if (this.stringSet() instanceof ConcurrentHashMap.KeySetView) { + combined = this.stringSet(); } else { - combined = new HashSet<>(this.stringSet()); + combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); } long stringSize = addUntilCapacity(combined, this.stringSize(), Arrays.asList(strings)); return StringSetData.createInPlace(combined, stringSize); @@ -95,7 +96,8 @@ public StringSetData combine(StringSetData other) { } else if (other.stringSet().isEmpty()) { return this; } else { - HashSet combined = new HashSet<>(this.stringSet()); + Set combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); long stringSize = addUntilCapacity(combined, this.stringSize(), other.stringSet()); return StringSetData.createInPlace(combined, stringSize); } @@ -105,7 +107,8 @@ public StringSetData combine(StringSetData other) { * Combines this {@link StringSetData} with others, all original StringSetData are left intact. */ public StringSetData combine(Iterable others) { - HashSet combined = new HashSet<>(this.stringSet()); + Set combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); long stringSize = this.stringSize(); for (StringSetData other : others) { stringSize = addUntilCapacity(combined, stringSize, other.stringSet()); @@ -120,7 +123,7 @@ public StringSetResult extractResult() { /** Add strings into set until reach capacity. Return the all string size of added set. */ private static long addUntilCapacity( - HashSet combined, long currentSize, Iterable others) { + Set combined, long currentSize, Iterable others) { if (currentSize > STRING_SET_SIZE_LIMIT) { // already at capacity return currentSize; diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java index f78ed01603fb..9497bbe43d0e 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java @@ -20,7 +20,13 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.junit.Assert; @@ -94,4 +100,42 @@ public void testReset() { assertThat(stringSetCell.getCumulative(), equalTo(StringSetData.empty())); assertThat(stringSetCell.getDirty(), equalTo(new DirtyState())); } + + @Test(timeout = 5000) + public void testStringSetCellConcurrentAddRetrieval() throws InterruptedException { + StringSetCell cell = new StringSetCell(MetricName.named("namespace", "name")); + AtomicBoolean finished = new AtomicBoolean(false); + Thread increment = + new Thread( + () -> { + for (long i = 0; !finished.get(); ++i) { + cell.add(String.valueOf(i)); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + break; + } + } + }); + increment.start(); + Instant start = Instant.now(); + try { + while (true) { + Set s = cell.getCumulative().stringSet(); + List snapshot = new ArrayList<>(s); + if (Instant.now().isAfter(start.plusSeconds(3)) && snapshot.size() > 0) { + finished.compareAndSet(false, true); + break; + } + } + } finally { + increment.interrupt(); + increment.join(); + } + + Set s = cell.getCumulative().stringSet(); + for (long i = 0; i < s.size(); ++i) { + assertTrue(s.contains(String.valueOf(i))); + } + } } diff --git a/runners/direct-java/build.gradle b/runners/direct-java/build.gradle index c357b8a04328..404b864c9c31 100644 --- a/runners/direct-java/build.gradle +++ b/runners/direct-java/build.gradle @@ -22,12 +22,12 @@ plugins { id 'org.apache.beam.module' } // Shade away runner execution utilities till because this causes ServiceLoader conflicts with // TransformPayloadTranslatorRegistrar amongst other runners. This only happens in the DirectRunner // because it is likely to appear on the classpath of another runner. -def dependOnProjects = [ - ":runners:core-java", - ":runners:local-java", - ":runners:java-fn-execution", - ":sdks:java:core", - ] +def dependOnProjectsAndConfigs = [ + ":runners:core-java":null, + ":runners:local-java":null, + ":runners:java-fn-execution":null, + ":sdks:java:core":"shadow", +] applyJavaNature( automaticModuleName: 'org.apache.beam.runners.direct', @@ -36,8 +36,8 @@ applyJavaNature( ], shadowClosure: { dependencies { - dependOnProjects.each { - include(project(path: it, configuration: "shadow")) + dependOnProjectsAndConfigs.each { + include(project(path: it.key, configuration: "shadow")) } } }, @@ -63,8 +63,10 @@ configurations { dependencies { shadow library.java.vendored_guava_32_1_2_jre shadow project(path: ":model:pipeline", configuration: "shadow") - dependOnProjects.each { - implementation project(it) + dependOnProjectsAndConfigs.each { + // For projects producing shadowjar, use the packaged jar as dependency to + // handle redirected packages from it + implementation project(path: it.key, configuration: it.value) } shadow library.java.vendored_grpc_1_60_1 shadow library.java.joda_time diff --git a/runners/flink/1.15/build.gradle b/runners/flink/1.15/build.gradle deleted file mode 100644 index 8055cf593ad0..000000000000 --- a/runners/flink/1.15/build.gradle +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -project.ext { - flink_major = '1.15' - flink_version = '1.15.0' -} - -// Load the main build script which contains all build logic. -apply from: "../flink_runner.gradle" diff --git a/runners/flink/1.15/job-server/build.gradle b/runners/flink/1.15/job-server/build.gradle deleted file mode 100644 index 05ad8feb5b78..000000000000 --- a/runners/flink/1.15/job-server/build.gradle +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server' - -project.ext { - // Look for the source code in the parent module - main_source_dirs = ["$basePath/src/main/java"] - test_source_dirs = ["$basePath/src/test/java"] - main_resources_dirs = ["$basePath/src/main/resources"] - test_resources_dirs = ["$basePath/src/test/resources"] - archives_base_name = 'beam-runners-flink-1.15-job-server' -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server.gradle" diff --git a/runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java deleted file mode 100644 index 956aad428d8b..000000000000 --- a/runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.types; - -import java.io.EOFException; -import java.io.IOException; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.flink.FlinkPipelineOptions; -import org.apache.beam.runners.flink.translation.wrappers.DataInputViewWrapper; -import org.apache.beam.runners.flink.translation.wrappers.DataOutputViewWrapper; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; -import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.core.io.VersionedIOReadableWritable; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** - * Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for Beam {@link - * org.apache.beam.sdk.coders.Coder Coders}. - */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class CoderTypeSerializer extends TypeSerializer { - - private static final long serialVersionUID = 7247319138941746449L; - - private final Coder coder; - - /** - * {@link SerializablePipelineOptions} deserialization will cause {@link - * org.apache.beam.sdk.io.FileSystems} registration needed for {@link - * org.apache.beam.sdk.transforms.Reshuffle} translation. - */ - private final SerializablePipelineOptions pipelineOptions; - - private final boolean fasterCopy; - - public CoderTypeSerializer(Coder coder, SerializablePipelineOptions pipelineOptions) { - Preconditions.checkNotNull(coder); - Preconditions.checkNotNull(pipelineOptions); - this.coder = coder; - this.pipelineOptions = pipelineOptions; - - FlinkPipelineOptions options = pipelineOptions.get().as(FlinkPipelineOptions.class); - this.fasterCopy = options.getFasterCopy(); - } - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public CoderTypeSerializer duplicate() { - return new CoderTypeSerializer<>(coder, pipelineOptions); - } - - @Override - public T createInstance() { - return null; - } - - @Override - public T copy(T t) { - if (fasterCopy) { - return t; - } - try { - return CoderUtils.clone(coder, t); - } catch (CoderException e) { - throw new RuntimeException("Could not clone.", e); - } - } - - @Override - public T copy(T t, T reuse) { - return copy(t); - } - - @Override - public int getLength() { - return -1; - } - - @Override - public void serialize(T t, DataOutputView dataOutputView) throws IOException { - DataOutputViewWrapper outputWrapper = new DataOutputViewWrapper(dataOutputView); - coder.encode(t, outputWrapper); - } - - @Override - public T deserialize(DataInputView dataInputView) throws IOException { - try { - DataInputViewWrapper inputWrapper = new DataInputViewWrapper(dataInputView); - return coder.decode(inputWrapper); - } catch (CoderException e) { - Throwable cause = e.getCause(); - if (cause instanceof EOFException) { - throw (EOFException) cause; - } else { - throw e; - } - } - } - - @Override - public T deserialize(T t, DataInputView dataInputView) throws IOException { - return deserialize(dataInputView); - } - - @Override - public void copy(DataInputView dataInputView, DataOutputView dataOutputView) throws IOException { - serialize(deserialize(dataInputView), dataOutputView); - } - - @Override - public boolean equals(@Nullable Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - CoderTypeSerializer that = (CoderTypeSerializer) o; - return coder.equals(that.coder); - } - - @Override - public int hashCode() { - return coder.hashCode(); - } - - @Override - public TypeSerializerSnapshot snapshotConfiguration() { - return new UnversionedTypeSerializerSnapshot<>(this); - } - - /** - * A legacy snapshot which does not care about schema compatibility. This is used only for state - * restore of state created by Beam 2.54.0 and below for Flink 1.16 and below. - */ - public static class LegacySnapshot extends TypeSerializerConfigSnapshot { - - /** Needs to be public to work with {@link VersionedIOReadableWritable}. */ - public LegacySnapshot() {} - - public LegacySnapshot(CoderTypeSerializer serializer) { - setPriorSerializer(serializer); - } - - @Override - public int getVersion() { - // We always return the same version - return 1; - } - - @Override - public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( - TypeSerializer newSerializer) { - - // We assume compatibility because we don't have a way of checking schema compatibility - return TypeSerializerSchemaCompatibility.compatibleAsIs(); - } - } - - @Override - public String toString() { - return "CoderTypeSerializer{" + "coder=" + coder + '}'; - } -} diff --git a/runners/flink/1.16/job-server-container/build.gradle b/runners/flink/1.16/job-server-container/build.gradle deleted file mode 100644 index afdb68a0fc91..000000000000 --- a/runners/flink/1.16/job-server-container/build.gradle +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server-container' - -project.ext { - resource_path = basePath -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server_container.gradle" diff --git a/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java new file mode 100644 index 000000000000..7317788a72ee --- /dev/null +++ b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import java.util.Collection; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.BackendBuildingException; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.state.ttl.TtlTimeProvider; + +class MemoryStateBackendWrapper { + static AbstractKeyedStateBackend createKeyedStateBackend( + Environment env, + JobID jobID, + String operatorIdentifier, + TypeSerializer keySerializer, + int numberOfKeyGroups, + KeyGroupRange keyGroupRange, + TaskKvStateRegistry kvStateRegistry, + TtlTimeProvider ttlTimeProvider, + MetricGroup metricGroup, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws BackendBuildingException { + + MemoryStateBackend backend = new MemoryStateBackend(); + return backend.createKeyedStateBackend( + env, + jobID, + operatorIdentifier, + keySerializer, + numberOfKeyGroups, + keyGroupRange, + kvStateRegistry, + ttlTimeProvider, + metricGroup, + stateHandles, + cancelStreamRegistry); + } + + static OperatorStateBackend createOperatorStateBackend( + Environment env, + String operatorIdentifier, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws Exception { + MemoryStateBackend backend = new MemoryStateBackend(); + return backend.createOperatorStateBackend( + env, operatorIdentifier, stateHandles, cancelStreamRegistry); + } +} diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java similarity index 100% rename from runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java rename to runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java diff --git a/runners/flink/1.16/build.gradle b/runners/flink/1.19/build.gradle similarity index 94% rename from runners/flink/1.16/build.gradle rename to runners/flink/1.19/build.gradle index 21a222864a27..1545da258477 100644 --- a/runners/flink/1.16/build.gradle +++ b/runners/flink/1.19/build.gradle @@ -17,8 +17,8 @@ */ project.ext { - flink_major = '1.16' - flink_version = '1.16.0' + flink_major = '1.19' + flink_version = '1.19.0' } // Load the main build script which contains all build logic. diff --git a/runners/flink/1.15/job-server-container/build.gradle b/runners/flink/1.19/job-server-container/build.gradle similarity index 100% rename from runners/flink/1.15/job-server-container/build.gradle rename to runners/flink/1.19/job-server-container/build.gradle diff --git a/runners/flink/1.16/job-server/build.gradle b/runners/flink/1.19/job-server/build.gradle similarity index 95% rename from runners/flink/1.16/job-server/build.gradle rename to runners/flink/1.19/job-server/build.gradle index 99dc00275a0c..332f04e08ceb 100644 --- a/runners/flink/1.16/job-server/build.gradle +++ b/runners/flink/1.19/job-server/build.gradle @@ -24,7 +24,7 @@ project.ext { test_source_dirs = ["$basePath/src/test/java"] main_resources_dirs = ["$basePath/src/main/resources"] test_resources_dirs = ["$basePath/src/test/resources"] - archives_base_name = 'beam-runners-flink-1.16-job-server' + archives_base_name = 'beam-runners-flink-1.19-job-server' } // Load the main build script which contains all build logic. diff --git a/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java new file mode 100644 index 000000000000..cbaa6fd3a8c4 --- /dev/null +++ b/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import java.util.Collection; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.BackendBuildingException; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateBackendParametersImpl; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateBackendParametersImpl; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.runtime.state.ttl.TtlTimeProvider; + +class MemoryStateBackendWrapper { + static AbstractKeyedStateBackend createKeyedStateBackend( + Environment env, + JobID jobID, + String operatorIdentifier, + TypeSerializer keySerializer, + int numberOfKeyGroups, + KeyGroupRange keyGroupRange, + TaskKvStateRegistry kvStateRegistry, + TtlTimeProvider ttlTimeProvider, + MetricGroup metricGroup, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws BackendBuildingException { + + MemoryStateBackend backend = new MemoryStateBackend(); + return backend.createKeyedStateBackend( + new KeyedStateBackendParametersImpl<>( + env, + jobID, + operatorIdentifier, + keySerializer, + numberOfKeyGroups, + keyGroupRange, + kvStateRegistry, + ttlTimeProvider, + metricGroup, + stateHandles, + cancelStreamRegistry)); + } + + static OperatorStateBackend createOperatorStateBackend( + Environment env, + String operatorIdentifier, + Collection stateHandles, + CloseableRegistry cancelStreamRegistry) + throws Exception { + MemoryStateBackend backend = new MemoryStateBackend(); + return backend.createOperatorStateBackend( + new OperatorStateBackendParametersImpl( + env, operatorIdentifier, stateHandles, cancelStreamRegistry)); + } +} diff --git a/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java new file mode 100644 index 000000000000..c03799d09535 --- /dev/null +++ b/runners/flink/1.19/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.streaming.runtime.streamrecord.RecordAttributes; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorChain; +import org.apache.flink.streaming.runtime.tasks.RegularOperatorChain; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; + +/** {@link StreamSource} utilities, that bridge incompatibilities between Flink releases. */ +public class StreamSources { + + public static > void run( + StreamSource streamSource, + Object lockingObject, + Output> collector) + throws Exception { + streamSource.run(lockingObject, collector, createOperatorChain(streamSource)); + } + + private static OperatorChain createOperatorChain(AbstractStreamOperator operator) { + return new RegularOperatorChain<>( + operator.getContainingTask(), + StreamTask.createRecordWriterDelegate( + operator.getOperatorConfig(), new MockEnvironmentBuilder().build())); + } + + /** The emitWatermarkStatus method was added in Flink 1.14, so we need to wrap Output. */ + public interface OutputWrapper extends Output { + @Override + default void emitWatermarkStatus(WatermarkStatus watermarkStatus) {} + + /** In Flink 1.19 the {@code emitRecordAttributes} method was added. */ + @Override + default void emitRecordAttributes(RecordAttributes recordAttributes) { + throw new UnsupportedOperationException("emitRecordAttributes not implemented"); + } + } +} diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle index c8f492a901d3..d13e1c5faf6e 100644 --- a/runners/flink/flink_runner.gradle +++ b/runners/flink/flink_runner.gradle @@ -173,36 +173,19 @@ dependencies { implementation library.java.joda_time implementation library.java.args4j - // Flink 1.15 shades all remaining scala dependencies and therefor does not depend on a specific version of Scala anymore - if (flink_version.compareTo("1.15") >= 0) { - implementation "org.apache.flink:flink-clients:$flink_version" - // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation - // configuration (https://issues.apache.org/jira/browse/BEAM-11732). - permitUnusedDeclared "org.apache.flink:flink-clients:$flink_version" - - implementation "org.apache.flink:flink-streaming-java:$flink_version" - // RocksDB state backend (included in the Flink distribution) - provided "org.apache.flink:flink-statebackend-rocksdb:$flink_version" - testImplementation "org.apache.flink:flink-statebackend-rocksdb:$flink_version" - testImplementation "org.apache.flink:flink-streaming-java:$flink_version:tests" - testImplementation "org.apache.flink:flink-test-utils:$flink_version" - - miniCluster "org.apache.flink:flink-runtime-web:$flink_version" - } else { - implementation "org.apache.flink:flink-clients_2.12:$flink_version" - // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation - // configuration (https://issues.apache.org/jira/browse/BEAM-11732). - permitUnusedDeclared "org.apache.flink:flink-clients_2.12:$flink_version" - - implementation "org.apache.flink:flink-streaming-java_2.12:$flink_version" - // RocksDB state backend (included in the Flink distribution) - provided "org.apache.flink:flink-statebackend-rocksdb_2.12:$flink_version" - testImplementation "org.apache.flink:flink-statebackend-rocksdb_2.12:$flink_version" - testImplementation "org.apache.flink:flink-streaming-java_2.12:$flink_version:tests" - testImplementation "org.apache.flink:flink-test-utils_2.12:$flink_version" - - miniCluster "org.apache.flink:flink-runtime-web_2.12:$flink_version" - } + implementation "org.apache.flink:flink-clients:$flink_version" + // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation + // configuration (https://issues.apache.org/jira/browse/BEAM-11732). + permitUnusedDeclared "org.apache.flink:flink-clients:$flink_version" + + implementation "org.apache.flink:flink-streaming-java:$flink_version" + // RocksDB state backend (included in the Flink distribution) + provided "org.apache.flink:flink-statebackend-rocksdb:$flink_version" + testImplementation "org.apache.flink:flink-statebackend-rocksdb:$flink_version" + testImplementation "org.apache.flink:flink-streaming-java:$flink_version:tests" + testImplementation "org.apache.flink:flink-test-utils:$flink_version" + + miniCluster "org.apache.flink:flink-runtime-web:$flink_version" implementation "org.apache.flink:flink-core:$flink_version" implementation "org.apache.flink:flink-metrics-core:$flink_version" diff --git a/runners/flink/job-server-container/Dockerfile b/runners/flink/job-server-container/Dockerfile index c5a81ecf6466..5f19aa0dc851 100644 --- a/runners/flink/job-server-container/Dockerfile +++ b/runners/flink/job-server-container/Dockerfile @@ -16,7 +16,7 @@ # limitations under the License. ############################################################################### -FROM openjdk:8 +FROM eclipse-temurin:11 MAINTAINER "Apache Beam " RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y libltdl7 @@ -28,4 +28,6 @@ COPY target/LICENSE /opt/apache/beam/ COPY target/NOTICE /opt/apache/beam/ WORKDIR /opt/apache/beam -ENTRYPOINT ["./flink-job-server.sh"] + +# Add a conditional check for a mounted volume. This allows passing flink configs. +ENTRYPOINT ["/bin/sh", "-c", "if [ -d \"/flink-conf\" ]; then /opt/apache/beam/flink-job-server.sh --flink-conf-dir /flink-conf; else /opt/apache/beam/flink-job-server.sh; fi"] diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java index 10e20a6d47d3..c679c0725051 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java @@ -25,7 +25,6 @@ import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.OperatorStateBackend; -import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.junit.Ignore; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,10 +40,9 @@ public class FlinkBroadcastStateInternalsTest extends StateInternalsTest { @Override protected StateInternals createStateInternals() { - MemoryStateBackend backend = new MemoryStateBackend(); try { OperatorStateBackend operatorStateBackend = - backend.createOperatorStateBackend( + MemoryStateBackendWrapper.createOperatorStateBackend( new DummyEnvironment("test", 1, 0), "", Collections.emptyList(), null); return new FlinkBroadcastStateInternals<>( 1, diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index a2d6f5027abb..d0338ec3b0d3 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -47,7 +47,6 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateBackend; -import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.state.ttl.TtlTimeProvider; import org.hamcrest.Matchers; import org.joda.time.Instant; @@ -185,9 +184,8 @@ public void testGlobalWindowWatermarkHoldClear() throws Exception { } public static KeyedStateBackend createStateBackend() throws Exception { - MemoryStateBackend backend = new MemoryStateBackend(); AbstractKeyedStateBackend keyedStateBackend = - backend.createKeyedStateBackend( + MemoryStateBackendWrapper.createKeyedStateBackend( new DummyEnvironment("test", 1, 0), new JobID(), "test_op", diff --git a/runners/google-cloud-dataflow-java/worker/build.gradle b/runners/google-cloud-dataflow-java/worker/build.gradle index b7e6e981effe..92beccd067e2 100644 --- a/runners/google-cloud-dataflow-java/worker/build.gradle +++ b/runners/google-cloud-dataflow-java/worker/build.gradle @@ -54,6 +54,7 @@ def sdk_provided_project_dependencies = [ ":runners:google-cloud-dataflow-java", ":sdks:java:extensions:avro", ":sdks:java:extensions:google-cloud-platform-core", + ":sdks:java:io:kafka", // For metric propagation into worker ":sdks:java:io:google-cloud-platform", ] diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java index 94c894608a47..a28a5e989c88 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java @@ -82,7 +82,9 @@ public static T initializeGlobalStateAn @SuppressWarnings("Slf4jIllegalPassedClass") public static void initializeLogging(Class workerHarnessClass) { - /* Set up exception handling tied to the workerHarnessClass. */ + // Set up exception handling for raw Threads tied to the workerHarnessClass. + // Does NOT handle exceptions thrown by threads created by + // ScheduledExecutors/ScheduledExecutorServices. Thread.setDefaultUncaughtExceptionHandler( new WorkerUncaughtExceptionHandler(LoggerFactory.getLogger(workerHarnessClass))); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java index 30e920119120..77f867793ae2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java @@ -32,13 +32,15 @@ import java.util.Map.Entry; import java.util.Optional; import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics; +import org.apache.beam.sdk.io.kafka.KafkaSinkMetrics; import org.apache.beam.sdk.metrics.LabeledMetricNameUtils; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.util.HistogramData; /** * Converts metric updates to {@link PerStepNamespaceMetrics} protos. Currently we only support - * converting metrics from {@link BigQuerySinkMetrics} with this converter. + * converting metrics from {@link BigQuerySinkMetrics} and from {@link KafkaSinkMetrics} with this + * converter. */ public class MetricsToPerStepNamespaceMetricsConverter { @@ -65,7 +67,10 @@ private static Optional convertCounterToMetricValue( MetricName metricName, Long value, Map parsedPerWorkerMetricsCache) { - if (value == 0 || !metricName.getNamespace().equals(BigQuerySinkMetrics.METRICS_NAMESPACE)) { + + if (value == 0 + || (!metricName.getNamespace().equals(BigQuerySinkMetrics.METRICS_NAMESPACE) + && !metricName.getNamespace().equals(KafkaSinkMetrics.METRICS_NAMESPACE))) { return Optional.empty(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java index a413c2c03dbe..558848f488a7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java @@ -77,6 +77,7 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class SimpleParDoFn implements ParDoFn { + // TODO: Remove once Distributions has shipped. @VisibleForTesting static final String OUTPUTS_PER_ELEMENT_EXPERIMENT = "outputs_per_element_counter"; @@ -174,6 +175,7 @@ private boolean hasExperiment(String experiment) { /** Simple state tracker to calculate PerElementOutputCount counter. */ private interface OutputsPerElementTracker { + void onOutput(); void onProcessElement(); @@ -182,6 +184,7 @@ private interface OutputsPerElementTracker { } private class OutputsPerElementTrackerImpl implements OutputsPerElementTracker { + private long outputsPerElement; private final Counter counter; @@ -214,6 +217,7 @@ private void reset() { /** No-op {@link OutputsPerElementTracker} implementation used when the counter is disabled. */ private static class NoopOutputsPerElementTracker implements OutputsPerElementTracker { + private NoopOutputsPerElementTracker() {} public static final OutputsPerElementTracker INSTANCE = new NoopOutputsPerElementTracker(); @@ -516,10 +520,14 @@ private void registerStateCleanup( private Instant earliestAllowableCleanupTime( BoundedWindow window, WindowingStrategy windowingStrategy) { - return window - .maxTimestamp() - .plus(windowingStrategy.getAllowedLateness()) - .plus(Duration.millis(1L)); + Instant cleanupTime = + window + .maxTimestamp() + .plus(windowingStrategy.getAllowedLateness()) + .plus(Duration.millis(1L)); + return cleanupTime.isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE) + ? BoundedWindow.TIMESTAMP_MAX_VALUE + : cleanupTime; } /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index ecdba404151e..6ce60283735f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -65,6 +65,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter; @@ -93,6 +94,7 @@ import org.apache.beam.sdk.fn.JvmInitializers; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics; +import org.apache.beam.sdk.io.kafka.KafkaSinkMetrics; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; @@ -140,8 +142,6 @@ public final class StreamingDataflowWorker { private static final int DEFAULT_STATUS_PORT = 8081; private static final Random CLIENT_ID_GENERATOR = new Random(); private static final String CHANNELZ_PATH = "/channelz"; - public static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL = - "streaming_engine_use_job_settings_for_heartbeat_pool"; private final WindmillStateCache stateCache; private final StreamingWorkerStatusPages statusPages; @@ -176,7 +176,7 @@ private StreamingDataflowWorker( StreamingCounters streamingCounters, MemoryMonitor memoryMonitor, GrpcWindmillStreamFactory windmillStreamFactory, - Function executorSupplier, + ScheduledExecutorService activeWorkRefreshExecutorFn, ConcurrentMap stageInfoMap) { // Register standard file systems. FileSystems.setDefaultPipelineOptions(options); @@ -200,6 +200,7 @@ private StreamingDataflowWorker( this.workCommitter = windmillServiceEnabled ? StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory( WindmillStreamPool.create( numCommitThreads, @@ -249,10 +250,7 @@ private StreamingDataflowWorker( GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream); getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool); - // Experiment gates the logic till backend changes are rollback safe - if (!DataflowRunner.hasExperiment( - options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL) - || options.getUseSeparateWindmillHeartbeatStreams() != null) { + if (options.getUseSeparateWindmillHeartbeatStreams() != null) { heartbeatSender = StreamPoolHeartbeatSender.Create( Boolean.TRUE.equals(options.getUseSeparateWindmillHeartbeatStreams()) @@ -289,7 +287,7 @@ private StreamingDataflowWorker( stuckCommitDurationMillis, computationStateCache::getAllPresentComputations, sampler, - executorSupplier.apply("RefreshWork"), + activeWorkRefreshExecutorFn, getDataMetricTracker::trackHeartbeats); this.statusPages = @@ -351,10 +349,7 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o .setSizeMb(options.getWorkerCacheMb()) .setSupportMapViaMultimap(options.isEnableStreamingEngine()) .build(); - Function executorSupplier = - threadName -> - Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()); + GrpcWindmillStreamFactory.Builder windmillStreamFactoryBuilder = createGrpcwindmillStreamFactoryBuilder(options, clientId); @@ -421,7 +416,8 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o streamingCounters, memoryMonitor, configFetcherComputationStateCacheAndWindmillClient.windmillStreamFactory(), - executorSupplier, + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat("RefreshWork").build()), stageInfo); } @@ -599,7 +595,7 @@ static StreamingDataflowWorker forTesting( options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()) .build() : windmillStreamFactory.build(), - executorSupplier, + executorSupplier.apply("RefreshWork"), stageInfo); } @@ -668,6 +664,10 @@ public static void main(String[] args) throws Exception { enableBigQueryMetrics(); } + if (DataflowRunner.hasExperiment(options, "enable_kafka_metrics")) { + KafkaSinkMetrics.setSupportKafkaMetrics(true); + } + JvmInitializers.runBeforeProcessing(options); worker.startStatusPages(); worker.start(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java index 5a8e87d23ab9..b4ec170099d5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java @@ -28,16 +28,16 @@ * This uncaught exception handler logs the {@link Throwable} to the logger, {@link System#err} and * exits the application with status code 1. */ -class WorkerUncaughtExceptionHandler implements UncaughtExceptionHandler { +public final class WorkerUncaughtExceptionHandler implements UncaughtExceptionHandler { + @VisibleForTesting public static final int JVM_TERMINATED_STATUS_CODE = 1; private final JvmRuntime runtime; private final Logger logger; - WorkerUncaughtExceptionHandler(Logger logger) { + public WorkerUncaughtExceptionHandler(Logger logger) { this(JvmRuntime.INSTANCE, logger); } - @VisibleForTesting - WorkerUncaughtExceptionHandler(JvmRuntime runtime, Logger logger) { + public WorkerUncaughtExceptionHandler(JvmRuntime runtime, Logger logger) { this.runtime = runtime; this.logger = logger; } @@ -59,7 +59,7 @@ public void uncaughtException(Thread thread, Throwable e) { t.printStackTrace(originalStdErr); } } finally { - runtime.halt(1); + runtime.halt(JVM_TERMINATED_STATUS_CODE); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java index 199ad26aed00..4b4acb73f4a7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java @@ -147,8 +147,10 @@ public Optional get(String computationId) { | ComputationStateNotFoundException e) { if (e.getCause() instanceof ComputationStateNotFoundException || e instanceof ComputationStateNotFoundException) { - LOG.error( - "Trying to fetch unknown computation={}, known computations are {}.", + LOG.warn( + "Computation {} is currently unknown, " + + "known computations are {}. " + + "This is transient and will get retried.", computationId, ImmutableSet.copyOf(computationCache.asMap().keySet())); } else { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java index a18ca8cfd6dc..525464ef2e1f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java @@ -35,6 +35,7 @@ import org.apache.beam.runners.dataflow.worker.counters.DataflowCounterUpdateExtractor; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics; +import org.apache.beam.sdk.io.kafka.KafkaSinkMetrics; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; /** Contains a few of the stage specific fields. E.g. metrics container registry, counters etc. */ @@ -118,7 +119,9 @@ public List extractPerWorkerMetricValues() { private void translateKnownPerWorkerCounters(List metrics) { for (PerStepNamespaceMetrics perStepnamespaceMetrics : metrics) { if (!BigQuerySinkMetrics.METRICS_NAMESPACE.equals( - perStepnamespaceMetrics.getMetricsNamespace())) { + perStepnamespaceMetrics.getMetricsNamespace()) + && !KafkaSinkMetrics.METRICS_NAMESPACE.equals( + perStepnamespaceMetrics.getMetricsNamespace())) { continue; } for (MetricValue metric : perStepnamespaceMetrics.getMetricValues()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java index f2893f3e7191..5f039be7b00f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java @@ -18,33 +18,24 @@ package org.apache.beam.runners.dataflow.worker.streaming; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.function.Function; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; -/** Bounded set of queues, with a maximum total weight. */ +/** Queue bounded by a {@link WeightedSemaphore}. */ public final class WeightedBoundedQueue { private final LinkedBlockingQueue queue; - private final int maxWeight; - private final Semaphore limit; - private final Function weigher; + private final WeightedSemaphore weightedSemaphore; private WeightedBoundedQueue( - LinkedBlockingQueue linkedBlockingQueue, - int maxWeight, - Semaphore limit, - Function weigher) { + LinkedBlockingQueue linkedBlockingQueue, WeightedSemaphore weightedSemaphore) { this.queue = linkedBlockingQueue; - this.maxWeight = maxWeight; - this.limit = limit; - this.weigher = weigher; + this.weightedSemaphore = weightedSemaphore; } - public static WeightedBoundedQueue create(int maxWeight, Function weigherFn) { - return new WeightedBoundedQueue<>( - new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, true), weigherFn); + public static WeightedBoundedQueue create(WeightedSemaphore weightedSemaphore) { + return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(), weightedSemaphore); } /** @@ -52,15 +43,15 @@ public static WeightedBoundedQueue create(int maxWeight, Function { + private final int maxWeight; + private final Semaphore limit; + private final Function weigher; + + private WeightedSemaphore(int maxWeight, Semaphore limit, Function weigher) { + this.maxWeight = maxWeight; + this.limit = limit; + this.weigher = weigher; + } + + public static WeightedSemaphore create(int maxWeight, Function weigherFn) { + return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), weigherFn); + } + + public void acquireUninterruptibly(V value) { + limit.acquireUninterruptibly(computePermits(value)); + } + + public void release(V value) { + limit.release(computePermits(value)); + } + + private int computePermits(V value) { + return Math.min(weigher.apply(value), maxWeight); + } + + public int currentWeight() { + return maxWeight - limit.availablePermits(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 3556b7ce2919..3eed4ee6d835 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -20,20 +20,25 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; -import java.util.Collection; -import java.util.List; +import java.io.Closeable; +import java.util.HashSet; import java.util.Map.Entry; +import java.util.NoSuchElementException; import java.util.Optional; -import java.util.Queue; -import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import java.util.stream.Collectors; import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; @@ -54,18 +59,14 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; -import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,32 +81,39 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorkerHarness { private static final Logger LOG = LoggerFactory.getLogger(FanOutStreamingEngineWorkerHarness.class); - private static final String PUBLISH_NEW_WORKER_METADATA_THREAD = "PublishNewWorkerMetadataThread"; - private static final String CONSUME_NEW_WORKER_METADATA_THREAD = "ConsumeNewWorkerMetadataThread"; + private static final String WORKER_METADATA_CONSUMER_THREAD_NAME = + "WindmillWorkerMetadataConsumerThread"; + private static final String STREAM_MANAGER_THREAD_NAME = "WindmillStreamManager-%d"; private final JobHeader jobHeader; private final GrpcWindmillStreamFactory streamFactory; private final WorkItemScheduler workItemScheduler; private final ChannelCachingStubFactory channelCachingStubFactory; private final GrpcDispatcherClient dispatcherClient; - private final AtomicBoolean isBudgetRefreshPaused; - private final GetWorkBudgetRefresher getWorkBudgetRefresher; - private final AtomicReference lastBudgetRefresh; + private final GetWorkBudgetDistributor getWorkBudgetDistributor; + private final GetWorkBudget totalGetWorkBudget; private final ThrottleTimer getWorkerMetadataThrottleTimer; - private final ExecutorService newWorkerMetadataPublisher; - private final ExecutorService newWorkerMetadataConsumer; - private final long clientId; - private final Supplier getWorkerMetadataStream; - private final Queue newWindmillEndpoints; private final Function workCommitterFactory; private final ThrottlingGetDataMetricTracker getDataMetricTracker; + private final ExecutorService windmillStreamManager; + private final ExecutorService workerMetadataConsumer; + private final Object metadataLock = new Object(); /** Writes are guarded by synchronization, reads are lock free. */ - private final AtomicReference connections; + private final AtomicReference backends; - private volatile boolean started; + @GuardedBy("this") + private long activeMetadataVersion; + + @GuardedBy("metadataLock") + private long pendingMetadataVersion; + + @GuardedBy("this") + private boolean started; + + @GuardedBy("this") + private @Nullable GetWorkerMetadataStream getWorkerMetadataStream; - @SuppressWarnings("FutureReturnValueIgnored") private FanOutStreamingEngineWorkerHarness( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, @@ -114,7 +122,6 @@ private FanOutStreamingEngineWorkerHarness( ChannelCachingStubFactory channelCachingStubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { this.jobHeader = jobHeader; @@ -122,42 +129,21 @@ private FanOutStreamingEngineWorkerHarness( this.started = false; this.streamFactory = streamFactory; this.workItemScheduler = workItemScheduler; - this.connections = new AtomicReference<>(StreamingEngineConnectionState.EMPTY); + this.backends = new AtomicReference<>(StreamingEngineBackends.EMPTY); this.channelCachingStubFactory = channelCachingStubFactory; this.dispatcherClient = dispatcherClient; - this.isBudgetRefreshPaused = new AtomicBoolean(false); this.getWorkerMetadataThrottleTimer = new ThrottleTimer(); - this.newWorkerMetadataPublisher = - singleThreadedExecutorServiceOf(PUBLISH_NEW_WORKER_METADATA_THREAD); - this.newWorkerMetadataConsumer = - singleThreadedExecutorServiceOf(CONSUME_NEW_WORKER_METADATA_THREAD); - this.clientId = clientId; - this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH); - this.newWindmillEndpoints = Queues.synchronizedQueue(EvictingQueue.create(1)); - this.getWorkBudgetRefresher = - new GetWorkBudgetRefresher( - isBudgetRefreshPaused::get, - () -> { - getWorkBudgetDistributor.distributeBudget( - connections.get().windmillStreams().values(), totalGetWorkBudget); - lastBudgetRefresh.set(Instant.now()); - }); - this.getWorkerMetadataStream = - Suppliers.memoize( - () -> - streamFactory.createGetWorkerMetadataStream( - dispatcherClient.getWindmillMetadataServiceStubBlocking(), - getWorkerMetadataThrottleTimer, - endpoints -> - // Run this on a separate thread than the grpc stream thread. - newWorkerMetadataPublisher.submit( - () -> newWindmillEndpoints.add(endpoints)))); + this.windmillStreamManager = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build()); + this.workerMetadataConsumer = + Executors.newSingleThreadExecutor( + new ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build()); + this.getWorkBudgetDistributor = getWorkBudgetDistributor; + this.totalGetWorkBudget = totalGetWorkBudget; + this.activeMetadataVersion = Long.MIN_VALUE; this.workCommitterFactory = workCommitterFactory; - } - - private static ExecutorService singleThreadedExecutorServiceOf(String threadName) { - return Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()); + this.getWorkerMetadataStream = null; } /** @@ -183,7 +169,6 @@ public static FanOutStreamingEngineWorkerHarness create( channelCachingStubFactory, getWorkBudgetDistributor, dispatcherClient, - /* clientId= */ new Random().nextLong(), workCommitterFactory, getDataMetricTracker); } @@ -197,7 +182,6 @@ static FanOutStreamingEngineWorkerHarness forTesting( ChannelCachingStubFactory stubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkProvider = @@ -209,201 +193,218 @@ static FanOutStreamingEngineWorkerHarness forTesting( stubFactory, getWorkBudgetDistributor, dispatcherClient, - clientId, workCommitterFactory, getDataMetricTracker); fanOutStreamingEngineWorkProvider.start(); return fanOutStreamingEngineWorkProvider; } - @SuppressWarnings("ReturnValueIgnored") @Override public synchronized void start() { - Preconditions.checkState(!started, "StreamingEngineClient cannot start twice."); - // Starts the stream, this value is memoized. - getWorkerMetadataStream.get(); - startWorkerMetadataConsumer(); - getWorkBudgetRefresher.start(); + Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); + getWorkerMetadataStream = + streamFactory.createGetWorkerMetadataStream( + dispatcherClient.getWindmillMetadataServiceStubBlocking(), + getWorkerMetadataThrottleTimer, + this::consumeWorkerMetadata); started = true; } public ImmutableSet currentWindmillEndpoints() { - return connections.get().windmillConnections().keySet().stream() + return backends.get().windmillStreams().keySet().stream() .map(Endpoint::directEndpoint) .filter(Optional::isPresent) .map(Optional::get) - .filter( - windmillServiceAddress -> - windmillServiceAddress.getKind() != WindmillServiceAddress.Kind.IPV6) - .map( - windmillServiceAddress -> - windmillServiceAddress.getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS - ? windmillServiceAddress.gcpServiceAddress() - : windmillServiceAddress.authenticatedGcpServiceAddress().gcpServiceAddress()) + .map(WindmillServiceAddress::getServiceAddress) .collect(toImmutableSet()); } /** - * Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or defaults to {@link - * GetDataStream} pointing to dispatcher. + * Fetches {@link GetDataStream} mapped to globalDataKey if or throws {@link + * NoSuchElementException} if one is not found. */ private GetDataStream getGlobalDataStream(String globalDataKey) { - return Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey)) - .map(Supplier::get) - .orElseGet( - () -> - streamFactory.createGetDataStream( - dispatcherClient.getWindmillServiceStub(), new ThrottleTimer())); - } - - @SuppressWarnings("FutureReturnValueIgnored") - private void startWorkerMetadataConsumer() { - newWorkerMetadataConsumer.submit( - () -> { - while (true) { - Optional.ofNullable(newWindmillEndpoints.poll()) - .ifPresent(this::consumeWindmillWorkerEndpoints); - } - }); + return Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey)) + .map(GlobalDataStreamSender::get) + .orElseThrow( + () -> new NoSuchElementException("No endpoint for global data tag: " + globalDataKey)); } @VisibleForTesting @Override public synchronized void shutdown() { - Preconditions.checkState(started, "StreamingEngineClient never started."); - getWorkerMetadataStream.get().halfClose(); - getWorkBudgetRefresher.stop(); - newWorkerMetadataPublisher.shutdownNow(); - newWorkerMetadataConsumer.shutdownNow(); + Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started."); + Preconditions.checkNotNull(getWorkerMetadataStream).shutdown(); + workerMetadataConsumer.shutdownNow(); + closeStreamsNotIn(WindmillEndpoints.none()); channelCachingStubFactory.shutdown(); + + try { + Preconditions.checkNotNull(getWorkerMetadataStream).awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for GetWorkerMetadataStream to shutdown.", e); + } + + windmillStreamManager.shutdown(); + boolean isStreamManagerShutdown = false; + try { + isStreamManagerShutdown = windmillStreamManager.awaitTermination(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for windmillStreamManager to shutdown.", e); + } + if (!isStreamManagerShutdown) { + windmillStreamManager.shutdownNow(); + } + } + + private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) { + synchronized (metadataLock) { + // Only process versions greater than what we currently have to prevent double processing of + // metadata. workerMetadataConsumer is single-threaded so we maintain ordering. + if (windmillEndpoints.version() > pendingMetadataVersion) { + pendingMetadataVersion = windmillEndpoints.version(); + workerMetadataConsumer.execute(() -> consumeWindmillWorkerEndpoints(windmillEndpoints)); + } + } } - /** - * {@link java.util.function.Consumer} used to update {@link #connections} on - * new backend worker metadata. - */ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWindmillEndpoints) { - isBudgetRefreshPaused.set(true); - LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints); - ImmutableMap newWindmillConnections = - createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints()); - - StreamingEngineConnectionState newConnectionsState = - StreamingEngineConnectionState.builder() - .setWindmillConnections(newWindmillConnections) - .setWindmillStreams( - closeStaleStreamsAndCreateNewStreams(newWindmillConnections.values())) + // Since this is run on a single threaded executor, multiple versions of the metadata maybe + // queued up while a previous version of the windmillEndpoints were being consumed. Only consume + // the endpoints if they are the most current version. + synchronized (metadataLock) { + if (newWindmillEndpoints.version() < pendingMetadataVersion) { + return; + } + } + + LOG.debug( + "Consuming new endpoints: {}. previous metadata version: {}, current metadata version: {}", + newWindmillEndpoints, + activeMetadataVersion, + newWindmillEndpoints.version()); + closeStreamsNotIn(newWindmillEndpoints); + ImmutableMap newStreams = + createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join(); + StreamingEngineBackends newBackends = + StreamingEngineBackends.builder() + .setWindmillStreams(newStreams) .setGlobalDataStreams( createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints())) .build(); + backends.set(newBackends); + getWorkBudgetDistributor.distributeBudget(newStreams.values(), totalGetWorkBudget); + activeMetadataVersion = newWindmillEndpoints.version(); + } + + /** Close the streams that are no longer valid asynchronously. */ + private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { + StreamingEngineBackends currentBackends = backends.get(); + currentBackends.windmillStreams().entrySet().stream() + .filter( + connectionAndStream -> + !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) + .forEach( + entry -> + windmillStreamManager.execute( + () -> closeStreamSender(entry.getKey(), entry.getValue()))); - LOG.info( - "Setting new connections: {}. Previous connections: {}.", - newConnectionsState, - connections.get()); - connections.set(newConnectionsState); - isBudgetRefreshPaused.set(false); - getWorkBudgetRefresher.requestBudgetRefresh(); + Set newGlobalDataEndpoints = + new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); + currentBackends.globalDataStreams().values().stream() + .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) + .forEach( + sender -> + windmillStreamManager.execute(() -> closeStreamSender(sender.endpoint(), sender))); + } + + private void closeStreamSender(Endpoint endpoint, Closeable sender) { + LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender); + try { + sender.close(); + endpoint.directEndpoint().ifPresent(channelCachingStubFactory::remove); + LOG.debug("Successfully closed streams to {}", endpoint); + } catch (Exception e) { + LOG.error("Error closing streams to endpoint={}, sender={}", endpoint, sender); + } + } + + private synchronized CompletableFuture> + createAndStartNewStreams(ImmutableSet newWindmillEndpoints) { + ImmutableMap currentStreams = backends.get().windmillStreams(); + return MoreFutures.allAsList( + newWindmillEndpoints.stream() + .map(endpoint -> getOrCreateWindmillStreamSenderFuture(endpoint, currentStreams)) + .collect(Collectors.toList())) + .thenApply( + backends -> backends.stream().collect(toImmutableMap(Pair::getLeft, Pair::getRight))) + .toCompletableFuture(); + } + + private CompletionStage> + getOrCreateWindmillStreamSenderFuture( + Endpoint endpoint, ImmutableMap currentStreams) { + return MoreFutures.supplyAsync( + () -> + Pair.of( + endpoint, + Optional.ofNullable(currentStreams.get(endpoint)) + .orElseGet(() -> createAndStartWindmillStreamSender(endpoint))), + windmillStreamManager); } /** Add up all the throttle times of all streams including GetWorkerMetadataStream. */ - public long getAndResetThrottleTimes() { - return connections.get().windmillStreams().values().stream() + public long getAndResetThrottleTime() { + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getAndResetThrottleTime) .reduce(0L, Long::sum) + getWorkerMetadataThrottleTimer.getAndResetThrottleTime(); } public long currentActiveCommitBytes() { - return connections.get().windmillStreams().values().stream() + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getCurrentActiveCommitBytes) .reduce(0L, Long::sum); } @VisibleForTesting - StreamingEngineConnectionState getCurrentConnections() { - return connections.get(); - } - - private synchronized ImmutableMap createNewWindmillConnections( - List newWindmillEndpoints) { - ImmutableMap currentConnections = - connections.get().windmillConnections(); - return newWindmillEndpoints.stream() - .collect( - toImmutableMap( - Function.identity(), - endpoint -> - // Reuse existing stubs if they exist. Optional.orElseGet only calls the - // supplier if the value is not present, preventing constructing expensive - // objects. - Optional.ofNullable(currentConnections.get(endpoint)) - .orElseGet( - () -> WindmillConnection.from(endpoint, this::createWindmillStub)))); + StreamingEngineBackends currentBackends() { + return backends.get(); } - private synchronized ImmutableMap - closeStaleStreamsAndCreateNewStreams(Collection newWindmillConnections) { - ImmutableMap currentStreams = - connections.get().windmillStreams(); - - // Close the streams that are no longer valid. - currentStreams.entrySet().stream() - .filter( - connectionAndStream -> !newWindmillConnections.contains(connectionAndStream.getKey())) - .forEach( - entry -> { - entry.getValue().closeAllStreams(); - entry.getKey().directEndpoint().ifPresent(channelCachingStubFactory::remove); - }); - - return newWindmillConnections.stream() - .collect( - toImmutableMap( - Function.identity(), - newConnection -> - Optional.ofNullable(currentStreams.get(newConnection)) - .orElseGet(() -> createAndStartWindmillStreamSenderFor(newConnection)))); - } - - private ImmutableMap> createNewGlobalDataStreams( + private ImmutableMap createNewGlobalDataStreams( ImmutableMap newGlobalDataEndpoints) { - ImmutableMap> currentGlobalDataStreams = - connections.get().globalDataStreams(); + ImmutableMap currentGlobalDataStreams = + backends.get().globalDataStreams(); return newGlobalDataEndpoints.entrySet().stream() .collect( toImmutableMap( Entry::getKey, keyedEndpoint -> - existingOrNewGetDataStreamFor(keyedEndpoint, currentGlobalDataStreams))); + getOrCreateGlobalDataSteam(keyedEndpoint, currentGlobalDataStreams))); } - private Supplier existingOrNewGetDataStreamFor( + private GlobalDataStreamSender getOrCreateGlobalDataSteam( Entry keyedEndpoint, - ImmutableMap> currentGlobalDataStreams) { - return Preconditions.checkNotNull( - currentGlobalDataStreams.getOrDefault( - keyedEndpoint.getKey(), + ImmutableMap currentGlobalDataStreams) { + return Optional.ofNullable(currentGlobalDataStreams.get(keyedEndpoint.getKey())) + .orElseGet( () -> - streamFactory.createGetDataStream( - newOrExistingStubFor(keyedEndpoint.getValue()), new ThrottleTimer()))); - } - - private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint endpoint) { - return Optional.ofNullable(connections.get().windmillConnections().get(endpoint)) - .map(WindmillConnection::stub) - .orElseGet(() -> createWindmillStub(endpoint)); + new GlobalDataStreamSender( + () -> + streamFactory.createGetDataStream( + createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()), + keyedEndpoint.getValue())); } - private WindmillStreamSender createAndStartWindmillStreamSenderFor( - WindmillConnection connection) { - // Initially create each stream with no budget. The budget will be eventually assigned by the - // GetWorkBudgetDistributor. + private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) { WindmillStreamSender windmillStreamSender = WindmillStreamSender.create( - connection, + WindmillConnection.from(endpoint, this::createWindmillStub), GetWorkRequest.newBuilder() - .setClientId(clientId) + .setClientId(jobHeader.getClientId()) .setJobId(jobHeader.getJobId()) .setProjectId(jobHeader.getProjectId()) .setWorkerId(jobHeader.getWorkerId()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java new file mode 100644 index 000000000000..ce5f3a7b6bfc --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +import java.io.Closeable; +import java.util.function.Supplier; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; + +@Internal +@ThreadSafe +// TODO (m-trieu): replace Supplier with Stream after github.com/apache/beam/pull/32774/ is +// merged +final class GlobalDataStreamSender implements Closeable, Supplier { + private final Endpoint endpoint; + private final Supplier delegate; + private volatile boolean started; + + GlobalDataStreamSender(Supplier delegate, Endpoint endpoint) { + // Ensures that the Supplier is thread-safe + this.delegate = Suppliers.memoize(delegate::get); + this.started = false; + this.endpoint = endpoint; + } + + @Override + public GetDataStream get() { + if (!started) { + started = true; + } + + return delegate.get(); + } + + @Override + public void close() { + if (started) { + delegate.get().shutdown(); + } + } + + Endpoint endpoint() { + return endpoint; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java index bc93e6d89c41..06598b61c458 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java @@ -82,7 +82,7 @@ public final class SingleSourceWorkerHarness implements StreamingWorkerHarness { this.waitForResources = waitForResources; this.computationStateFetcher = computationStateFetcher; this.workProviderExecutor = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) .setPriority(Thread.MIN_PRIORITY) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java similarity index 55% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java index 3c85ee6abe1f..14290b486830 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java @@ -18,47 +18,37 @@ package org.apache.beam.runners.dataflow.worker.streaming.harness; import com.google.auto.value.AutoValue; -import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; /** - * Represents the current state of connections to Streaming Engine. Connections are updated when - * backend workers assigned to the key ranges being processed by this user worker change during + * Represents the current state of connections to the Streaming Engine backend. Backends are updated + * when backend workers assigned to the key ranges being processed by this user worker change during * pipeline execution. For example, changes can happen via autoscaling, load-balancing, or other * backend updates. */ @AutoValue -abstract class StreamingEngineConnectionState { - static final StreamingEngineConnectionState EMPTY = builder().build(); +abstract class StreamingEngineBackends { + static final StreamingEngineBackends EMPTY = builder().build(); static Builder builder() { - return new AutoValue_StreamingEngineConnectionState.Builder() - .setWindmillConnections(ImmutableMap.of()) + return new AutoValue_StreamingEngineBackends.Builder() .setWindmillStreams(ImmutableMap.of()) .setGlobalDataStreams(ImmutableMap.of()); } - abstract ImmutableMap windmillConnections(); - - abstract ImmutableMap windmillStreams(); + abstract ImmutableMap windmillStreams(); /** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them. */ - abstract ImmutableMap> globalDataStreams(); + abstract ImmutableMap globalDataStreams(); @AutoValue.Builder abstract static class Builder { - public abstract Builder setWindmillConnections( - ImmutableMap value); - - public abstract Builder setWindmillStreams( - ImmutableMap value); + public abstract Builder setWindmillStreams(ImmutableMap value); public abstract Builder setGlobalDataStreams( - ImmutableMap> value); + ImmutableMap value); - public abstract StreamingEngineConnectionState build(); + public abstract StreamingEngineBackends build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 45aa403ee71b..744c3d74445f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import java.io.Closeable; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -49,7 +50,7 @@ * {@link GetWorkBudget} is set. * *

Once started, the underlying streams are "alive" until they are manually closed via {@link - * #closeAllStreams()}. + * #close()} ()}. * *

If closed, it means that the backend endpoint is no longer in the worker set. Once closed, * these instances are not reused. @@ -59,7 +60,7 @@ */ @Internal @ThreadSafe -final class WindmillStreamSender implements GetWorkBudgetSpender { +final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable { private final AtomicBoolean started; private final AtomicReference getWorkBudget; private final Supplier getWorkStream; @@ -103,9 +104,9 @@ private WindmillStreamSender( connection, withRequestBudget(getWorkRequest, getWorkBudget.get()), streamingEngineThrottleTimers.getWorkThrottleTimer(), - () -> FixedStreamHeartbeatSender.create(getDataStream.get()), - () -> getDataClientFactory.apply(getDataStream.get()), - workCommitter, + FixedStreamHeartbeatSender.create(getDataStream.get()), + getDataClientFactory.apply(getDataStream.get()), + workCommitter.get(), workItemScheduler)); } @@ -141,7 +142,8 @@ void startStreams() { started.set(true); } - void closeAllStreams() { + @Override + public void close() { // Supplier.get() starts the stream which is an expensive operation as it initiates the // streaming RPCs by possibly making calls over the network. Do not close the streams unless // they have already been started. @@ -154,18 +156,13 @@ void closeAllStreams() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta)); + public void setBudget(long items, long bytes) { + getWorkBudget.set(getWorkBudget.get().apply(items, bytes)); if (started.get()) { - getWorkStream.get().adjustBudget(itemsDelta, bytesDelta); + getWorkStream.get().setBudget(items, bytes); } } - @Override - public GetWorkBudget remainingBudget() { - return started.get() ? getWorkStream.get().remainingBudget() : getWorkBudget.get(); - } - long getAndResetThrottleTime() { return streamingEngineThrottleTimers.getAndResetThrottleTime(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index d7ed83def43e..eb269eef848f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -17,8 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; import com.google.auto.value.AutoValue; import java.net.Inet6Address; @@ -27,8 +27,8 @@ import java.util.Map; import java.util.Optional; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,6 +41,14 @@ public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + public static WindmillEndpoints none() { + return WindmillEndpoints.builder() + .setVersion(Long.MAX_VALUE) + .setWindmillEndpoints(ImmutableSet.of()) + .setGlobalDataEndpoints(ImmutableMap.of()) + .build(); + } + public static WindmillEndpoints from( Windmill.WorkerMetadataResponse workerMetadataResponseProto) { ImmutableMap globalDataServers = @@ -53,14 +61,15 @@ public static WindmillEndpoints from( endpoint.getValue(), workerMetadataResponseProto.getExternalEndpoint()))); - ImmutableList windmillServers = + ImmutableSet windmillServers = workerMetadataResponseProto.getWorkEndpointsList().stream() .map( endpointProto -> Endpoint.from(endpointProto, workerMetadataResponseProto.getExternalEndpoint())) - .collect(toImmutableList()); + .collect(toImmutableSet()); return WindmillEndpoints.builder() + .setVersion(workerMetadataResponseProto.getMetadataVersion()) .setGlobalDataEndpoints(globalDataServers) .setWindmillEndpoints(windmillServers) .build(); @@ -123,6 +132,9 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( directEndpointAddress.getHostAddress(), (int) endpointProto.getPort())); } + /** Version of the endpoints which increases with every modification. */ + public abstract long version(); + /** * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key * is a global data tag and the value is the endpoint where the data associated with the global @@ -138,7 +150,7 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( * Windmill servers. Returns a list of endpoints used to communicate with the corresponding * Windmill servers. */ - public abstract ImmutableList windmillEndpoints(); + public abstract ImmutableSet windmillEndpoints(); /** * Representation of an endpoint in {@link Windmill.WorkerMetadataResponse.Endpoint} proto with @@ -204,13 +216,15 @@ public abstract static class Builder { @AutoValue.Builder public abstract static class Builder { + public abstract Builder setVersion(long version); + public abstract Builder setGlobalDataEndpoints( ImmutableMap globalDataServers); public abstract Builder setWindmillEndpoints( - ImmutableList windmillServers); + ImmutableSet windmillServers); - abstract ImmutableList.Builder windmillEndpointsBuilder(); + abstract ImmutableSet.Builder windmillEndpointsBuilder(); public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint endpoint) { windmillEndpointsBuilder().add(endpoint); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java index 90f93b072673..0b895652efe2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java @@ -19,38 +19,36 @@ import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; -import java.net.Inet6Address; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; /** Used to create channels to communicate with Streaming Engine via gRpc. */ @AutoOneOf(WindmillServiceAddress.Kind.class) public abstract class WindmillServiceAddress { - public static WindmillServiceAddress create(Inet6Address ipv6Address) { - return AutoOneOf_WindmillServiceAddress.ipv6(ipv6Address); - } public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) { return AutoOneOf_WindmillServiceAddress.gcpServiceAddress(gcpServiceAddress); } - public abstract Kind getKind(); + public static WindmillServiceAddress create( + AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { + return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( + authenticatedGcpServiceAddress); + } - public abstract Inet6Address ipv6(); + public abstract Kind getKind(); public abstract HostAndPort gcpServiceAddress(); public abstract AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress(); - public static WindmillServiceAddress create( - AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { - return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( - authenticatedGcpServiceAddress); + public final HostAndPort getServiceAddress() { + return getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS + ? gcpServiceAddress() + : authenticatedGcpServiceAddress().gcpServiceAddress(); } public enum Kind { - IPV6, GCP_SERVICE_ADDRESS, - // TODO(m-trieu): Use for direct connections when ALTS is enabled. AUTHENTICATED_GCP_SERVICE_ADDRESS } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 31bd4e146a78..f26c56b14ec2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -56,10 +56,11 @@ public interface WindmillStream { @ThreadSafe interface GetWorkStream extends WindmillStream { /** Adjusts the {@link GetWorkBudget} for the stream. */ - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(GetWorkBudget newBudget); - /** Returns the remaining in-flight {@link GetWorkBudget}. */ - GetWorkBudget remainingBudget(); + default void setBudget(long newItems, long newBytes) { + setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build()); + } } /** Interface for streaming GetDataRequests to Windmill. */ diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java similarity index 51% rename from sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java index f15fc5307374..498e90f78e29 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java @@ -15,27 +15,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.gcp.spanner.changestreams; +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.junit.Test; +/** Utility class for commits. */ +@Internal +public final class Commits { -public class NameGeneratorTest { - private static final int MAXIMUM_POSTGRES_TABLE_NAME_LENGTH = 63; + /** Max bytes of commits queued on the user worker. */ + @VisibleForTesting static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB - @Test - public void testGenerateMetadataTableNameRemovesHyphens() { - final String tableName = - NameGenerator.generatePartitionMetadataTableName("my-database-id-12345"); - assertFalse(tableName.contains("-")); - } + private Commits() {} - @Test - public void testGenerateMetadataTableNameIsShorterThan64Characters() { - final String tableName = - NameGenerator.generatePartitionMetadataTableName("my-database-id1-maximum-length"); - assertTrue(tableName.length() <= MAXIMUM_POSTGRES_TABLE_NAME_LENGTH); + public static WeightedSemaphore maxCommitByteSemaphore() { + return WeightedSemaphore.create(MAX_QUEUED_COMMITS_BYTES, Commit::getSize); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index d092ebf53fc1..20b95b0661d0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -42,7 +42,6 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class); private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20; - private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB private final Consumer commitWorkFn; private final WeightedBoundedQueue commitQueue; @@ -53,11 +52,9 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter { private StreamingApplianceWorkCommitter( Consumer commitWorkFn, Consumer onCommitComplete) { this.commitWorkFn = commitWorkFn; - this.commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitQueue = WeightedBoundedQueue.create(Commits.maxCommitByteSemaphore()); this.commitWorkers = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) .setPriority(Thread.MAX_PRIORITY) @@ -73,10 +70,9 @@ public static StreamingApplianceWorkCommitter create( } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { if (!commitWorkers.isShutdown()) { - commitWorkers.submit(this::commitLoop); + commitWorkers.execute(this::commitLoop); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index bf1007bc4bfb..85fa1d67c6c3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -28,6 +28,7 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; @@ -46,7 +47,6 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class); private static final int TARGET_COMMIT_BATCH_KEYS = 5; - private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB private static final String NO_BACKEND_WORKER_TOKEN = ""; private final Supplier> commitWorkStreamFactory; @@ -61,11 +61,10 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { Supplier> commitWorkStreamFactory, int numCommitSenders, Consumer onCommitComplete, - String backendWorkerToken) { + String backendWorkerToken, + WeightedSemaphore commitByteSemaphore) { this.commitWorkStreamFactory = commitWorkStreamFactory; - this.commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitQueue = WeightedBoundedQueue.create(commitByteSemaphore); this.commitSenders = Executors.newFixedThreadPool( numCommitSenders, @@ -90,12 +89,11 @@ public static Builder builder() { } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { Preconditions.checkState( isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start()."); for (int i = 0; i < numCommitSenders; i++) { - commitSenders.submit(this::streamingCommitLoop); + commitSenders.execute(this::streamingCommitLoop); } } @@ -166,6 +164,8 @@ private void streamingCommitLoop() { return; } } + + // take() blocks until a value is available in the commitQueue. Preconditions.checkNotNull(initialCommit); if (initialCommit.work().isFailed()) { @@ -258,6 +258,8 @@ public interface Builder { Builder setCommitWorkStreamFactory( Supplier> commitWorkStreamFactory); + Builder setCommitByteSemaphore(WeightedSemaphore commitByteSemaphore); + Builder setNumCommitSenders(int numCommitSenders); Builder setOnCommitComplete(Consumer onCommitComplete); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 19de998b1da8..b27ebc8e9eee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -21,9 +21,11 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.concurrent.GuardedBy; +import net.jcip.annotations.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -44,8 +46,8 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Implementation of {@link GetWorkStream} that passes along a specific {@link @@ -55,9 +57,10 @@ * these direct streams are used to facilitate these RPC calls to specific backend workers. */ @Internal -public final class GrpcDirectGetWorkStream +final class GrpcDirectGetWorkStream extends AbstractWindmillStream implements GetWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcDirectGetWorkStream.class); private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST = StreamingGetWorkRequest.newBuilder() .setRequestExtension( @@ -67,15 +70,14 @@ public final class GrpcDirectGetWorkStream .build()) .build(); - private final AtomicReference inFlightBudget; - private final AtomicReference nextBudgetAdjustment; - private final AtomicReference pendingResponseBudget; - private final GetWorkRequest request; + private final GetWorkBudgetTracker budgetTracker; + private final GetWorkRequest requestHeader; private final WorkItemScheduler workItemScheduler; private final ThrottleTimer getWorkThrottleTimer; - private final Supplier heartbeatSender; - private final Supplier workCommitter; - private final Supplier getDataClient; + private final HeartbeatSender heartbeatSender; + private final WorkCommitter workCommitter; + private final GetDataClient getDataClient; + private final AtomicReference lastRequest; /** * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they @@ -92,15 +94,15 @@ private GrpcDirectGetWorkStream( StreamObserver, StreamObserver> startGetWorkRpcFn, - GetWorkRequest request, + GetWorkRequest requestHeader, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { super( "GetWorkStream", @@ -110,19 +112,23 @@ private GrpcDirectGetWorkStream( streamRegistry, logEveryNStreamFailures, backendWorkerToken); - this.request = request; + this.requestHeader = requestHeader; this.getWorkThrottleTimer = getWorkThrottleTimer; this.workItemScheduler = workItemScheduler; this.workItemAssemblers = new ConcurrentHashMap<>(); - this.heartbeatSender = Suppliers.memoize(heartbeatSender::get); - this.workCommitter = Suppliers.memoize(workCommitter::get); - this.getDataClient = Suppliers.memoize(getDataClient::get); - this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); - this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget()); - this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget()); + this.heartbeatSender = heartbeatSender; + this.workCommitter = workCommitter; + this.getDataClient = getDataClient; + this.lastRequest = new AtomicReference<>(); + this.budgetTracker = + new GetWorkBudgetTracker( + GetWorkBudget.builder() + .setItems(requestHeader.getMaxItems()) + .setBytes(requestHeader.getMaxBytes()) + .build()); } - public static GrpcDirectGetWorkStream create( + static GrpcDirectGetWorkStream create( String backendWorkerToken, Function< StreamObserver, @@ -134,9 +140,9 @@ public static GrpcDirectGetWorkStream create( Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { GrpcDirectGetWorkStream getWorkStream = new GrpcDirectGetWorkStream( @@ -165,46 +171,52 @@ private static Watermarks createWatermarks( .build(); } - private void sendRequestExtension(GetWorkBudget adjustment) { - inFlightBudget.getAndUpdate(budget -> budget.apply(adjustment)); - StreamingGetWorkRequest extension = - StreamingGetWorkRequest.newBuilder() - .setRequestExtension( - Windmill.StreamingGetWorkRequestExtension.newBuilder() - .setMaxItems(adjustment.items()) - .setMaxBytes(adjustment.bytes())) - .build(); - - executor() - .execute( - () -> { - try { - send(extension); - } catch (IllegalStateException e) { - // Stream was closed. - } - }); + /** + * @implNote Do not lock/synchronize here due to this running on grpc serial executor for message + * which can deadlock since we send on the stream beneath the synchronization. {@link + * AbstractWindmillStream#send(Object)} is synchronized so the sends are already guarded. + */ + private void maybeSendRequestExtension(GetWorkBudget extension) { + if (extension.items() > 0 || extension.bytes() > 0) { + executeSafely( + () -> { + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequestExtension( + Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(extension.items()) + .setMaxBytes(extension.bytes())) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(extension); + try { + send(request); + } catch (IllegalStateException e) { + // Stream was closed. + } + }); + } } @Override protected synchronized void onNewStream() { workItemAssemblers.clear(); - // Add the current in-flight budget to the next adjustment. Only positive values are allowed - // here - // with negatives defaulting to 0, since GetWorkBudgets cannot be created with negative values. - GetWorkBudget budgetAdjustment = nextBudgetAdjustment.get().apply(inFlightBudget.get()); - inFlightBudget.set(budgetAdjustment); - send( - StreamingGetWorkRequest.newBuilder() - .setRequest( - request - .toBuilder() - .setMaxBytes(budgetAdjustment.bytes()) - .setMaxItems(budgetAdjustment.items())) - .build()); - - // We just sent the budget, reset it. - nextBudgetAdjustment.set(GetWorkBudget.noBudget()); + if (!isShutdown()) { + budgetTracker.reset(); + GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequest( + requestHeader + .toBuilder() + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build()) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(initialGetWorkBudget); + send(request); + } } @Override @@ -216,8 +228,9 @@ protected boolean hasPendingRequests() { public void appendSpecificHtml(PrintWriter writer) { // Number of buffers is same as distinct workers that sent work on this stream. writer.format( - "GetWorkStream: %d buffers, %s inflight budget allowed.", - workItemAssemblers.size(), inFlightBudget.get()); + "GetWorkStream: %d buffers, " + "last sent request: %s; ", + workItemAssemblers.size(), lastRequest.get()); + writer.print(budgetTracker.debugString()); } @Override @@ -235,30 +248,22 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) { } private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) { - // Record the fact that there are now fewer outstanding messages and bytes on the stream. - inFlightBudget.updateAndGet(budget -> budget.subtract(1, assembledWorkItem.bufferedSize())); WorkItem workItem = assembledWorkItem.workItem(); GetWorkResponseChunkAssembler.ComputationMetadata metadata = assembledWorkItem.computationMetadata(); - pendingResponseBudget.getAndUpdate(budget -> budget.apply(1, workItem.getSerializedSize())); - try { - workItemScheduler.scheduleWork( - workItem, - createWatermarks(workItem, Preconditions.checkNotNull(metadata)), - createProcessingContext(Preconditions.checkNotNull(metadata.computationId())), - assembledWorkItem.latencyAttributions()); - } finally { - pendingResponseBudget.getAndUpdate(budget -> budget.apply(-1, -workItem.getSerializedSize())); - } + workItemScheduler.scheduleWork( + workItem, + createWatermarks(workItem, metadata), + createProcessingContext(metadata.computationId()), + assembledWorkItem.latencyAttributions()); + budgetTracker.recordBudgetReceived(assembledWorkItem.bufferedSize()); + GetWorkBudget extension = budgetTracker.computeBudgetExtension(); + maybeSendRequestExtension(extension); } private Work.ProcessingContext createProcessingContext(String computationId) { return Work.createProcessingContext( - computationId, - getDataClient.get(), - workCommitter.get()::commit, - heartbeatSender.get(), - backendWorkerToken()); + computationId, getDataClient, workCommitter::commit, heartbeatSender, backendWorkerToken()); } @Override @@ -267,25 +272,110 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - GetWorkBudget adjustment = - nextBudgetAdjustment - // Get the current value, and reset the nextBudgetAdjustment. This will be set again - // when adjustBudget is called. - .getAndUpdate(unused -> GetWorkBudget.noBudget()) - .apply(itemsDelta, bytesDelta); - sendRequestExtension(adjustment); + public void setBudget(GetWorkBudget newBudget) { + GetWorkBudget extension = budgetTracker.consumeAndComputeBudgetUpdate(newBudget); + maybeSendRequestExtension(extension); } - @Override - public GetWorkBudget remainingBudget() { - // Snapshot the current budgets. - GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get(); - GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get(); - GetWorkBudget currentInflightBudget = inFlightBudget.get(); - - return currentPendingResponseBudget - .apply(currentNextBudgetAdjustment) - .apply(currentInflightBudget); + private void executeSafely(Runnable runnable) { + try { + executor().execute(runnable); + } catch (RejectedExecutionException e) { + LOG.debug("{} has been shutdown.", getClass()); + } + } + + /** + * Tracks sent, received, max {@link GetWorkBudget} and uses this information to generate request + * extensions. + */ + @ThreadSafe + private static final class GetWorkBudgetTracker { + + @GuardedBy("GetWorkBudgetTracker.this") + private GetWorkBudget maxGetWorkBudget; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsReceived = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesReceived = 0; + + private GetWorkBudgetTracker(GetWorkBudget maxGetWorkBudget) { + this.maxGetWorkBudget = maxGetWorkBudget; + } + + private synchronized void reset() { + itemsRequested = 0; + bytesRequested = 0; + itemsReceived = 0; + bytesReceived = 0; + } + + private synchronized String debugString() { + return String.format( + "max budget: %s; " + + "in-flight budget: %s; " + + "total budget requested: %s; " + + "total budget received: %s.", + maxGetWorkBudget, inFlightBudget(), totalRequestedBudget(), totalReceivedBudget()); + } + + /** Consumes the new budget and computes an extension based on the new budget. */ + private synchronized GetWorkBudget consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) { + maxGetWorkBudget = newBudget; + return computeBudgetExtension(); + } + + private synchronized void recordBudgetRequested(GetWorkBudget budgetRequested) { + itemsRequested += budgetRequested.items(); + bytesRequested += budgetRequested.bytes(); + } + + private synchronized void recordBudgetReceived(long returnedBudget) { + itemsReceived++; + bytesReceived += returnedBudget; + } + + /** + * If the outstanding items or bytes limit has gotten too low, top both off with a + * GetWorkExtension. The goal is to keep the limits relatively close to their maximum values + * without sending too many extension requests. + */ + private synchronized GetWorkBudget computeBudgetExtension() { + // Expected items and bytes can go negative here, since WorkItems returned might be larger + // than the initially requested budget. + long inFlightItems = itemsRequested - itemsReceived; + long inFlightBytes = bytesRequested - bytesReceived; + + // Don't send negative budget extensions. + long requestBytes = Math.max(0, maxGetWorkBudget.bytes() - inFlightBytes); + long requestItems = Math.max(0, maxGetWorkBudget.items() - inFlightItems); + + return (inFlightItems > requestItems / 2 && inFlightBytes > requestBytes / 2) + ? GetWorkBudget.noBudget() + : GetWorkBudget.builder().setItems(requestItems).setBytes(requestBytes).build(); + } + + private synchronized GetWorkBudget inFlightBudget() { + return GetWorkBudget.builder() + .setItems(itemsRequested - itemsReceived) + .setBytes(bytesRequested - bytesReceived) + .build(); + } + + private synchronized GetWorkBudget totalRequestedBudget() { + return GetWorkBudget.builder().setItems(itemsRequested).setBytes(bytesRequested).build(); + } + + private synchronized GetWorkBudget totalReceivedBudget() { + return GetWorkBudget.builder().setItems(itemsReceived).setBytes(bytesReceived).build(); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index f96464150d4a..6bae84483d16 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -30,7 +30,6 @@ import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; @@ -53,8 +52,6 @@ public class GrpcDispatcherClient { private static final Logger LOG = LoggerFactory.getLogger(GrpcDispatcherClient.class); - static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS = - "streaming_engine_use_job_settings_for_isolated_channels"; private final CountDownLatch onInitializedEndpoints; /** @@ -80,18 +77,12 @@ private GrpcDispatcherClient( DispatcherStubs initialDispatcherStubs, Random rand) { this.windmillStubFactoryFactory = windmillStubFactoryFactory; - if (DataflowRunner.hasExperiment( - options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)) { - if (options.getUseWindmillIsolatedChannels() != null) { - this.useIsolatedChannels.set(options.getUseWindmillIsolatedChannels()); - this.reactToIsolatedChannelsJobSetting = false; - } else { - this.useIsolatedChannels.set(false); - this.reactToIsolatedChannelsJobSetting = true; - } - } else { - this.useIsolatedChannels.set(Boolean.TRUE.equals(options.getUseWindmillIsolatedChannels())); + if (options.getUseWindmillIsolatedChannels() != null) { + this.useIsolatedChannels.set(options.getUseWindmillIsolatedChannels()); this.reactToIsolatedChannelsJobSetting = false; + } else { + this.useIsolatedChannels.set(false); + this.reactToIsolatedChannelsJobSetting = true; } this.windmillStubFactory.set( windmillStubFactoryFactory.makeWindmillStubFactory(useIsolatedChannels.get())); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 0e9a0c6316ee..c99e05a77074 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -59,7 +59,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class GrpcGetDataStream +final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index 09ecbf3f3051..a368f3fec235 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -194,15 +194,7 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op } - - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setBytes(request.getMaxBytes() - inflightBytes.get()) - .setItems(request.getMaxItems() - inflightMessages.get()) - .build(); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 92f031db9972..9e6a02d135e2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -198,9 +198,9 @@ public GetWorkStream createDirectGetWorkStream( WindmillConnection connection, GetWorkRequest request, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( connection.backendWorkerToken(), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java index 9aec29a3ba4d..f0ea2f550a74 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java @@ -36,7 +36,6 @@ /** Utility class used to create different RPC Channels. */ public final class WindmillChannelFactory { public static final String LOCALHOST = "localhost"; - private static final int DEFAULT_GRPC_PORT = 443; private static final int MAX_REMOTE_TRACE_EVENTS = 100; private WindmillChannelFactory() {} @@ -55,8 +54,6 @@ public static Channel localhostChannel(int port) { public static ManagedChannel remoteChannel( WindmillServiceAddress windmillServiceAddress, int windmillServiceRpcChannelTimeoutSec) { switch (windmillServiceAddress.getKind()) { - case IPV6: - return remoteChannel(windmillServiceAddress.ipv6(), windmillServiceRpcChannelTimeoutSec); case GCP_SERVICE_ADDRESS: return remoteChannel( windmillServiceAddress.gcpServiceAddress(), windmillServiceRpcChannelTimeoutSec); @@ -67,7 +64,8 @@ public static ManagedChannel remoteChannel( windmillServiceRpcChannelTimeoutSec); default: throw new UnsupportedOperationException( - "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS are supported WindmillServiceAddresses."); + "Only GCP_SERVICE_ADDRESS and AUTHENTICATED_GCP_SERVICE_ADDRESS are supported" + + " WindmillServiceAddresses."); } } @@ -105,17 +103,6 @@ public static Channel remoteChannel( } } - public static ManagedChannel remoteChannel( - Inet6Address directEndpoint, int windmillServiceRpcChannelTimeoutSec) { - try { - return createRemoteChannel( - NettyChannelBuilder.forAddress(new InetSocketAddress(directEndpoint, DEFAULT_GRPC_PORT)), - windmillServiceRpcChannelTimeoutSec); - } catch (SSLException sslException) { - throw new WindmillChannelCreationException(directEndpoint.toString(), sslException); - } - } - @SuppressWarnings("nullness") private static ManagedChannel createRemoteChannel( NettyChannelBuilder channelBuilder, int windmillServiceRpcChannelTimeoutSec) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java index aed03f33e6d6..b17631a8bd0a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java @@ -137,12 +137,7 @@ protected Windmill.WorkItemCommitRequest persistDirectly(WindmillStateCache.ForK keyCoder.encode(key, keyStream, Coder.Context.OUTER); ByteString keyBytes = keyStream.toByteString(); // Leaving data blank means that we delete the tag. - commitBuilder - .addValueUpdatesBuilder() - .setTag(keyBytes) - .setStateFamily(stateFamily) - .getValueBuilder() - .setTimestamp(Long.MAX_VALUE); + commitBuilder.addValueUpdatesBuilder().setTag(keyBytes).setStateFamily(stateFamily); V cachedValue = cachedValues.remove(key); if (cachedValue != null) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java index 403bb99efb4c..8a1ba2556cf2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java @@ -17,18 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath.roundToLong; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide; import java.math.RoundingMode; -import java.util.Map; -import java.util.Map.Entry; -import java.util.function.Function; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,22 +29,11 @@ @Internal final class EvenGetWorkBudgetDistributor implements GetWorkBudgetDistributor { private static final Logger LOG = LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class); - private final Supplier activeWorkBudgetSupplier; - - EvenGetWorkBudgetDistributor(Supplier activeWorkBudgetSupplier) { - this.activeWorkBudgetSupplier = activeWorkBudgetSupplier; - } - - private static boolean isBelowFiftyPercentOfTarget( - GetWorkBudget remaining, GetWorkBudget target) { - return remaining.items() < roundToLong(target.items() * 0.5, RoundingMode.CEILING) - || remaining.bytes() < roundToLong(target.bytes() * 0.5, RoundingMode.CEILING); - } @Override public void distributeBudget( - ImmutableCollection budgetOwners, GetWorkBudget getWorkBudget) { - if (budgetOwners.isEmpty()) { + ImmutableCollection budgetSpenders, GetWorkBudget getWorkBudget) { + if (budgetSpenders.isEmpty()) { LOG.debug("Cannot distribute budget to no owners."); return; } @@ -61,38 +43,15 @@ public void distributeBudget( return; } - Map desiredBudgets = computeDesiredBudgets(budgetOwners, getWorkBudget); - - for (Entry streamAndDesiredBudget : desiredBudgets.entrySet()) { - GetWorkBudgetSpender getWorkBudgetSpender = streamAndDesiredBudget.getKey(); - GetWorkBudget desired = streamAndDesiredBudget.getValue(); - GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget(); - if (isBelowFiftyPercentOfTarget(remaining, desired)) { - GetWorkBudget adjustment = desired.subtract(remaining); - getWorkBudgetSpender.adjustBudget(adjustment); - } - } + GetWorkBudget budgetPerStream = computeDesiredPerStreamBudget(budgetSpenders, getWorkBudget); + budgetSpenders.forEach(getWorkBudgetSpender -> getWorkBudgetSpender.setBudget(budgetPerStream)); } - private ImmutableMap computeDesiredBudgets( + private GetWorkBudget computeDesiredPerStreamBudget( ImmutableCollection streams, GetWorkBudget totalGetWorkBudget) { - GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get(); - LOG.info("Current active work budget: {}", activeWorkBudget); - // TODO: Fix possibly non-deterministic handing out of budgets. - // Rounding up here will drift upwards over the lifetime of the streams. - GetWorkBudget budgetPerStream = - GetWorkBudget.builder() - .setItems( - divide( - totalGetWorkBudget.items() - activeWorkBudget.items(), - streams.size(), - RoundingMode.CEILING)) - .setBytes( - divide( - totalGetWorkBudget.bytes() - activeWorkBudget.bytes(), - streams.size(), - RoundingMode.CEILING)) - .build(); - return streams.stream().collect(toImmutableMap(Function.identity(), unused -> budgetPerStream)); + return GetWorkBudget.builder() + .setItems(divide(totalGetWorkBudget.items(), streams.size(), RoundingMode.CEILING)) + .setBytes(divide(totalGetWorkBudget.bytes(), streams.size(), RoundingMode.CEILING)) + .build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java index 43c0d46139da..2013c9ff1cb7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java @@ -17,13 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; @Internal public final class GetWorkBudgetDistributors { - public static GetWorkBudgetDistributor distributeEvenly( - Supplier activeWorkBudgetSupplier) { - return new EvenGetWorkBudgetDistributor(activeWorkBudgetSupplier); + public static GetWorkBudgetDistributor distributeEvenly() { + return new EvenGetWorkBudgetDistributor(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java index e39aa8dbc8a5..d81c7d0593f3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java @@ -51,7 +51,7 @@ public GetWorkBudgetRefresher( Supplier isBudgetRefreshPaused, Runnable redistributeBudget) { this.budgetRefreshTrigger = new AdvancingPhaser(1); this.budgetRefreshExecutor = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setNameFormat(BUDGET_REFRESH_THREAD) .setUncaughtExceptionHandler( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java index 254b2589062e..decf101a641b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java @@ -22,11 +22,9 @@ * org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget} */ public interface GetWorkBudgetSpender { - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(long items, long bytes); - default void adjustBudget(GetWorkBudget adjustment) { - adjustBudget(adjustment.items(), adjustment.bytes()); + default void setBudget(GetWorkBudget budget) { + setBudget(budget.items(), budget.bytes()); } - - GetWorkBudget remainingBudget(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 9a3e6eb6b099..c74874c465a6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -70,7 +70,7 @@ */ @Internal @ThreadSafe -public final class StreamingWorkScheduler { +public class StreamingWorkScheduler { private static final Logger LOG = LoggerFactory.getLogger(StreamingWorkScheduler.class); private final DataflowWorkerHarnessOptions options; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index b3f7467cdbd3..90ffb3d3fbcf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -245,18 +245,10 @@ public void halfClose() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op. } - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setItems(request.getMaxItems()) - .setBytes(request.getMaxBytes()) - .build(); - } - @Override public boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException { while (done.getCount() > 0) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java index 2d5a8d8266ae..37c5ad261280 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java @@ -366,7 +366,6 @@ public void testExtractPerWorkerMetricUpdates_populatedMetrics() { .setMetricsNamespace("BigQuerySink") .setMetricValues(Collections.singletonList(expectedCounter)); - // Expected histogram metric List bucketCounts = Collections.singletonList(1L); Linear linearOptions = new Linear().setNumberOfBuckets(10).setWidth(10.0).setStart(0.0); @@ -393,6 +392,44 @@ public void testExtractPerWorkerMetricUpdates_populatedMetrics() { assertThat(updates, containsInAnyOrder(histograms, counters)); } + @Test + public void testExtractPerWorkerMetricUpdatesKafka_populatedMetrics() { + StreamingStepMetricsContainer.setEnablePerWorkerMetrics(true); + + MetricName histogramMetricName = MetricName.named("KafkaSink", "histogram"); + HistogramData.LinearBuckets linearBuckets = HistogramData.LinearBuckets.of(0, 10, 10); + c2.getPerWorkerHistogram(histogramMetricName, linearBuckets).update(5.0); + + Iterable updates = + StreamingStepMetricsContainer.extractPerWorkerMetricUpdates(registry); + + // Expected histogram metric + List bucketCounts = Collections.singletonList(1L); + + Linear linearOptions = new Linear().setNumberOfBuckets(10).setWidth(10.0).setStart(0.0); + BucketOptions bucketOptions = new BucketOptions().setLinear(linearOptions); + + DataflowHistogramValue linearHistogram = + new DataflowHistogramValue() + .setCount(1L) + .setBucketOptions(bucketOptions) + .setBucketCounts(bucketCounts); + + MetricValue expectedHistogram = + new MetricValue() + .setMetric("histogram") + .setMetricLabels(new HashMap<>()) + .setValueHistogram(linearHistogram); + + PerStepNamespaceMetrics histograms = + new PerStepNamespaceMetrics() + .setOriginalStep("s2") + .setMetricsNamespace("KafkaSink") + .setMetricValues(Collections.singletonList(expectedHistogram)); + + assertThat(updates, containsInAnyOrder(histograms)); + } + @Test public void testExtractPerWorkerMetricUpdates_emptyMetrics() { StreamingStepMetricsContainer.setEnablePerWorkerMetrics(true); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java index ff114ef2f078..c1e5000f03da 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.theInstance; @@ -153,6 +154,21 @@ private static class TestStatefulDoFn extends DoFn, Void> { public void processElement(ProcessContext c) {} } + private static class TestStatefulDoFnWithWindowExpiration + extends DoFn, Void> { + + public static final String STATE_ID = "state-id"; + + @StateId(STATE_ID) + private final StateSpec> spec = StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void processElement(ProcessContext c) {} + + @OnWindowExpiration + public void onWindowExpiration() {} + } + private static final TupleTag MAIN_OUTPUT = new TupleTag<>("1"); private UserParDoFnFactory factory = UserParDoFnFactory.createDefault(); @@ -373,6 +389,92 @@ public void testCleanupRegistered() throws Exception { firstWindow.maxTimestamp().plus(Duration.millis(1L))); } + /** + * Regression test for global window + OnWindowExpiration + allowed lateness > max allowed time + */ + @Test + public void testCleanupTimerForGlobalWindowWithAllowedLateness() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + CounterSet counters = new CounterSet(); + DoFn initialFn = new TestStatefulDoFnWithWindowExpiration(); + Duration allowedLateness = Duration.standardDays(2); + CloudObject cloudObject = + getCloudObject( + initialFn, WindowingStrategy.globalDefault().withAllowedLateness(allowedLateness)); + + StateInternals stateInternals = InMemoryStateInternals.forKey("dummy"); + + TimerInternals timerInternals = mock(TimerInternals.class); + + DataflowStepContext stepContext = mock(DataflowStepContext.class); + when(stepContext.timerInternals()).thenReturn(timerInternals); + DataflowStepContext userStepContext = mock(DataflowStepContext.class); + when(stepContext.namespacedToUser()).thenReturn(userStepContext); + when(stepContext.stateInternals()).thenReturn(stateInternals); + when(userStepContext.stateInternals()).thenReturn((StateInternals) stateInternals); + + DataflowExecutionContext executionContext = + mock(DataflowExecutionContext.class); + TestOperationContext operationContext = TestOperationContext.create(counters); + when(executionContext.getStepContext(operationContext)).thenReturn(stepContext); + when(executionContext.getSideInputReader(any(), any(), any())) + .thenReturn(NullSideInputReader.empty()); + + ParDoFn parDoFn = + factory.create( + options, + cloudObject, + Collections.emptyList(), + MAIN_OUTPUT, + ImmutableMap.of(MAIN_OUTPUT, 0), + executionContext, + operationContext); + + Receiver rcvr = new OutputReceiver(); + parDoFn.startBundle(rcvr); + + GlobalWindow globalWindow = GlobalWindow.INSTANCE; + parDoFn.processElement( + WindowedValue.of("foo", new Instant(1), globalWindow, PaneInfo.NO_FIRING)); + + assertThat( + globalWindow.maxTimestamp().plus(allowedLateness), + greaterThan(BoundedWindow.TIMESTAMP_MAX_VALUE)); + verify(stepContext) + .setStateCleanupTimer( + SimpleParDoFn.CLEANUP_TIMER_ID, + globalWindow, + GlobalWindow.Coder.INSTANCE, + BoundedWindow.TIMESTAMP_MAX_VALUE, + BoundedWindow.TIMESTAMP_MAX_VALUE.minus(Duration.millis(1))); + + StateNamespace globalWindowNamespace = + StateNamespaces.window(GlobalWindow.Coder.INSTANCE, globalWindow); + StateTag> tag = + StateTags.tagForSpec( + TestStatefulDoFnWithWindowExpiration.STATE_ID, StateSpecs.value(StringUtf8Coder.of())); + + when(userStepContext.getNextFiredTimer((Coder) GlobalWindow.Coder.INSTANCE)).thenReturn(null); + when(stepContext.getNextFiredTimer((Coder) GlobalWindow.Coder.INSTANCE)) + .thenReturn( + TimerData.of( + SimpleParDoFn.CLEANUP_TIMER_ID, + globalWindowNamespace, + BoundedWindow.TIMESTAMP_MAX_VALUE, + BoundedWindow.TIMESTAMP_MAX_VALUE.minus(Duration.millis(1)), + TimeDomain.EVENT_TIME)) + .thenReturn(null); + + // Set up non-empty state. We don't mock + verify calls to clear() but instead + // check that state is actually empty. We mustn't care how it is accomplished. + stateInternals.state(globalWindowNamespace, tag).write("first"); + + // And this should clean up the second window + parDoFn.processTimers(); + + assertThat(stateInternals.state(globalWindowNamespace, tag).read(), nullValue()); + } + @Test public void testCleanupWorks() throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java index 4f035c88774c..c71001fbeee7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; @@ -30,27 +31,29 @@ @RunWith(JUnit4.class) public class WeightBoundedQueueTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final int MAX_WEIGHT = 10; + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Test public void testPut_hasCapacity() { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); int insertedValue = 1; queue.put(insertedValue); - assertEquals(insertedValue, queue.queuedElementsWeight()); + assertEquals(insertedValue, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); assertEquals(insertedValue, (int) queue.poll()); } @Test public void testPut_noCapacity() throws InterruptedException { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); // Insert value that takes all the capacity into the queue. queue.put(MAX_WEIGHT); @@ -71,7 +74,7 @@ public void testPut_noCapacity() throws InterruptedException { // Should only see the first value in the queue, since the queue is at capacity. thread2 // should be blocked. - assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); // Poll the queue, pulling off the only value inside and freeing up the capacity in the queue. @@ -80,14 +83,15 @@ public void testPut_noCapacity() throws InterruptedException { // Wait for the putThread which was previously blocked due to the queue being at capacity. putThread.join(); - assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); } @Test public void testPoll() { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); int insertedValue1 = 1; int insertedValue2 = 2; @@ -95,7 +99,7 @@ public void testPoll() { queue.put(insertedValue1); queue.put(insertedValue2); - assertEquals(insertedValue1 + insertedValue2, queue.queuedElementsWeight()); + assertEquals(insertedValue1 + insertedValue2, weightedSemaphore.currentWeight()); assertEquals(2, queue.size()); assertEquals(insertedValue1, (int) queue.poll()); assertEquals(1, queue.size()); @@ -104,7 +108,8 @@ public void testPoll() { @Test public void testPoll_withTimeout() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int pollWaitTimeMillis = 10000; int insertedValue1 = 1; @@ -132,7 +137,8 @@ public void testPoll_withTimeout() throws InterruptedException { @Test public void testPoll_withTimeout_timesOut() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int defaultPollResult = -10; int pollWaitTimeMillis = 100; int insertedValue1 = 1; @@ -144,13 +150,17 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException { Thread pollThread = new Thread( () -> { - int polled; + @Nullable Integer polled; try { polled = queue.poll(pollWaitTimeMillis, TimeUnit.MILLISECONDS); - pollResult.set(polled); + if (polled != null) { + pollResult.set(polled); + } } catch (InterruptedException e) { throw new RuntimeException(e); } + + assertNull(polled); }); pollThread.start(); @@ -164,7 +174,8 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException { @Test public void testPoll_emptyQueue() { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); assertNull(queue.poll()); } @@ -172,7 +183,8 @@ public void testPoll_emptyQueue() { @Test public void testTake() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); AtomicInteger value = new AtomicInteger(); // Should block until value is available @@ -194,4 +206,39 @@ public void testTake() throws InterruptedException { assertEquals(MAX_WEIGHT, value.get()); } + + @Test + public void testPut_sharedWeigher() throws InterruptedException { + WeightedSemaphore weigher = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue1 = WeightedBoundedQueue.create(weigher); + WeightedBoundedQueue queue2 = WeightedBoundedQueue.create(weigher); + + // Insert value that takes all the weight into the queue1. + queue1.put(MAX_WEIGHT); + + // Try to insert a value into the queue2. This will block since there is no capacity in the + // weigher. + Thread putThread = new Thread(() -> queue2.put(MAX_WEIGHT)); + putThread.start(); + // Should only see the first value in the queue, since the queue is at capacity. putThread + // should be blocked. The weight should be the same however, since queue1 and queue2 are sharing + // the weigher. + Thread.sleep(100); + assertEquals(MAX_WEIGHT, weigher.currentWeight()); + assertEquals(1, queue1.size()); + assertEquals(0, queue2.size()); + + // Poll queue1, pulling off the only value inside and freeing up the capacity in the weigher. + queue1.poll(); + + // Wait for the putThread which was previously blocked due to the weigher being at capacity. + putThread.join(); + + assertEquals(MAX_WEIGHT, weigher.currentWeight()); + assertEquals(1, queue2.size()); + queue2.poll(); + assertEquals(0, queue2.size()); + assertEquals(0, weigher.currentWeight()); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java index 9fa17588c94d..3a0ae7bb2084 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java @@ -47,7 +47,6 @@ @RunWith(JUnit4.class) public class StreamingEngineComputationConfigFetcherTest { - private final WorkUnitClient mockDataflowServiceClient = mock(WorkUnitClient.class, new Returns(Optional.empty())); private StreamingEngineComputationConfigFetcher streamingEngineConfigFetcher; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index ed8815c48e76..0092fcc7bcd1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -30,9 +30,7 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashSet; -import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; @@ -46,7 +44,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; @@ -71,7 +68,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.junit.After; import org.junit.Before; @@ -92,7 +88,6 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setDirectEndpoint(DEFAULT_WINDMILL_SERVICE_ADDRESS.gcpServiceAddress().toString()) .build()); - private static final long CLIENT_ID = 1L; private static final String JOB_ID = "jobId"; private static final String PROJECT_ID = "projectId"; private static final String WORKER_ID = "workerId"; @@ -101,6 +96,7 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) + .setClientId(1L) .build(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @@ -134,7 +130,7 @@ private static GetWorkRequest getWorkRequest(long items, long bytes) { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) - .setClientId(CLIENT_ID) + .setClientId(JOB_HEADER.getClientId()) .setMaxItems(items) .setMaxBytes(bytes) .build(); @@ -174,7 +170,7 @@ public void cleanUp() { stubFactory.shutdown(); } - private FanOutStreamingEngineWorkerHarness newStreamingEngineClient( + private FanOutStreamingEngineWorkerHarness newFanOutStreamingEngineWorkerHarness( GetWorkBudget getWorkBudget, GetWorkBudgetDistributor getWorkBudgetDistributor, WorkItemScheduler workItemScheduler) { @@ -186,7 +182,6 @@ private FanOutStreamingEngineWorkerHarness newStreamingEngineClient( stubFactory, getWorkBudgetDistributor, dispatcherClient, - CLIENT_ID, ignored -> mock(WorkCommitter.class), new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class))); } @@ -201,7 +196,7 @@ public void testStreamsStartCorrectly() throws InterruptedException { spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(items).setBytes(bytes).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -219,16 +214,14 @@ public void testStreamsStartCorrectly() throws InterruptedException { getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); - assertEquals(2, currentConnections.windmillConnections().size()); - assertEquals(2, currentConnections.windmillStreams().size()); + assertEquals(2, currentBackends.windmillStreams().size()); Set workerTokens = - currentConnections.windmillConnections().values().stream() - .map(WindmillConnection::backendWorkerToken) + currentBackends.windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertTrue(workerTokens.contains(workerToken)); @@ -252,27 +245,6 @@ public void testStreamsStartCorrectly() throws InterruptedException { verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); } - @Test - public void testScheduledBudgetRefresh() throws InterruptedException { - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(2)); - fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( - GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), - getWorkBudgetDistributor, - noOpProcessWorkItemFn()); - - getWorkerMetadataReady.await(); - fakeGetWorkerMetadataStub.injectWorkerMetadata( - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(1) - .addWorkEndpoints(metadataResponseEndpoint("workerToken")) - .putAllGlobalDataEndpoints(DEFAULT) - .build()); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(), any()); - } - @Test public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() throws InterruptedException { @@ -280,7 +252,7 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor(metadataCount)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -309,32 +281,28 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() WorkerMetadataResponse.Endpoint.newBuilder() .setBackendWorkerToken(workerToken3) .build()) - .putAllGlobalDataEndpoints(DEFAULT) .build(); getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); - assertEquals(1, currentConnections.windmillConnections().size()); - assertEquals(1, currentConnections.windmillStreams().size()); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); + assertEquals(1, currentBackends.windmillStreams().size()); Set workerTokens = - fanOutStreamingEngineWorkProvider.getCurrentConnections().windmillConnections().values() - .stream() - .map(WindmillConnection::backendWorkerToken) + fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertFalse(workerTokens.contains(workerToken)); assertFalse(workerTokens.contains(workerToken2)); + assertTrue(currentBackends.globalDataStreams().isEmpty()); } @Test public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedException { String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; - String workerToken3 = "workerToken3"; WorkerMetadataResponse firstWorkerMetadata = WorkerMetadataResponse.newBuilder() @@ -354,42 +322,24 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); - WorkerMetadataResponse thirdWorkerMetadata = - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(3) - .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder() - .setBackendWorkerToken(workerToken3) - .build()) - .putAllGlobalDataEndpoints(DEFAULT) - .build(); - - List workerMetadataResponses = - Lists.newArrayList(firstWorkerMetadata, secondWorkerMetadata, thirdWorkerMetadata); TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(workerMetadataResponses.size())); + spy(new TestGetWorkBudgetDistributor(1)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); getWorkerMetadataReady.await(); - // Make sure we are injecting the metadata from smallest to largest. - workerMetadataResponses.stream() - .sorted(Comparator.comparingLong(WorkerMetadataResponse::getMetadataVersion)) - .forEach(fakeGetWorkerMetadataStub::injectWorkerMetadata); - - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(workerMetadataResponses.size())) - .distributeBudget(any(), any()); - } + fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + getWorkBudgetDistributor.expectNumDistributions(1); + fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - private void waitForWorkerMetadataToBeConsumed( - TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws InterruptedException { - getWorkBudgetDistributor.waitForBudgetDistribution(); + verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any()); } private static class GetWorkerMetadataTestStub @@ -434,21 +384,24 @@ private void injectWorkerMetadata(WorkerMetadataResponse response) { } private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor { - private final CountDownLatch getWorkBudgetDistributorTriggered; + private CountDownLatch getWorkBudgetDistributorTriggered; private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) { this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); } - @SuppressWarnings("ReturnValueIgnored") - private void waitForBudgetDistribution() throws InterruptedException { - getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + private boolean waitForBudgetDistribution() throws InterruptedException { + return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + } + + private void expectNumDistributions(int numBudgetDistributionsExpected) { + this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); } @Override public void distributeBudget( ImmutableCollection streams, GetWorkBudget getWorkBudget) { - streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(), getWorkBudget.bytes())); + streams.forEach(stream -> stream.setBudget(getWorkBudget.items(), getWorkBudget.bytes())); getWorkBudgetDistributorTriggered.countDown(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java new file mode 100644 index 000000000000..5a2df4baae61 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.WorkerUncaughtExceptionHandler; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.util.common.worker.JvmRuntime; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.processing.StreamingWorkScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@RunWith(JUnit4.class) +public class SingleSourceWorkerHarnessTest { + private static final Logger LOG = LoggerFactory.getLogger(SingleSourceWorkerHarnessTest.class); + private final WorkCommitter workCommitter = mock(WorkCommitter.class); + private final GetDataClient getDataClient = mock(GetDataClient.class); + private final HeartbeatSender heartbeatSender = mock(HeartbeatSender.class); + private final Runnable waitForResources = () -> {}; + private final Function> computationStateFetcher = + ignored -> Optional.empty(); + private final StreamingWorkScheduler streamingWorkScheduler = mock(StreamingWorkScheduler.class); + + private SingleSourceWorkerHarness createWorkerHarness( + SingleSourceWorkerHarness.GetWorkSender getWorkSender, JvmRuntime runtime) { + // In non-test scenario this is set in DataflowWorkerHarnessHelper.initializeLogging(...). + Thread.setDefaultUncaughtExceptionHandler(new WorkerUncaughtExceptionHandler(runtime, LOG)); + return SingleSourceWorkerHarness.builder() + .setWorkCommitter(workCommitter) + .setGetDataClient(getDataClient) + .setHeartbeatSender(heartbeatSender) + .setWaitForResources(waitForResources) + .setStreamingWorkScheduler(streamingWorkScheduler) + .setComputationStateFetcher(computationStateFetcher) + .setGetWorkSender(getWorkSender) + .build(); + } + + @Test + public void testDispatchLoop_unexpectedFailureKillsJvm_appliance() { + SingleSourceWorkerHarness.GetWorkSender getWorkSender = + SingleSourceWorkerHarness.GetWorkSender.forAppliance( + () -> { + throw new RuntimeException("something bad happened"); + }); + + FakeJvmRuntime fakeJvmRuntime = new FakeJvmRuntime(); + createWorkerHarness(getWorkSender, fakeJvmRuntime).start(); + assertTrue(fakeJvmRuntime.waitForRuntimeDeath(5, TimeUnit.SECONDS)); + fakeJvmRuntime.assertJvmTerminated(); + } + + @Test + public void testDispatchLoop_unexpectedFailureKillsJvm_streamingEngine() { + SingleSourceWorkerHarness.GetWorkSender getWorkSender = + SingleSourceWorkerHarness.GetWorkSender.forStreamingEngine( + workItemReceiver -> { + throw new RuntimeException("something bad happened"); + }); + + FakeJvmRuntime fakeJvmRuntime = new FakeJvmRuntime(); + createWorkerHarness(getWorkSender, fakeJvmRuntime).start(); + assertTrue(fakeJvmRuntime.waitForRuntimeDeath(5, TimeUnit.SECONDS)); + fakeJvmRuntime.assertJvmTerminated(); + } + + private static class FakeJvmRuntime implements JvmRuntime { + private final CountDownLatch haltedLatch = new CountDownLatch(1); + private volatile int exitStatus = 0; + + @Override + public void halt(int status) { + exitStatus = status; + haltedLatch.countDown(); + } + + public boolean waitForRuntimeDeath(long timeout, TimeUnit unit) { + try { + return haltedLatch.await(timeout, unit); + } catch (InterruptedException e) { + return false; + } + } + + private void assertJvmTerminated() { + assertThat(exitStatus).isEqualTo(WorkerUncaughtExceptionHandler.JVM_TERMINATED_STATUS_CODE); + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java index dc6cc5641055..32d1f5738086 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java @@ -193,7 +193,7 @@ public void testCloseAllStreams_doesNotCloseUnstartedStreams() { WindmillStreamSender windmillStreamSender = newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build()); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.close(); verifyNoInteractions(streamFactory); } @@ -230,7 +230,7 @@ public void testCloseAllStreams_closesAllStreams() { mockStreamFactory); windmillStreamSender.startStreams(); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.close(); verify(mockGetWorkStream).shutdown(); verify(mockGetDataStream).shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 546a2883e3b2..c05a4dd340dd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -121,6 +121,7 @@ public void setUp() throws IOException { private WorkCommitter createWorkCommitter(Consumer onCommitComplete) { return StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory(commitWorkStreamFactory) .setOnCommitComplete(onCommitComplete) .build(); @@ -342,6 +343,7 @@ public void testMultipleCommitSendersSingleStream() { Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); workCommitter = StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory(commitWorkStreamFactory) .setNumCommitSenders(5) .setOnCommitComplete(completeCommits::add) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java new file mode 100644 index 000000000000..fd2b30238836 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final WorkItemScheduler NO_OP_WORK_ITEM_SCHEDULER = + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setClientId(1L) + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + private static void assertHeader( + Windmill.StreamingGetWorkRequest getWorkRequest, GetWorkBudget expectedInitialBudget) { + assertTrue(getWorkRequest.hasRequest()); + assertFalse(getWorkRequest.hasRequestExtension()); + assertThat(getWorkRequest.getRequest()) + .isEqualTo( + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(expectedInitialBudget.items()) + .setMaxBytes(expectedInitialBudget.bytes()) + .build()); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertHeader(requestObserver.sent().get(0), GetWorkBudget.noBudget()); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.shutdown(); + stream.setBudget( + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(1); + assertHeader(Iterables.getOnlyElement(requests), GetWorkBudget.noBudget()); + } + + @Test + public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(1).setBytes(100).build(); + Set scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); + + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()) + .isEqualTo( + extension( + GetWorkBudget.builder() + .setItems(1) + .setBytes(initialBudget.bytes() - inFlightBytes) + .build())); + } + + @Test + public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + Set scheduledWorkItems = new HashSet<>(); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> + scheduledWorkItems.add(work)); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testOnResponse_stopsThrottling() { + ThrottleTimer throttleTimer = new ThrottleTimer(); + TestGetWorkRequestObserver requestObserver = + new TestGetWorkRequestObserver(new CountDownLatch(1)); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), throttleTimer, NO_OP_WORK_ITEM_SCHEDULER); + stream.startThrottleTimer(); + assertTrue(throttleTimer.throttled()); + testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance()); + assertFalse(throttleTimer.throttled()); + } + + private static class GetWorkStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetWorkRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetWorkStreamTestStub(TestGetWorkRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getWorkStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + + private void injectResponse(Windmill.StreamingGetWorkResponseChunk responseChunk) { + checkNotNull(responseObserver).onNext(responseChunk); + } + } + + private static class TestGetWorkRequestObserver + implements StreamObserver { + private final List requests = + Collections.synchronizedList(new ArrayList<>()); + private final CountDownLatch waitForRequests; + private @Nullable volatile StreamObserver + responseObserver; + + public TestGetWorkRequestObserver(CountDownLatch waitForRequests) { + this.waitForRequests = waitForRequests; + } + + @Override + public void onNext(Windmill.StreamingGetWorkRequest request) { + requests.add(request); + waitForRequests.countDown(); + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + + List sent() { + return requests; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java index 3f746d91a868..c04456906ea2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java @@ -34,7 +34,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.hamcrest.Matcher; import org.junit.Test; @@ -55,9 +54,6 @@ public static class RespectsJobSettingTest { public void createsNewStubWhenIsolatedChannelsConfigIsChanged() { DataflowWorkerHarnessOptions options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - options.setExperiments( - Lists.newArrayList( - GrpcDispatcherClient.STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)); GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); // Create first time with Isolated channels disabled @@ -91,27 +87,18 @@ public static class RespectsPipelineOptionsTest { public static Collection data() { List list = new ArrayList<>(); for (Boolean pipelineOption : new Boolean[] {true, false}) { - list.add(new Object[] {/*experimentEnabled=*/ false, pipelineOption}); - list.add(new Object[] {/*experimentEnabled=*/ true, pipelineOption}); + list.add(new Object[] {pipelineOption}); } return list; } @Parameter(0) - public Boolean experimentEnabled; - - @Parameter(1) public Boolean pipelineOption; @Test public void ignoresIsolatedChannelsConfigWithPipelineOption() { DataflowWorkerHarnessOptions options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - if (experimentEnabled) { - options.setExperiments( - Lists.newArrayList( - GrpcDispatcherClient.STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)); - } options.setUseWindmillIsolatedChannels(pipelineOption); GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java index d06ed0f526c7..8d2623c382e9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java @@ -30,6 +30,8 @@ import java.io.Closeable; import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; import java.nio.charset.StandardCharsets; import java.util.AbstractMap; import java.util.AbstractMap.SimpleEntry; @@ -305,6 +307,26 @@ private K userKeyFromProtoKey(ByteString tag, Coder keyCoder) throws IOEx return keyCoder.decode(keyBytes.newInput(), Context.OUTER); } + private static void assertBuildable( + Windmill.WorkItemCommitRequest.Builder commitWorkRequestBuilder) { + Windmill.WorkItemCommitRequest.Builder clone = commitWorkRequestBuilder.clone(); + if (!clone.hasKey()) { + clone.setKey(ByteString.EMPTY); // key is required to build + } + if (!clone.hasWorkToken()) { + clone.setWorkToken(1357924680L); // workToken is required to build + } + + try { + clone.build(); + } catch (Exception e) { + StringWriter sw = new StringWriter(); + e.printStackTrace(new PrintWriter(sw)); + fail( + "Failed to build commitRequest from: " + commitWorkRequestBuilder + "\n" + sw.toString()); + } + } + @Test public void testMapAddBeforeGet() throws Exception { StateTag> addr = @@ -647,6 +669,8 @@ public void testMapAddPersist() throws Exception { .map(tv -> fromTagValue(tv, StringUtf8Coder.of(), VarIntCoder.of())) .collect(Collectors.toList()), Matchers.containsInAnyOrder(new SimpleEntry<>(tag1, 1), new SimpleEntry<>(tag2, 2))); + + assertBuildable(commitBuilder); } @Test @@ -670,6 +694,8 @@ public void testMapRemovePersist() throws Exception { .map(tv -> fromTagValue(tv, StringUtf8Coder.of(), VarIntCoder.of())) .collect(Collectors.toList()), Matchers.containsInAnyOrder(new SimpleEntry<>(tag1, null), new SimpleEntry<>(tag2, null))); + + assertBuildable(commitBuilder); } @Test @@ -695,6 +721,8 @@ public void testMapClearPersist() throws Exception { assertEquals( protoKeyFromUserKey(null, StringUtf8Coder.of()), commitBuilder.getTagValuePrefixDeletes(0).getTagPrefix()); + + assertBuildable(commitBuilder); } @Test @@ -736,6 +764,8 @@ public void testMapComplexPersist() throws Exception { commitBuilder = Windmill.WorkItemCommitRequest.newBuilder(); assertEquals(0, commitBuilder.getTagValuePrefixDeletesCount()); assertEquals(0, commitBuilder.getValueUpdatesCount()); + + assertBuildable(commitBuilder); } @Test @@ -953,6 +983,8 @@ public void testMultimapRemovePersistPut() { multimapState.put(key, 5); assertThat(multimapState.get(key).read(), Matchers.containsInAnyOrder(4, 5)); + + assertBuildable(commitBuilder); } @Test @@ -1766,6 +1798,8 @@ public void testMultimapPutAndPersist() { builder, new MultimapEntryUpdate(key1, Arrays.asList(1, 2), false), new MultimapEntryUpdate(key2, Collections.singletonList(2), false)); + + assertBuildable(commitBuilder); } @Test @@ -1799,6 +1833,8 @@ public void testMultimapRemovePutAndPersist() { builder, new MultimapEntryUpdate(key1, Arrays.asList(1, 2), true), new MultimapEntryUpdate(key2, Collections.singletonList(4), true)); + + assertBuildable(commitBuilder); } @Test @@ -1825,6 +1861,8 @@ public void testMultimapRemoveAndPersist() { builder, new MultimapEntryUpdate(key1, Collections.emptyList(), true), new MultimapEntryUpdate(key2, Collections.emptyList(), true)); + + assertBuildable(commitBuilder); } @Test @@ -1856,6 +1894,8 @@ public void testMultimapPutRemoveClearAndPersist() { Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); assertEquals(0, builder.getUpdatesCount()); assertTrue(builder.getDeleteAll()); + + assertBuildable(commitBuilder); } @Test @@ -1894,6 +1934,8 @@ false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); assertTagMultimapUpdates( builder, new MultimapEntryUpdate(key1, Collections.singletonList(4), false)); + + assertBuildable(commitBuilder); } @Test @@ -1938,6 +1980,8 @@ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) ByteArrayCoder.of().decode(entryUpdate.getEntryName().newInput(), Context.OUTER); assertArrayEquals(key1, decodedKey); assertTrue(entryUpdate.getDeleteAll()); + + assertBuildable(commitBuilder); } @Test @@ -2053,6 +2097,8 @@ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) Windmill.WorkItemCommitRequest.Builder commitBuilder = Windmill.WorkItemCommitRequest.newBuilder(); underTest.persist(commitBuilder); + + assertBuildable(commitBuilder); } @Test @@ -2253,6 +2299,8 @@ public void testOrderedListAddPersist() throws Exception { assertEquals("hello", updates.getInserts(0).getEntries(0).getValue().toStringUtf8()); assertEquals(1000, updates.getInserts(0).getEntries(0).getSortKey()); assertEquals(IdTracker.NEW_RANGE_MIN_ID, updates.getInserts(0).getEntries(0).getId()); + + assertBuildable(commitBuilder); } @Test @@ -2284,6 +2332,8 @@ public void testOrderedListClearPersist() throws Exception { assertEquals(IdTracker.NEW_RANGE_MIN_ID, updates.getInserts(0).getEntries(0).getId()); assertEquals(IdTracker.NEW_RANGE_MIN_ID + 1, updates.getInserts(0).getEntries(1).getId()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2331,6 +2381,8 @@ public void testOrderedListDeleteRangePersist() { assertEquals(4000, updates.getInserts(0).getEntries(1).getSortKey()); assertEquals(IdTracker.NEW_RANGE_MIN_ID, updates.getInserts(0).getEntries(0).getId()); assertEquals(IdTracker.NEW_RANGE_MIN_ID + 1, updates.getInserts(0).getEntries(1).getId()); + + assertBuildable(commitBuilder); } @Test @@ -2539,6 +2591,8 @@ public void testOrderedListPersistEmpty() throws Exception { assertEquals(1, updates.getDeletesCount()); assertEquals(WindmillOrderedList.MIN_TS_MICROS, updates.getDeletes(0).getRange().getStart()); assertEquals(WindmillOrderedList.MAX_TS_MICROS, updates.getDeletes(0).getRange().getLimit()); + + assertBuildable(commitBuilder); } @Test @@ -2653,6 +2707,8 @@ public void testBagAddPersist() throws Exception { assertEquals("hello", bagUpdates.getValues(0).toStringUtf8()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2678,6 +2734,8 @@ public void testBagClearPersist() throws Exception { assertEquals("world", tagBag.getValues(0).toStringUtf8()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2693,6 +2751,8 @@ public void testBagPersistEmpty() throws Exception { // 1 bag update = the clear assertEquals(1, commitBuilder.getBagUpdatesCount()); + + assertBuildable(commitBuilder); } @Test @@ -2806,6 +2866,8 @@ public void testCombiningAddPersist() throws Exception { 11, CoderUtils.decodeFromByteArray(accumCoder, bagUpdates.getValues(0).toByteArray())[0]); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2835,6 +2897,8 @@ public void testCombiningAddPersistWithCompact() throws Exception { assertTrue(bagUpdates.getDeleteAll()); assertEquals( 111, CoderUtils.decodeFromByteArray(accumCoder, bagUpdates.getValues(0).toByteArray())[0]); + + assertBuildable(commitBuilder); } @Test @@ -2862,6 +2926,8 @@ public void testCombiningClearPersist() throws Exception { 11, CoderUtils.decodeFromByteArray(accumCoder, tagBag.getValues(0).toByteArray())[0]); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2990,6 +3056,8 @@ public void testWatermarkPersistEarliest() throws Exception { assertEquals(TimeUnit.MILLISECONDS.toMicros(1000), watermarkHold.getTimestamps(0)); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3016,6 +3084,8 @@ public void testWatermarkPersistLatestEmpty() throws Exception { Mockito.verify(mockReader).watermarkFuture(key(NAMESPACE, "watermark"), STATE_FAMILY); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3042,6 +3112,8 @@ public void testWatermarkPersistLatestWindmillWins() throws Exception { Mockito.verify(mockReader).watermarkFuture(key(NAMESPACE, "watermark"), STATE_FAMILY); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3068,6 +3140,8 @@ public void testWatermarkPersistLatestLocalAdditionsWin() throws Exception { Mockito.verify(mockReader).watermarkFuture(key(NAMESPACE, "watermark"), STATE_FAMILY); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3091,6 +3165,8 @@ public void testWatermarkPersistEndOfWindow() throws Exception { // Blind adds should not need to read the future. Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3116,6 +3192,8 @@ public void testWatermarkClearPersist() throws Exception { assertEquals(TimeUnit.MILLISECONDS.toMicros(1000), clearAndUpdate.getTimestamps(0)); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3133,6 +3211,8 @@ public void testWatermarkPersistEmpty() throws Exception { // 1 bag update corresponds to deletion. There shouldn't be a bag update adding items. assertEquals(1, commitBuilder.getWatermarkHoldsCount()); + + assertBuildable(commitBuilder); } @Test @@ -3200,6 +3280,8 @@ public void testValueSetPersist() throws Exception { assertTrue(valueUpdate.isInitialized()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3220,6 +3302,8 @@ public void testValueClearPersist() throws Exception { assertEquals(0, valueUpdate.getValue().getData().size()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3234,6 +3318,8 @@ public void testValueNoChangePersist() throws Exception { assertEquals(0, commitBuilder.getValueUpdatesCount()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java index 3cda4559c100..c76d5a584184 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java @@ -17,9 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -40,169 +38,79 @@ public class EvenGetWorkBudgetDistributorTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - private static GetWorkBudgetDistributor createBudgetDistributor(GetWorkBudget activeWorkBudget) { - return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget); - } + private static GetWorkBudgetSpender createGetWorkBudgetOwner() { + // Lambdas are final and cannot be spied. + return spy( + new GetWorkBudgetSpender() { - private static GetWorkBudgetDistributor createBudgetDistributor(long activeWorkItemsAndBytes) { - return createBudgetDistributor( - GetWorkBudget.builder() - .setItems(activeWorkItemsAndBytes) - .setBytes(activeWorkItemsAndBytes) - .build()); + @Override + public void setBudget(long items, long bytes) {} + }); } @Test public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() { - createBudgetDistributor(1L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.of(), GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); } @Test public void testDistributeBudget_doesNothingWithNoBudget() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())); - createBudgetDistributor(1L) + GetWorkBudgetSpender getWorkBudgetSpender = createGetWorkBudgetOwner(); + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget(ImmutableList.of(getWorkBudgetSpender), GetWorkBudget.noBudget()); verifyNoInteractions(getWorkBudgetSpender); } @Test - public void testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(10L).setBytes(10L).build())); - createBudgetDistributor(0L) - .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(5L).setBytes(5L).build())); - createBudgetDistributor(10L) + public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { + int totalStreams = 10; + long totalItems = 10L; + long totalBytes = 100L; + List streams = new ArrayList<>(); + for (int i = 0; i < totalStreams; i++) { + streams.add(createGetWorkBudgetOwner()); + } + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(20L).setBytes(20L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq( - totalGetWorkBudget.items() - - streamRemainingBudget.items() - - activeWorkItemsAndBytes), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); + ImmutableList.copyOf(streams), + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq( - totalGetWorkBudget.bytes() - - streamRemainingBudget.bytes() - - activeWorkItemsAndBytes)); + streams.forEach( + stream -> + verify(stream, times(1)) + .setBudget(eq(GetWorkBudget.builder().setItems(1L).setBytes(10L).build()))); } @Test - public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { - long totalItemsAndBytes = 10L; + public void testDistributeBudget_distributesFairlyWhenNotEven() { + long totalItems = 10L; + long totalBytes = 19L; List streams = new ArrayList<>(); - for (int i = 0; i < totalItemsAndBytes; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < 3; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), - GetWorkBudget.builder() - .setItems(totalItemsAndBytes) - .setBytes(totalItemsAndBytes) - .build()); + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); + .setBudget(eq(GetWorkBudget.builder().setItems(4L).setBytes(7L).build()))); } @Test - public void testDistributeBudget_distributesFairlyWhenNotEven() { + public void testDistributeBudget_distributesBudgetEvenly() { long totalItemsAndBytes = 10L; List streams = new ArrayList<>(); - for (int i = 0; i < 3; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < totalItemsAndBytes; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), GetWorkBudget.builder() @@ -210,24 +118,10 @@ public void testDistributeBudget_distributesFairlyWhenNotEven() { .setBytes(totalItemsAndBytes) .build()); - long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes / (streams.size() * 1.0)); + long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); - } - - private GetWorkBudgetSpender createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget getWorkBudget) { - return spy( - new GetWorkBudgetSpender() { - @Override - public void adjustBudget(long itemsDelta, long bytesDelta) {} - - @Override - public GetWorkBudget remainingBudget() { - return getWorkBudget; - } - }); + .setBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); } } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java index 72fa991c1f73..470692e75103 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java @@ -140,6 +140,8 @@ public RemoteEnvironment createEnvironment(Environment environment, String worke try { fnHarness.get(); } catch (Throwable t) { + // Print stacktrace to stderr. Could be useful if underlying error not surfaced earlier + t.printStackTrace(); executor.shutdownNow(); } }); diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java index 874748d7b975..49120d38f1f1 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java @@ -248,7 +248,7 @@ public void launchSdkHarness(PipelineOptions options) throws Exception { } }); InstructionRequestHandler controlClient = - clientPool.getSource().take(WORKER_ID, java.time.Duration.ofSeconds(2)); + clientPool.getSource().take(WORKER_ID, java.time.Duration.ofSeconds(10)); this.controlClient = SdkHarnessClient.usingFnApiClient(controlClient, dataServer.getService()); } diff --git a/runners/portability/java/build.gradle b/runners/portability/java/build.gradle index b684299c3174..a82759b4e4a0 100644 --- a/runners/portability/java/build.gradle +++ b/runners/portability/java/build.gradle @@ -253,7 +253,7 @@ tasks.register("validatesRunnerSickbay", Test) { } task ulrDockerValidatesRunner { - dependsOn createUlrValidatesRunnerTask("ulrDockerValidatesRunnerTests", "DOCKER", ":sdks:java:container:java8:docker") + dependsOn createUlrValidatesRunnerTask("ulrDockerValidatesRunnerTests", "DOCKER", ":sdks:java:container:${project.ext.currentJavaVersion}:docker") } task ulrLoopbackValidatesRunner { diff --git a/runners/prism/build.gradle b/runners/prism/build.gradle index 711a1aa2dd75..1009b9856e71 100644 --- a/runners/prism/build.gradle +++ b/runners/prism/build.gradle @@ -42,6 +42,9 @@ ext.set('buildTarget', buildTarget) def buildTask = tasks.named("build") { // goPrepare is a task registered in applyGoNature. dependsOn("goPrepare") + // Allow Go to manage the caching, not gradle. + outputs.cacheIf { false } + outputs.upToDateWhen { false } doLast { exec { workingDir = modDir diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle index de9a30ad8189..f2dfa2bb1a28 100644 --- a/runners/prism/java/build.gradle +++ b/runners/prism/java/build.gradle @@ -16,6 +16,8 @@ * limitations under the License. */ +import groovy.json.JsonOutput + plugins { id 'org.apache.beam.module' } applyJavaNature( @@ -43,3 +45,242 @@ tasks.test { var prismBuildTask = dependsOn(':runners:prism:build') systemProperty 'prism.buildTarget', prismBuildTask.project.property('buildTarget').toString() } + +// Below is configuration to support running the Java Validates Runner tests. + +configurations { + validatesRunner +} + +dependencies { + implementation project(path: ":sdks:java:core", configuration: "shadow") + implementation library.java.hamcrest + permitUnusedDeclared library.java.hamcrest + implementation library.java.joda_time + implementation library.java.slf4j_api + implementation library.java.vendored_guava_32_1_2_jre + + testImplementation library.java.hamcrest + testImplementation library.java.junit + testImplementation library.java.mockito_core + testImplementation library.java.slf4j_jdk14 + + validatesRunner project(path: ":sdks:java:core", configuration: "shadowTest") + validatesRunner project(path: ":runners:core-java", configuration: "testRuntimeMigration") + validatesRunner project(path: project.path, configuration: "testRuntimeMigration") +} + +project.evaluationDependsOn(":sdks:java:core") +project.evaluationDependsOn(":runners:core-java") + +def sickbayTests = [ + // PortableMetrics doesn't implement "getCommitedOrNull" from Metrics + // Preventing Prism from passing these tests. + // In particular, it doesn't subclass MetricResult with an override, and + // it explicilty passes "false" to commited supported in create. + // + // There is not currently a category for excluding these _only_ in committed mode + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testAllCommittedMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedCounterMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedDistributionMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedStringSetMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedGaugeMetrics', + + // Triggers / Accumulation modes not yet implemented in prism. + // https://github.com/apache/beam/issues/31438 + 'org.apache.beam.sdk.transforms.CombineTest$WindowingTests.testGlobalCombineWithDefaultsAndTriggers', + 'org.apache.beam.sdk.transforms.CombineTest$BasicTests.testHotKeyCombiningWithAccumulationMode', + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testNoWindowFnDoesNotReassignWindows', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testCombiningAccumulatingProcessingTime', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerEarly', + 'org.apache.beam.sdk.transforms.ParDoTest$BundleInvariantsTests.testWatermarkUpdateMidBundle', + 'org.apache.beam.sdk.transforms.ViewTest.testTriggeredLatestSingleton', + // Requires Allowed Lateness, among others. + 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testEventTimeTimerSetWithinAllowedLateness', + 'org.apache.beam.sdk.testing.TestStreamTest.testFirstElementLate', + 'org.apache.beam.sdk.testing.TestStreamTest.testDiscardingMode', + 'org.apache.beam.sdk.testing.TestStreamTest.testEarlyPanesOfWindow', + 'org.apache.beam.sdk.testing.TestStreamTest.testElementsAtAlmostPositiveInfinity', + 'org.apache.beam.sdk.testing.TestStreamTest.testLateDataAccumulating', + 'org.apache.beam.sdk.testing.TestStreamTest.testMultipleStreams', + 'org.apache.beam.sdk.testing.TestStreamTest.testProcessingTimeTrigger', + + // Coding error somehow: short write: reached end of stream after reading 5 bytes; 98 bytes expected + 'org.apache.beam.sdk.testing.TestStreamTest.testMultiStage', + + // Prism not firing sessions correctly (seems to be merging inapppropriately) + 'org.apache.beam.sdk.transforms.CombineTest$WindowingTests.testSessionsCombine', + 'org.apache.beam.sdk.transforms.CombineTest$WindowingTests.testSessionsCombineWithContext', + + // Java side dying during execution. + // https://github.com/apache/beam/issues/32930 + 'org.apache.beam.sdk.transforms.FlattenTest.testFlattenMultipleCoders', + // Stream corruption error java side: failed:java.io.StreamCorruptedException: invalid stream header: 206E6F74 + // Likely due to prism't coder changes. + 'org.apache.beam.sdk.transforms.FlattenTest.testFlattenWithDifferentInputAndOutputCoders2', + + // java.lang.IllegalStateException: Output with tag Tag must have a schema in order to call getRowReceiver + // Ultimately because getRoeReceiver code path SDK side isn't friendly to LengthPrefix wrapping of row coders. + // https://github.com/apache/beam/issues/32931 + 'org.apache.beam.sdk.transforms.ParDoSchemaTest.testReadAndWrite', + 'org.apache.beam.sdk.transforms.ParDoSchemaTest.testReadAndWriteMultiOutput', + 'org.apache.beam.sdk.transforms.ParDoSchemaTest.testReadAndWriteWithSchemaRegistry', + + // Technically these tests "succeed" + // the test is just complaining that an AssertionException isn't a RuntimeException + // + // java.lang.RuntimeException: test error in finalize + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests.testParDoWithErrorInFinishBatch', + // java.lang.RuntimeException: test error in process + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests.testParDoWithErrorInProcessElement', + // java.lang.RuntimeException: test error in initialize + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests.testParDoWithErrorInStartBatch', + + // Only known window fns supported, not general window merging + // Custom window fns not yet implemented in prism. + // https://github.com/apache/beam/issues/31921 + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testMergingCustomWindows', + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testMergingCustomWindowsKeyedCollection', + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testMergingCustomWindowsWithoutCustomWindowTypes', + 'org.apache.beam.sdk.transforms.windowing.WindowingTest.testMergingWindowing', + 'org.apache.beam.sdk.transforms.windowing.WindowingTest.testNonPartitioningWindowing', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$WindowTests.testGroupByKeyMergingWindows', + + // Possibly a different error being hidden behind the main error. + // org.apache.beam.sdk.util.WindowedValue$ValueInGlobalWindow cannot be cast to class java.lang.String + // TODO(https://github.com/apache/beam/issues/29973) + 'org.apache.beam.sdk.transforms.ReshuffleTest.testReshufflePreservesMetadata', + // TODO(https://github.com/apache/beam/issues/31231) + 'org.apache.beam.sdk.transforms.RedistributeTest.testRedistributePreservesMetadata', + + // Prism isn't handling Java's side input views properly. + // https://github.com/apache/beam/issues/32932 + // java.lang.IllegalArgumentException: PCollection with more than one element accessed as a singleton view. + // Consider using Combine.globally().asSingleton() to combine the PCollection into a single value + 'org.apache.beam.sdk.transforms.ViewTest.testDiscardingNonSingletonSideInput', + // java.util.NoSuchElementException: Empty PCollection accessed as a singleton view. + 'org.apache.beam.sdk.transforms.ViewTest.testDiscardingNonSingletonSideInput', + // ava.lang.IllegalArgumentException: Duplicate values for a + 'org.apache.beam.sdk.transforms.ViewTest.testMapSideInputWithNullValuesCatchesDuplicates', + // java.lang.IllegalArgumentException: PCollection with more than one element accessed as a singleton view.... + 'org.apache.beam.sdk.transforms.ViewTest.testNonSingletonSideInput', + // java.util.NoSuchElementException: Empty PCollection accessed as a singleton view. + 'org.apache.beam.sdk.transforms.ViewTest.testEmptySingletonSideInput', + // Prism side encoding error. + // java.lang.IllegalStateException: java.io.EOFException + 'org.apache.beam.sdk.transforms.ViewTest.testSideInputWithNestedIterables', + + // Requires Time Sorted Input + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInput', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithTestStream', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithLateDataAndAllowedLateness', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testTwoRequiresTimeSortedInputWithLateData', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithLateData', + + // Timer race condition/ordering issue in Prism. + 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testTwoTimersSettingEachOtherWithCreateAsInputUnbounded', + + // Missing output due to timer skew. + 'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testProcessElementSkew', + + // TestStream + BundleFinalization. + // Tests seem to assume individual element bundles from test stream, but prism will aggregate them, preventing + // a subsequent firing. Tests ultimately hang until timeout. + // Either a test problem, or a misunderstanding of how test stream must work problem in prism. + // Biased to test problem, due to how they are constructed. + 'org.apache.beam.sdk.transforms.ParDoTest$BundleFinalizationTests.testBundleFinalization', + 'org.apache.beam.sdk.transforms.ParDoTest$BundleFinalizationTests.testBundleFinalizationWithSideInputs', + + // Filtered by PortableRunner tests. + // Teardown not called in exceptions + // https://github.com/apache/beam/issues/20372 + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful', +] + +/** + * Runs Java ValidatesRunner tests against the Prism Runner + * with the specified environment type. + */ +def createPrismValidatesRunnerTask = { name, environmentType -> + Task vrTask = tasks.create(name: name, type: Test, group: "Verification") { + description "PrismRunner Java $environmentType ValidatesRunner suite" + classpath = configurations.validatesRunner + + var prismBuildTask = dependsOn(':runners:prism:build') + systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ + "--runner=TestPrismRunner", + "--experiments=beam_fn_api", + "--defaultEnvironmentType=${environmentType}", + "--prismLogLevel=warn", + "--prismLocation=${prismBuildTask.project.property('buildTarget').toString()}", + "--enableWebUI=false", + ]) + testClassesDirs = files(project(":sdks:java:core").sourceSets.test.output.classesDirs) + useJUnit { + includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' + // Should be run only in a properly configured SDK harness environment + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' + excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' + + // Not yet implemented in Prism + // https://github.com/apache/beam/issues/32211 + excludeCategories 'org.apache.beam.sdk.testing.UsesOnWindowExpiration' + // https://github.com/apache/beam/issues/32929 + excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState' + + // Not supported in Portable Java SDK yet. + // https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState + excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' + } + filter { + // Hangs forever with prism. Put here instead of sickbay to allow sickbay runs to terminate. + // https://github.com/apache/beam/issues/32222 + excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testEventTimeTimerOrderingWithCreate' + + for (String test : sickbayTests) { + excludeTestsMatching test + } + } + } + return vrTask +} + +tasks.register("validatesRunnerSickbay", Test) { + group = "Verification" + description "Validates Prism local runner (Sickbay Tests)" + + var prismBuildTask = dependsOn(':runners:prism:build') + systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ + "--runner=TestPrismRunner", + "--experiments=beam_fn_api", + "--enableWebUI=false", + "--prismLogLevel=warn", + "--prismLocation=${prismBuildTask.project.property('buildTarget').toString()}" + ]) + + classpath = configurations.validatesRunner + testClassesDirs = files(project(":sdks:java:core").sourceSets.test.output.classesDirs) + + filter { + for (String test : sickbayTests) { + includeTestsMatching test + } + } +} + +task prismDockerValidatesRunner { + Task vrTask = createPrismValidatesRunnerTask("prismDockerValidatesRunnerTests", "DOCKER") + vrTask.dependsOn ":sdks:java:container:${project.ext.currentJavaVersion}:docker" +} + +task prismLoopbackValidatesRunner { + dependsOn createPrismValidatesRunnerTask("prismLoopbackValidatesRunnerTests", "LOOPBACK") +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java index fda5db923a7f..111d937fcbf6 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java @@ -31,6 +31,8 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.slf4j.Logger; @@ -48,6 +50,7 @@ abstract class PrismExecutor { static final String IDLE_SHUTDOWN_TIMEOUT = "-idle_shutdown_timeout=%s"; static final String JOB_PORT_FLAG_TEMPLATE = "-job_port=%s"; static final String SERVE_HTTP_FLAG_TEMPLATE = "-serve_http=%s"; + static final String LOG_LEVEL_FLAG_TEMPLATE = "-log_level=%s"; protected @MonotonicNonNull Process process; protected ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -157,6 +160,16 @@ abstract static class Builder { abstract Builder setArguments(List arguments); + Builder addArguments(List arguments) { + Optional> original = getArguments(); + if (!original.isPresent()) { + return this.setArguments(arguments); + } + List newArguments = + Stream.concat(original.get().stream(), arguments.stream()).collect(Collectors.toList()); + return this.setArguments(newArguments); + } + abstract Optional> getArguments(); abstract PrismExecutor autoBuild(); diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java index 27aea3f64df0..b32f03e78e6a 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java @@ -110,6 +110,12 @@ String resolveSource() { String resolve() throws IOException { String from = resolveSource(); + // If the location is set, and it's not an http request or a zip, + // use the binary directly. + if (!from.startsWith("http") && !from.endsWith("zip") && Files.exists(Paths.get(from))) { + return from; + } + String fromFileName = getNameWithoutExtension(from); Path to = Paths.get(userHome(), PRISM_BIN_PATH, fromFileName); diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java index 9b280d0a70d4..ceec1ad8268a 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java @@ -59,4 +59,10 @@ public interface PrismPipelineOptions extends PortablePipelineOptions { String getIdleShutdownTimeout(); void setIdleShutdownTimeout(String idleShutdownTimeout); + + @Description("Sets the log level for Prism. Can be set to 'debug', 'info', 'warn', or 'error'.") + @Default.String("warn") + String getPrismLogLevel(); + + void setPrismLogLevel(String prismLogLevel); } diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java index 6099db4b63ee..ac1e68237faf 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java @@ -101,13 +101,17 @@ PrismExecutor startPrism() throws IOException { String idleShutdownTimeoutFlag = String.format( PrismExecutor.IDLE_SHUTDOWN_TIMEOUT, prismPipelineOptions.getIdleShutdownTimeout()); + String logLevelFlag = + String.format( + PrismExecutor.LOG_LEVEL_FLAG_TEMPLATE, prismPipelineOptions.getPrismLogLevel()); String endpoint = "localhost:" + port; prismPipelineOptions.setJobEndpoint(endpoint); String command = locator.resolve(); PrismExecutor executor = PrismExecutor.builder() .setCommand(command) - .setArguments(Arrays.asList(portFlag, serveHttpFlag, idleShutdownTimeoutFlag)) + .setArguments( + Arrays.asList(portFlag, serveHttpFlag, idleShutdownTimeoutFlag, logLevelFlag)) .build(); executor.execute(); checkState(executor.isAlive()); diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java index eb497f0a4c43..a81e3e24ee69 100644 --- a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java @@ -59,7 +59,7 @@ public void executeWithStreamRedirectThenStop() throws IOException { sleep(3000L); executor.stop(); String output = outputStream.toString(StandardCharsets.UTF_8.name()); - assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:8073"); + assertThat(output).contains("level=INFO msg=\"Serving JobManagement\" endpoint=localhost:8073"); } @Test @@ -71,7 +71,8 @@ public void executeWithFileOutputThenStop() throws IOException { executor.stop(); try (Stream stream = Files.lines(log.toPath(), StandardCharsets.UTF_8)) { String output = stream.collect(Collectors.joining("\n")); - assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:8073"); + assertThat(output) + .contains("level=INFO msg=\"Serving JobManagement\" endpoint=localhost:8073"); } } @@ -79,21 +80,23 @@ public void executeWithFileOutputThenStop() throws IOException { public void executeWithCustomArgumentsThenStop() throws IOException { PrismExecutor executor = underTest() - .setArguments(Collections.singletonList("-" + JOB_PORT_FLAG_NAME + "=5555")) + .addArguments(Collections.singletonList("-" + JOB_PORT_FLAG_NAME + "=5555")) .build(); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); executor.execute(outputStream); sleep(3000L); executor.stop(); String output = outputStream.toString(StandardCharsets.UTF_8.name()); - assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:5555"); + assertThat(output).contains("level=INFO msg=\"Serving JobManagement\" endpoint=localhost:5555"); } @Test public void executeWithPortFinderThenStop() throws IOException {} private PrismExecutor.Builder underTest() { - return PrismExecutor.builder().setCommand(getLocalPrismBuildOrIgnoreTest()); + return PrismExecutor.builder() + .setCommand(getLocalPrismBuildOrIgnoreTest()) + .setArguments(Collections.singletonList("--log_kind=text")); // disable color control chars } private void sleep(long millis) { diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java index 095d3c9bde61..fa5ba6d37203 100644 --- a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java @@ -134,7 +134,10 @@ public void givenFilePrismLocationOption_thenResolves() throws IOException { PrismLocator underTest = new PrismLocator(options); String got = underTest.resolve(); - assertThat(got).contains(DESTINATION_DIRECTORY.toString()); + // Local file overrides should use the local binary in place, not copy + // to the cache. Doing so prevents using a locally built version. + assertThat(got).doesNotContain(DESTINATION_DIRECTORY.toString()); + assertThat(got).contains(options.getPrismLocation()); Path gotPath = Paths.get(got); assertThat(Files.exists(gotPath)).isTrue(); } diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java index e0bcbed1577c..68674d202cdb 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; @@ -59,6 +60,8 @@ public class TestSamzaRunnerWithTransformMetrics { @Test public void testSamzaRunnerWithDefaultMetrics() { + // TODO(https://github.com/apache/beam/issues/32208) + assumeTrue(System.getProperty("java.version").startsWith("1.")); SamzaPipelineOptions options = PipelineOptionsFactory.create().as(SamzaPipelineOptions.class); InMemoryMetricsReporter inMemoryMetricsReporter = new InMemoryMetricsReporter(); options.setMetricsReporters(ImmutableList.of(inMemoryMetricsReporter)); diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java index 8670d9a46eac..73454cc95421 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.samza.runtime; +import static org.junit.Assume.assumeTrue; + import java.io.Serializable; import java.util.Arrays; import org.apache.beam.sdk.coders.KvCoder; @@ -35,11 +37,19 @@ import org.apache.beam.sdk.values.TimestampedValue; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; /** Tests for GroupByKeyOp. */ public class GroupByKeyOpTest implements Serializable { + + @BeforeClass + public static void beforeClass() { + // TODO(https://github.com/apache/beam/issues/32208) + assumeTrue(System.getProperty("java.version").startsWith("1.")); + } + @Rule public final transient TestPipeline pipeline = TestPipeline.fromOptions( diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java index 004162600179..9409efbcf394 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; import java.io.File; import java.io.IOException; @@ -75,6 +76,7 @@ import org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStorageEngineFactory; import org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStore; import org.apache.samza.system.SystemStreamPartition; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -91,6 +93,12 @@ public class SamzaStoreStateInternalsTest implements Serializable { TestPipeline.fromOptions( PipelineOptionsFactory.fromArgs("--runner=TestSamzaRunner").create()); + @BeforeClass + public static void beforeClass() { + // TODO(https://github.com/apache/beam/issues/32208) + assumeTrue(System.getProperty("java.version").startsWith("1.")); + } + @Test public void testMapStateIterator() { final String stateId = "foo"; diff --git a/runners/spark/job-server/container/Dockerfile b/runners/spark/job-server/container/Dockerfile index ec4a123f2b9d..f5639430a33b 100644 --- a/runners/spark/job-server/container/Dockerfile +++ b/runners/spark/job-server/container/Dockerfile @@ -16,7 +16,7 @@ # limitations under the License. ############################################################################### -FROM openjdk:8 +FROM eclipse-temurin:11 MAINTAINER "Apache Beam " RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y libltdl7 diff --git a/runners/spark/job-server/spark_job_server.gradle b/runners/spark/job-server/spark_job_server.gradle index 5ed5f4277bf4..90109598ed64 100644 --- a/runners/spark/job-server/spark_job_server.gradle +++ b/runners/spark/job-server/spark_job_server.gradle @@ -301,3 +301,7 @@ createCrossLanguageValidatesRunnerTask( "--endpoint localhost:${jobPort}", ], ) + +shadowJar { + outputs.upToDateWhen { false } +} diff --git a/sdks/go.mod b/sdks/go.mod index 91369e9eb70f..ff711cbe91b0 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -20,21 +20,21 @@ // directory. module github.com/apache/beam/sdks/v2 -go 1.21 +go 1.21.0 require ( - cloud.google.com/go/bigquery v1.63.0 + cloud.google.com/go/bigquery v1.63.1 cloud.google.com/go/bigtable v1.33.0 - cloud.google.com/go/datastore v1.19.0 + cloud.google.com/go/datastore v1.20.0 cloud.google.com/go/profiler v0.4.1 - cloud.google.com/go/pubsub v1.43.0 + cloud.google.com/go/pubsub v1.45.1 cloud.google.com/go/spanner v1.70.0 - cloud.google.com/go/storage v1.44.0 - github.com/aws/aws-sdk-go-v2 v1.32.2 - github.com/aws/aws-sdk-go-v2/config v1.27.43 - github.com/aws/aws-sdk-go-v2/credentials v1.17.41 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.32 - github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3 + cloud.google.com/go/storage v1.45.0 + github.com/aws/aws-sdk-go-v2 v1.32.4 + github.com/aws/aws-sdk-go-v2/config v1.28.0 + github.com/aws/aws-sdk-go-v2/credentials v1.17.42 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33 + github.com/aws/aws-sdk-go-v2/service/s3 v1.66.3 github.com/aws/smithy-go v1.22.0 github.com/docker/go-connections v0.5.0 github.com/dustin/go-humanize v1.0.1 @@ -44,7 +44,7 @@ require ( github.com/johannesboyne/gofakes3 v0.0.0-20221110173912-32fb85c5aed6 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 - github.com/nats-io/nats-server/v2 v2.10.18 + github.com/nats-io/nats-server/v2 v2.10.22 github.com/nats-io/nats.go v1.37.0 github.com/proullon/ramsql v0.1.4 github.com/spf13/cobra v1.8.1 @@ -56,10 +56,10 @@ require ( golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.23.0 golang.org/x/sync v0.8.0 - golang.org/x/sys v0.26.0 + golang.org/x/sys v0.27.0 golang.org/x/text v0.19.0 - google.golang.org/api v0.199.0 - google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 + google.golang.org/api v0.203.0 + google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 google.golang.org/grpc v1.67.1 google.golang.org/protobuf v1.35.1 gopkg.in/yaml.v2 v2.4.0 @@ -69,12 +69,13 @@ require ( require ( github.com/avast/retry-go/v4 v4.6.0 github.com/fsouza/fake-gcs-server v1.49.2 + github.com/golang-cz/devslog v0.0.11 golang.org/x/exp v0.0.0-20231006140011-7918f672742d ) require ( cel.dev/expr v0.16.1 // indirect - cloud.google.com/go/auth v0.9.5 // indirect + cloud.google.com/go/auth v0.9.9 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect cloud.google.com/go/monitoring v1.21.1 // indirect dario.cat/mergo v1.0.0 // indirect @@ -116,12 +117,12 @@ require ( go.opentelemetry.io/otel/sdk v1.29.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.29.0 // indirect go.opentelemetry.io/otel/trace v1.29.0 // indirect - golang.org/x/time v0.6.0 // indirect + golang.org/x/time v0.7.0 // indirect google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a // indirect ) require ( - cloud.google.com/go v0.115.1 // indirect + cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/compute/metadata v0.5.2 // indirect cloud.google.com/go/iam v1.2.1 // indirect cloud.google.com/go/longrunning v0.6.1 // indirect @@ -131,18 +132,18 @@ require ( github.com/apache/thrift v0.17.0 // indirect github.com/aws/aws-sdk-go v1.34.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.23 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.32.3 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -166,7 +167,7 @@ require ( github.com/gorilla/handlers v1.5.2 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -193,6 +194,6 @@ require ( golang.org/x/mod v0.20.0 // indirect golang.org/x/tools v0.24.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 // indirect ) diff --git a/sdks/go.sum b/sdks/go.sum index 82de5fa9b95a..c24cb10126c8 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -38,8 +38,8 @@ cloud.google.com/go v0.104.0/go.mod h1:OO6xxXdJyvuJPcEPBLN9BJPD+jep5G1+2U5B5gkRY cloud.google.com/go v0.105.0/go.mod h1:PrLgOJNe5nfE9UMxKxgXj4mD3voiP+YQ6gdt6KMFOKM= cloud.google.com/go v0.107.0/go.mod h1:wpc2eNrD7hXUTy8EKS10jkxpZBjASrORK7goS+3YX2I= cloud.google.com/go v0.110.0/go.mod h1:SJnCLqQ0FCFGSZMUNUf84MV3Aia54kn7pi8st7tMzaY= -cloud.google.com/go v0.115.1 h1:Jo0SM9cQnSkYfp44+v+NQXHpcHqlnRJk2qxh6yvxxxQ= -cloud.google.com/go v0.115.1/go.mod h1:DuujITeaufu3gL68/lOFIirVNJwQeyf5UXyi+Wbgknc= +cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= +cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= cloud.google.com/go/accessapproval v1.4.0/go.mod h1:zybIuC3KpDOvotz59lFe5qxRZx6C75OtwbisN56xYB4= cloud.google.com/go/accessapproval v1.5.0/go.mod h1:HFy3tuiGvMdcd/u+Cu5b9NkO1pEICJ46IR82PoUdplw= cloud.google.com/go/accessapproval v1.6.0/go.mod h1:R0EiYnwV5fsRFiKZkPHr6mwyk2wxUJ30nL4j2pcFY2E= @@ -101,8 +101,8 @@ cloud.google.com/go/assuredworkloads v1.7.0/go.mod h1:z/736/oNmtGAyU47reJgGN+KVo cloud.google.com/go/assuredworkloads v1.8.0/go.mod h1:AsX2cqyNCOvEQC8RMPnoc0yEarXQk6WEKkxYfL6kGIo= cloud.google.com/go/assuredworkloads v1.9.0/go.mod h1:kFuI1P78bplYtT77Tb1hi0FMxM0vVpRC7VVoJC3ZoT0= cloud.google.com/go/assuredworkloads v1.10.0/go.mod h1:kwdUQuXcedVdsIaKgKTp9t0UJkE5+PAVNhdQm4ZVq2E= -cloud.google.com/go/auth v0.9.5 h1:4CTn43Eynw40aFVr3GpPqsQponx2jv0BQpjvajsbbzw= -cloud.google.com/go/auth v0.9.5/go.mod h1:Xo0n7n66eHyOWWCnitop6870Ilwo3PiZyodVkkH1xWM= +cloud.google.com/go/auth v0.9.9 h1:BmtbpNQozo8ZwW2t7QJjnrQtdganSdmqeIBxHxNkEZQ= +cloud.google.com/go/auth v0.9.9/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY= cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc= cloud.google.com/go/automl v1.5.0/go.mod h1:34EjfoFGMZ5sgJ9EoLsRtdPSNZLcfflJR39VbVNS2M0= @@ -133,8 +133,8 @@ cloud.google.com/go/bigquery v1.47.0/go.mod h1:sA9XOgy0A8vQK9+MWhEQTY6Tix87M/Zur cloud.google.com/go/bigquery v1.48.0/go.mod h1:QAwSz+ipNgfL5jxiaK7weyOhzdoAy1zFm0Nf1fysJac= cloud.google.com/go/bigquery v1.49.0/go.mod h1:Sv8hMmTFFYBlt/ftw2uN6dFdQPzBlREY9yBh7Oy7/4Q= cloud.google.com/go/bigquery v1.50.0/go.mod h1:YrleYEh2pSEbgTBZYMJ5SuSr0ML3ypjRB1zgf7pvQLU= -cloud.google.com/go/bigquery v1.63.0 h1:yQFuJXdDukmBkiUUpjX0i1CtHLFU62HqPs/VDvSzaZo= -cloud.google.com/go/bigquery v1.63.0/go.mod h1:TQto6OR4kw27bqjNTGkVk1Vo5PJlTgxvDJn6YEIZL/E= +cloud.google.com/go/bigquery v1.63.1 h1:/6syiWrSpardKNxdvldS5CUTRJX1iIkSPXCjLjiGL+g= +cloud.google.com/go/bigquery v1.63.1/go.mod h1:ufaITfroCk17WTqBhMpi8CRjsfHjMX07pDrQaRKKX2o= cloud.google.com/go/bigtable v1.33.0 h1:2BDaWLRAwXO14DJL/u8crbV2oUbMZkIa2eGq8Yao1bk= cloud.google.com/go/bigtable v1.33.0/go.mod h1:HtpnH4g25VT1pejHRtInlFPnN5sjTxbQlsYBjh9t5l0= cloud.google.com/go/billing v1.4.0/go.mod h1:g9IdKBEFlItS8bTtlrZdVLWSSdSyFUZKXNS02zKMOZY= @@ -210,8 +210,8 @@ cloud.google.com/go/datacatalog v1.8.0/go.mod h1:KYuoVOv9BM8EYz/4eMFxrr4DUKhGIOX cloud.google.com/go/datacatalog v1.8.1/go.mod h1:RJ58z4rMp3gvETA465Vg+ag8BGgBdnRPEMMSTr5Uv+M= cloud.google.com/go/datacatalog v1.12.0/go.mod h1:CWae8rFkfp6LzLumKOnmVh4+Zle4A3NXLzVJ1d1mRm0= cloud.google.com/go/datacatalog v1.13.0/go.mod h1:E4Rj9a5ZtAxcQJlEBTLgMTphfP11/lNaAshpoBgemX8= -cloud.google.com/go/datacatalog v1.22.0 h1:7e5/0B2LYbNx0BcUJbiCT8K2wCtcB5993z/v1JeLIdc= -cloud.google.com/go/datacatalog v1.22.0/go.mod h1:4Wff6GphTY6guF5WphrD76jOdfBiflDiRGFAxq7t//I= +cloud.google.com/go/datacatalog v1.22.1 h1:i0DyKb/o7j+0vgaFtimcRFjYsD6wFw1jpnODYUyiYRs= +cloud.google.com/go/datacatalog v1.22.1/go.mod h1:MscnJl9B2lpYlFoxRjicw19kFTwEke8ReKL5Y/6TWg8= cloud.google.com/go/dataflow v0.6.0/go.mod h1:9QwV89cGoxjjSR9/r7eFDqqjtvbKxAK2BaYU6PVk9UM= cloud.google.com/go/dataflow v0.7.0/go.mod h1:PX526vb4ijFMesO1o202EaUmouZKBpjHsTlCtB4parQ= cloud.google.com/go/dataflow v0.8.0/go.mod h1:Rcf5YgTKPtQyYz8bLYhFoIV/vP39eL7fWNcSOyFfLJE= @@ -240,8 +240,8 @@ cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7 cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/datastore v1.10.0/go.mod h1:PC5UzAmDEkAmkfaknstTYbNpgE49HAgW2J1gcgUfmdM= cloud.google.com/go/datastore v1.11.0/go.mod h1:TvGxBIHCS50u8jzG+AW/ppf87v1of8nwzFNgEZU1D3c= -cloud.google.com/go/datastore v1.19.0 h1:p5H3bUQltOa26GcMRAxPoNwoqGkq5v8ftx9/ZBB35MI= -cloud.google.com/go/datastore v1.19.0/go.mod h1:KGzkszuj87VT8tJe67GuB+qLolfsOt6bZq/KFuWaahc= +cloud.google.com/go/datastore v1.20.0 h1:NNpXoyEqIJmZFc0ACcwBEaXnmscUpcG4NkKnbCePmiM= +cloud.google.com/go/datastore v1.20.0/go.mod h1:uFo3e+aEpRfHgtp5pp0+6M0o147KoPaYNaPAKpfh8Ew= cloud.google.com/go/datastream v1.2.0/go.mod h1:i/uTP8/fZwgATHS/XFu0TcNUhuA0twZxxQ3EyCUQMwo= cloud.google.com/go/datastream v1.3.0/go.mod h1:cqlOX8xlyYF/uxhiKn6Hbv6WjwPPuI9W2M9SAXwaLLQ= cloud.google.com/go/datastream v1.4.0/go.mod h1:h9dpzScPhDTs5noEMQVWP8Wx8AFBRyS0s8KWPx/9r0g= @@ -348,8 +348,8 @@ cloud.google.com/go/kms v1.8.0/go.mod h1:4xFEhYFqvW+4VMELtZyxomGSYtSQKzM178ylFW4 cloud.google.com/go/kms v1.9.0/go.mod h1:qb1tPTgfF9RQP8e1wq4cLFErVuTJv7UsSC915J8dh3w= cloud.google.com/go/kms v1.10.0/go.mod h1:ng3KTUtQQU9bPX3+QGLsflZIHlkbn8amFAMY63m8d24= cloud.google.com/go/kms v1.10.1/go.mod h1:rIWk/TryCkR59GMC3YtHtXeLzd634lBbKenvyySAyYI= -cloud.google.com/go/kms v1.19.0 h1:x0OVJDl6UH1BSX4THKlMfdcFWoE4ruh90ZHuilZekrU= -cloud.google.com/go/kms v1.19.0/go.mod h1:e4imokuPJUc17Trz2s6lEXFDt8bgDmvpVynH39bdrHM= +cloud.google.com/go/kms v1.20.0 h1:uKUvjGqbBlI96xGE669hcVnEMw1Px/Mvfa62dhM5UrY= +cloud.google.com/go/kms v1.20.0/go.mod h1:/dMbFF1tLLFnQV44AoI2GlotbjowyUfgVwezxW291fM= cloud.google.com/go/language v1.4.0/go.mod h1:F9dRpNFQmJbkaop6g0JhSBXCNlO90e1KWx5iDdxbWic= cloud.google.com/go/language v1.6.0/go.mod h1:6dJ8t3B+lUYfStgls25GusK04NLh3eDLQnWM3mdEbhI= cloud.google.com/go/language v1.7.0/go.mod h1:DJ6dYN/W+SQOjF8e1hLQXMF21AkH2w9wiPzPCJa2MIE= @@ -451,8 +451,8 @@ cloud.google.com/go/pubsub v1.26.0/go.mod h1:QgBH3U/jdJy/ftjPhTkyXNj543Tin1pRYcd cloud.google.com/go/pubsub v1.27.1/go.mod h1:hQN39ymbV9geqBnfQq6Xf63yNhUAhv9CZhzp5O6qsW0= cloud.google.com/go/pubsub v1.28.0/go.mod h1:vuXFpwaVoIPQMGXqRyUQigu/AX1S3IWugR9xznmcXX8= cloud.google.com/go/pubsub v1.30.0/go.mod h1:qWi1OPS0B+b5L+Sg6Gmc9zD1Y+HaM0MdUr7LsupY1P4= -cloud.google.com/go/pubsub v1.43.0 h1:s3Qx+F96J7Kwey/uVHdK3QxFLIlOvvw4SfMYw2jFjb4= -cloud.google.com/go/pubsub v1.43.0/go.mod h1:LNLfqItblovg7mHWgU5g84Vhza4J8kTxx0YqIeTzcXY= +cloud.google.com/go/pubsub v1.45.1 h1:ZC/UzYcrmK12THWn1P72z+Pnp2vu/zCZRXyhAfP1hJY= +cloud.google.com/go/pubsub v1.45.1/go.mod h1:3bn7fTmzZFwaUjllitv1WlsNMkqBgGUb3UdMhI54eCc= cloud.google.com/go/pubsublite v1.5.0/go.mod h1:xapqNQ1CuLfGi23Yda/9l4bBCKz/wC3KIJ5gKcxveZg= cloud.google.com/go/pubsublite v1.6.0/go.mod h1:1eFCS0U11xlOuMFV/0iBqw3zP12kddMeCbj/F3FSj9k= cloud.google.com/go/pubsublite v1.7.0/go.mod h1:8hVMwRXfDfvGm3fahVbtDbiLePT3gpoiJYJY+vxWxVM= @@ -561,8 +561,8 @@ cloud.google.com/go/storage v1.23.0/go.mod h1:vOEEDNFnciUMhBeT6hsJIn3ieU5cFRmzeL cloud.google.com/go/storage v1.27.0/go.mod h1:x9DOL8TK/ygDUMieqwfhdpQryTeEkhGKMi80i/iqR2s= cloud.google.com/go/storage v1.28.1/go.mod h1:Qnisd4CqDdo6BGs2AD5LLnEsmSQ80wQ5ogcBBKhU86Y= cloud.google.com/go/storage v1.29.0/go.mod h1:4puEjyTKnku6gfKoTfNOU/W+a9JyuVNxjpS5GBrB8h4= -cloud.google.com/go/storage v1.44.0 h1:abBzXf4UJKMmQ04xxJf9dYM/fNl24KHoTuBjyJDX2AI= -cloud.google.com/go/storage v1.44.0/go.mod h1:wpPblkIuMP5jCB/E48Pz9zIo2S/zD8g+ITmxKkPCITE= +cloud.google.com/go/storage v1.45.0 h1:5av0QcIVj77t+44mV4gffFC/LscFRUhto6UBMB5SimM= +cloud.google.com/go/storage v1.45.0/go.mod h1:wpPblkIuMP5jCB/E48Pz9zIo2S/zD8g+ITmxKkPCITE= cloud.google.com/go/storagetransfer v1.5.0/go.mod h1:dxNzUopWy7RQevYFHewchb29POFv3/AaBgnhqzqiK0w= cloud.google.com/go/storagetransfer v1.6.0/go.mod h1:y77xm4CQV/ZhFZH75PLEXY0ROiS7Gh6pSKrM8dJyg6I= cloud.google.com/go/storagetransfer v1.7.0/go.mod h1:8Giuj1QNb1kfLAiWM1bN6dHzfdlDAVC9rv9abHot2W4= @@ -582,8 +582,8 @@ cloud.google.com/go/trace v1.3.0/go.mod h1:FFUE83d9Ca57C+K8rDl/Ih8LwOzWIV1krKgxg cloud.google.com/go/trace v1.4.0/go.mod h1:UG0v8UBqzusp+z63o7FK74SdFE+AXpCLdFb1rshXG+Y= cloud.google.com/go/trace v1.8.0/go.mod h1:zH7vcsbAhklH8hWFig58HvxcxyQbaIqMarMg9hn5ECA= cloud.google.com/go/trace v1.9.0/go.mod h1:lOQqpE5IaWY0Ixg7/r2SjixMuc6lfTFeO4QGM4dQWOk= -cloud.google.com/go/trace v1.11.0 h1:UHX6cOJm45Zw/KIbqHe4kII8PupLt/V5tscZUkeiJVI= -cloud.google.com/go/trace v1.11.0/go.mod h1:Aiemdi52635dBR7o3zuc9lLjXo3BwGaChEjCa3tJNmM= +cloud.google.com/go/trace v1.11.1 h1:UNqdP+HYYtnm6lb91aNA5JQ0X14GnxkABGlfz2PzPew= +cloud.google.com/go/trace v1.11.1/go.mod h1:IQKNQuBzH72EGaXEodKlNJrWykGZxet2zgjtS60OtjA= cloud.google.com/go/translate v1.3.0/go.mod h1:gzMUwRjvOqj5i69y/LYLd8RrNQk+hOmIXTi9+nb3Djs= cloud.google.com/go/translate v1.4.0/go.mod h1:06Dn/ppvLD6WvA5Rhdp029IX2Mi3Mn7fpMRLPvXT5Wg= cloud.google.com/go/translate v1.5.0/go.mod h1:29YDSYveqqpA1CQFD7NQuP49xymq17RXNaUDdc0mNu0= @@ -689,53 +689,53 @@ github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo= github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250= -github.com/aws/aws-sdk-go-v2 v1.32.2 h1:AkNLZEyYMLnx/Q/mSKkcMqwNFXMAvFto9bNsHqcTduI= -github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= +github.com/aws/aws-sdk-go-v2 v1.32.4 h1:S13INUiTxgrPueTmrm5DZ+MiAo99zYzHEFh1UNkOxNE= +github.com/aws/aws-sdk-go-v2 v1.32.4/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 h1:pT3hpW0cOHRJx8Y0DfJUEQuqPild8jRGmSFmBgvydr0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= -github.com/aws/aws-sdk-go-v2/config v1.27.43 h1:p33fDDihFC390dhhuv8nOmX419wjOSDQRb+USt20RrU= -github.com/aws/aws-sdk-go-v2/config v1.27.43/go.mod h1:pYhbtvg1siOOg8h5an77rXle9tVG8T+BWLWAo7cOukc= +github.com/aws/aws-sdk-go-v2/config v1.28.0 h1:FosVYWcqEtWNxHn8gB/Vs6jOlNwSoyOCA/g/sxyySOQ= +github.com/aws/aws-sdk-go-v2/config v1.28.0/go.mod h1:pYhbtvg1siOOg8h5an77rXle9tVG8T+BWLWAo7cOukc= github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc= -github.com/aws/aws-sdk-go-v2/credentials v1.17.41 h1:7gXo+Axmp+R4Z+AK8YFQO0ZV3L0gizGINCOWxSLY9W8= -github.com/aws/aws-sdk-go-v2/credentials v1.17.41/go.mod h1:u4Eb8d3394YLubphT4jLEwN1rLNq2wFOlT6OuxFwPzU= +github.com/aws/aws-sdk-go-v2/credentials v1.17.42 h1:sBP0RPjBU4neGpIYyx8mkU2QqLPl5u9cmdTWVzIpHkM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.42/go.mod h1:FwZBfU530dJ26rv9saAbxa9Ej3eF/AK0OAY86k13n4M= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 h1:TMH3f/SCAWdNtXXVPPu5D6wrr4G5hI1rAxbcocKfC7Q= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17/go.mod h1:1ZRXLdTpzdJb9fwTMXiLipENRxkGMTn1sfKexGllQCw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18 h1:68jFVtt3NulEzojFesM/WVarlFpCaXLKaBxDpzkQ9OQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18/go.mod h1:Fjnn5jQVIo6VyedMc0/EhPpfNlPl7dHV916O6B+49aE= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.32 h1:C2hE+gJ40Cb4vzhFJ+tTzjvBpPloUq7XP6PD3A2Fk7g= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.32/go.mod h1:0OmMtVNp+10JFBTfmA2AIeqBDm0YthDXmE+N7poaptk= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 h1:UAsR3xA31QGf79WzpG/ixT9FZvQlh5HY1NRqSHBNOCk= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21/go.mod h1:JNr43NFf5L9YaG3eKTm7HQzls9J+A9YYcGI5Quh1r2Y= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 h1:6jZVETqmYCadGFvrYEQfC5fAQmlo80CeL5psbno6r0s= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21/go.mod h1:1SR0GbLlnN3QUmYaflZNiH1ql+1qrSiB2vwcJ+4UM60= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33 h1:X+4YY5kZRI/cOoSMVMGTqFXHAMg1bvvay7IBcqHpybQ= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33/go.mod h1:DPynzu+cn92k5UQ6tZhX+wfTB4ah6QDU/NgdHqatmvk= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 h1:A2w6m6Tmr+BNXjDsr7M90zkWjsu4JXHwrzPg235STs4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23/go.mod h1:35EVp9wyeANdujZruvHiQUAo9E3vbhnIO1mTCAxMlY0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23 h1:pgYW9FCabt2M25MoHYCfMrVY2ghiiBKYWUVXfwZs+sU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23/go.mod h1:c48kLgzO19wAu3CPkDWC28JbaJ+hfQlsdl7I2+oqIbk= github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 h1:7edmS3VOBDhK00b/MwGtGglCm7hhwNYnjJs/PgFdMQE= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21/go.mod h1:Q9o5h4HoIWG8XfzxqiuK/CGUbepCJ8uTlaE3bAbxytQ= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.23 h1:1SZBDiRzzs3sNhOMVApyWPduWYGAX0imGy06XiBnCAM= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.23/go.mod h1:i9TkxgbZmHVh2S0La6CAXtnyFhlCX/pJ0JsOvBAS6Mk= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1/go.mod h1:v33JQ57i2nekYTA70Mb+O18KeH4KqhdqxTJZNK1zdRE= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 h1:TToQNkvGguu209puTojY/ozlqy2d/SFNcoLIqTFi42g= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 h1:4FMHqLfk0efmTqhXVRL5xYRqlEBNBiRI7N6w4jsEdd4= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2/go.mod h1:LWoqeWlK9OZeJxsROW2RqrSPvQHKTpp69r/iDjwsSaw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.4 h1:aaPpoG15S2qHkWm4KlEyF01zovK1nW4BBbyXuHNSE90= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.4/go.mod h1:eD9gS2EARTKgGr/W5xwgY/ik9z/zqpW+m/xOQbVxrMk= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 h1:s7NA1SOw8q/5c0wr8477yOPp0z+uBaXBnLE0XYb0POA= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2/go.mod h1:fnjjWyAW/Pj5HYOxl9LJqWtEwS7W2qgcRLWP+uWbss0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 h1:tHxQi/XHPK0ctd/wdOw0t7Xrc2OxcRCnVzv8lwWPu0c= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4/go.mod h1:4GQbF1vJzG60poZqWatZlhP31y8PGCCVTvIGPdaaYJ0= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 h1:t7iUP9+4wdc5lt3E41huP+GvQZJD38WLsgVp4iOtAjg= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2/go.mod h1:/niFCtmuQNxqx9v8WAPq5qh7EH25U4BF6tjoyq9bObM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.4 h1:E5ZAVOmI2apR8ADb72Q63KqwwwdW1XcMeXIlrZ1Psjg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.4/go.mod h1:wezzqVUOVVdk+2Z/JzQT4NxAU0NbhRe5W8pIE72jsWI= github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32yTv8GpBquCApZEycDLunI= -github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3 h1:xxHGZ+wUgZNACQmxtdvP5tgzfsxGS3vPpTP5Hy3iToE= -github.com/aws/aws-sdk-go-v2/service/s3 v1.65.3/go.mod h1:cB6oAuus7YXRZhWCc1wIwPywwZ1XwweNp2TVAEGYeB8= +github.com/aws/aws-sdk-go-v2/service/s3 v1.66.3 h1:neNOYJl72bHrz9ikAEED4VqWyND/Po0DnEx64RW6YM4= +github.com/aws/aws-sdk-go-v2/service/s3 v1.66.3/go.mod h1:TMhLIyRIyoGVlaEMAt+ITMbwskSTpcGsCPDq91/ihY0= github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM= -github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 h1:bSYXVyUzoTHoKalBmwaZxs97HU9DWWI3ehHSAMa7xOk= -github.com/aws/aws-sdk-go-v2/service/sso v1.24.2/go.mod h1:skMqY7JElusiOUjMJMOv1jJsP7YUg7DrhgqZZWuzu1U= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 h1:AhmO1fHINP9vFYUE0LHzCWg/LfUWUF+zFPEcY9QXb7o= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2/go.mod h1:o8aQygT2+MVP0NaV6kbdE1YnnIM8RRVQzoeUH45GOdI= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.3 h1:UTpsIf0loCIWEbrqdLb+0RxnTXfWh2vhw4nQmFi4nPc= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.3/go.mod h1:FZ9j3PFHHAR+w0BSEjK955w5YD2UwB/l/H0yAK3MJvI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.3 h1:2YCmIXv3tmiItw0LlYf6v7gEHebLY45kBEnPezbUKyU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.3/go.mod h1:u19stRyNPxGhj6dRm+Cdgu6N75qnbW7+QN0q0dsAk58= github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg= -github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 h1:CiS7i0+FUe+/YY1GvIBLLrR/XNGZ4CtM1Ll0XavNuVo= -github.com/aws/aws-sdk-go-v2/service/sts v1.32.2/go.mod h1:HtaiBI8CjYoNVde8arShXb94UbQQi9L4EMr6D+xGBwo= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.3 h1:wVnQ6tigGsRqSWDEEyH6lSAJ9OyFUsSnbaUWChuSGzs= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.3/go.mod h1:VZa9yTFyj4o10YGsmDO4gbQJUvvhY72fhumT8W4LqsE= github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= @@ -853,6 +853,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-cz/devslog v0.0.11 h1:v4Yb9o0ZpuZ/D8ZrtVw1f9q5XrjnkxwHF1XmWwO8IHg= +github.com/golang-cz/devslog v0.0.11/go.mod h1:bSe5bm0A7Nyfqtijf1OMNgVJHlWEuVSXnkuASiE1vV8= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= @@ -1027,8 +1029,8 @@ github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= @@ -1083,8 +1085,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/nats-io/jwt/v2 v2.5.8 h1:uvdSzwWiEGWGXf+0Q+70qv6AQdvcvxrv9hPM0RiPamE= github.com/nats-io/jwt/v2 v2.5.8/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= -github.com/nats-io/nats-server/v2 v2.10.18 h1:tRdZmBuWKVAFYtayqlBB2BuCHNGAQPvoQIXOKwU3WSM= -github.com/nats-io/nats-server/v2 v2.10.18/go.mod h1:97Qyg7YydD8blKlR8yBsUlPlWyZKjA7Bp5cl3MUE9K8= +github.com/nats-io/nats-server/v2 v2.10.22 h1:Yt63BGu2c3DdMoBZNcR6pjGQwk/asrKU7VX846ibxDA= +github.com/nats-io/nats-server/v2 v2.10.22/go.mod h1:X/m1ye9NYansUXYFrbcDwUi/blHkrgHh2rgCJaakonk= github.com/nats-io/nats.go v1.37.0 h1:07rauXbVnnJvv1gfIyghFEo6lUcYRY0WXc3x7x0vUxE= github.com/nats-io/nats.go v1.37.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= @@ -1522,8 +1524,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= @@ -1558,8 +1560,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= -golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1703,8 +1705,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/ google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI= google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0= google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg= -google.golang.org/api v0.199.0 h1:aWUXClp+VFJmqE0JPvpZOK3LDQMyFKYIow4etYd9qxs= -google.golang.org/api v0.199.0/go.mod h1:ohG4qSztDJmZdjK/Ar6MhbAmb/Rpi4JHOqagsh90K28= +google.golang.org/api v0.203.0 h1:SrEeuwU3S11Wlscsn+LA1kb/Y5xT8uggJSkIhD08NAU= +google.golang.org/api v0.203.0/go.mod h1:BuOVyCSYEPwJb3npWvDnNmFI92f3GeRnHNkETneT3SI= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1844,12 +1846,12 @@ google.golang.org/genproto v0.0.0-20230323212658-478b75c54725/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230330154414-c0448cd141ea/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak= google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak= google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 h1:BulPr26Jqjnd4eYDVe+YvyR7Yc2vJGkO5/0UxD0/jZU= -google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:hL97c3SYopEHblzpxRL4lSs523++l8DYxGM1FQiYmb4= -google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 h1:hjSy6tcFQZ171igDaN5QHOw2n6vx40juYbC/x67CEhc= -google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:qpvKtACPCQhAdu3PyQgV4l3LMXZEtft7y8QcarRsp9I= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 h1:Df6WuGvthPzc+JiQ/G+m+sNX24kc0aTBqoDN/0yyykE= +google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53/go.mod h1:fheguH3Am2dGp1LfXkrvwqC/KlFq8F0nLq3LryOMrrE= +google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 h1:T6rh4haD3GVYsgEfWExoCZA2o2FmbNyKpTuAxbEFPTg= +google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:wp2WsuBYj6j8wUdo3ToZsdxxixbvQNAHqVJrTgi5E5M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 h1:X58yt85/IXCx0Y3ZwN6sEIKZzQtDEYaBWrDvErdXrRE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= diff --git a/sdks/go/cmd/prism/prism.go b/sdks/go/cmd/prism/prism.go index 39c19df00dc3..5e3f42a9e5a5 100644 --- a/sdks/go/cmd/prism/prism.go +++ b/sdks/go/cmd/prism/prism.go @@ -22,9 +22,14 @@ import ( "flag" "fmt" "log" + "log/slog" + "os" + "strings" + "time" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism" + "github.com/golang-cz/devslog" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -37,10 +42,59 @@ var ( idleShutdownTimeout = flag.Duration("idle_shutdown_timeout", -1, "duration that prism will wait for a new job before shutting itself down. Negative durations disable auto shutdown. Defaults to never shutting down.") ) +// Logging flags +var ( + logKind = flag.String("log_kind", "dev", + "Determines the format of prism's logging to std err: valid values are `dev', 'json', or 'text'. Default is `dev`.") + logLevelFlag = flag.String("log_level", "info", + "Sets the minimum log level of Prism. Valid options are 'debug', 'info','warn', and 'error'. Default is 'info'. Debug adds prism source lines.") +) + +var logLevel = new(slog.LevelVar) + func main() { flag.Parse() ctx, cancel := context.WithCancelCause(context.Background()) + var logHandler slog.Handler + loggerOutput := os.Stderr + handlerOpts := &slog.HandlerOptions{ + Level: logLevel, + } + switch strings.ToLower(*logLevelFlag) { + case "debug": + logLevel.Set(slog.LevelDebug) + handlerOpts.AddSource = true + case "info": + logLevel.Set(slog.LevelInfo) + case "warn": + logLevel.Set(slog.LevelWarn) + case "error": + logLevel.Set(slog.LevelError) + default: + log.Fatalf("Invalid value for log_level: %v, must be 'debug', 'info', 'warn', or 'error'", *logKind) + } + switch strings.ToLower(*logKind) { + case "dev": + logHandler = + devslog.NewHandler(loggerOutput, &devslog.Options{ + TimeFormat: "[" + time.RFC3339Nano + "]", + StringerFormatter: true, + HandlerOptions: handlerOpts, + StringIndentation: false, + NewLineAfterLog: true, + MaxErrorStackTrace: 3, + }) + case "json": + logHandler = slog.NewJSONHandler(loggerOutput, handlerOpts) + case "text": + logHandler = slog.NewTextHandler(loggerOutput, handlerOpts) + default: + log.Fatalf("Invalid value for log_kind: %v, must be 'dev', 'json', or 'text'", *logKind) + } + + slog.SetDefault(slog.New(logHandler)) + cli, err := makeJobClient(ctx, prism.Options{ Port: *jobPort, diff --git a/sdks/go/container/tools/buffered_logging.go b/sdks/go/container/tools/buffered_logging.go index 445d19fabfdc..a7b84e56af3a 100644 --- a/sdks/go/container/tools/buffered_logging.go +++ b/sdks/go/container/tools/buffered_logging.go @@ -18,13 +18,15 @@ package tools import ( "context" "log" - "math" "os" "strings" "time" ) -const initialLogSize int = 255 +const ( + initialLogSize int = 255 + defaultFlushInterval time.Duration = 15 * time.Second +) // BufferedLogger is a wrapper around the FnAPI logging client meant to be used // in place of stdout and stderr in bootloader subprocesses. Not intended for @@ -41,7 +43,7 @@ type BufferedLogger struct { // NewBufferedLogger returns a new BufferedLogger type by reference. func NewBufferedLogger(logger *Logger) *BufferedLogger { - return &BufferedLogger{logger: logger, lastFlush: time.Now(), flushInterval: time.Duration(math.MaxInt64), periodicFlushContext: context.Background(), now: time.Now} + return &BufferedLogger{logger: logger, lastFlush: time.Now(), flushInterval: defaultFlushInterval, periodicFlushContext: context.Background(), now: time.Now} } // NewBufferedLoggerWithFlushInterval returns a new BufferedLogger type by reference. This type will diff --git a/sdks/go/examples/stringsplit/stringsplit.go b/sdks/go/examples/stringsplit/stringsplit.go index 266cdd99fb37..76140075b625 100644 --- a/sdks/go/examples/stringsplit/stringsplit.go +++ b/sdks/go/examples/stringsplit/stringsplit.go @@ -21,7 +21,7 @@ // 1. From a command line, navigate to the top-level beam/ directory and run // the Flink job server: // -// ./gradlew :runners:flink:1.18:job-server:runShadow -Djob-host=localhost -Dflink-master=local +// ./gradlew :runners:flink:1.19:job-server:runShadow -Djob-host=localhost -Dflink-master=local // // 2. The job server is ready to receive jobs once it outputs a log like the // following: `JobService started on localhost:8099`. Take note of the endpoint diff --git a/sdks/go/examples/wasm/README.md b/sdks/go/examples/wasm/README.md index 84d30a3c6a63..103bef88642b 100644 --- a/sdks/go/examples/wasm/README.md +++ b/sdks/go/examples/wasm/README.md @@ -68,7 +68,7 @@ cd $BEAM_HOME Expected output should include the following, from which you acquire the latest flink runner version. ```shell -'flink_versions: 1.15,1.16,1.17,1.18' +'flink_versions: 1.17,1.18,1.19' ``` #### 2. Set to the latest flink runner version i.e. 1.16 diff --git a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go index c71ead208364..06bb727178fc 100644 --- a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go +++ b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go @@ -19,12 +19,12 @@ import ( "bytes" "fmt" "log" + "log/slog" "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" - "golang.org/x/exp/slog" ) // FromMonitoringInfos extracts metrics from monitored states and diff --git a/sdks/go/pkg/beam/runners/prism/internal/coders.go b/sdks/go/pkg/beam/runners/prism/internal/coders.go index eb8abe16ecf8..ffea90e79065 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/coders.go +++ b/sdks/go/pkg/beam/runners/prism/internal/coders.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "strings" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" @@ -28,7 +29,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go index eaaf7f831712..7b8689f95112 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -18,12 +18,12 @@ package engine import ( "bytes" "fmt" + "log/slog" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" - "golang.org/x/exp/slog" ) // StateData is a "union" between Bag state and MultiMap state to increase common code. diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index f7229853e4d3..3cfde4701a8f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -23,6 +23,7 @@ import ( "context" "fmt" "io" + "log/slog" "sort" "strings" "sync" @@ -36,7 +37,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" ) type element struct { @@ -1607,7 +1607,7 @@ func (ss *stageState) bundleReady(em *ElementManager, emNow mtime.Time) (mtime.T inputW := ss.input _, upstreamW := ss.UpstreamWatermark() if inputW == upstreamW { - slog.Debug("bundleReady: insufficient upstream watermark", + slog.Debug("bundleReady: unchanged upstream watermark", slog.String("stage", ss.ID), slog.Group("watermark", slog.Any("upstream", upstreamW), diff --git a/sdks/go/pkg/beam/runners/prism/internal/environments.go b/sdks/go/pkg/beam/runners/prism/internal/environments.go index add7f769a702..2f960a04f0cb 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/environments.go +++ b/sdks/go/pkg/beam/runners/prism/internal/environments.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "os" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" @@ -27,7 +28,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" - "golang.org/x/exp/slog" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/proto" @@ -42,7 +42,7 @@ import ( // TODO move environment handling to the worker package. func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *worker.W) error { - logger := slog.With(slog.String("envID", wk.Env)) + logger := j.Logger.With(slog.String("envID", wk.Env)) // TODO fix broken abstraction. // We're starting a worker pool here, because that's the loopback environment. // It's sort of a mess, largely because of loopback, which has @@ -56,7 +56,7 @@ func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *wor } go func() { externalEnvironment(ctx, ep, wk) - slog.Debug("environment stopped", slog.String("job", j.String())) + logger.Debug("environment stopped", slog.String("job", j.String())) }() return nil case urns.EnvDocker: @@ -129,6 +129,8 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock credEnv := fmt.Sprintf("%v=%v", gcloudCredsEnv, dockerGcloudCredsFile) envs = append(envs, credEnv) } + } else { + logger.Debug("local GCP credentials environment variable not found") } if _, _, err := cli.ImageInspectWithRaw(ctx, dp.GetContainerImage()); err != nil { // We don't have a local image, so we should pull it. @@ -140,6 +142,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock logger.Warn("unable to pull image and it's not local", "error", err) } } + logger.Debug("creating container", "envs", envs, "mounts", mounts) ccr, err := cli.ContainerCreate(ctx, &container.Config{ Image: dp.GetContainerImage(), @@ -169,17 +172,32 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock return fmt.Errorf("unable to start container image %v with docker for env %v, err: %w", dp.GetContainerImage(), wk.Env, err) } + logger.Debug("container started") + // Start goroutine to wait on container state. go func() { defer cli.Close() defer wk.Stop() + defer func() { + logger.Debug("container stopped") + }() - statusCh, errCh := cli.ContainerWait(ctx, containerID, container.WaitConditionNotRunning) + bgctx := context.Background() + statusCh, errCh := cli.ContainerWait(bgctx, containerID, container.WaitConditionNotRunning) select { case <-ctx.Done(): - // Can't use command context, since it's already canceled here. - err := cli.ContainerKill(context.Background(), containerID, "") + rc, err := cli.ContainerLogs(bgctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) if err != nil { + logger.Error("error fetching container logs error on context cancellation", "error", err) + } + if rc != nil { + defer rc.Close() + var buf bytes.Buffer + stdcopy.StdCopy(&buf, &buf, rc) + logger.Info("container being killed", slog.Any("cause", context.Cause(ctx)), slog.Any("containerLog", buf)) + } + // Can't use command context, since it's already canceled here. + if err := cli.ContainerKill(bgctx, containerID, ""); err != nil { logger.Error("docker container kill error", "error", err) } case err := <-errCh: @@ -189,7 +207,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock case resp := <-statusCh: logger.Info("docker container has self terminated", "status_code", resp.StatusCode) - rc, err := cli.ContainerLogs(ctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) + rc, err := cli.ContainerLogs(bgctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) if err != nil { logger.Error("docker container logs error", "error", err) } diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index d7605f34f5f2..614edee47721 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "log/slog" "sort" "sync/atomic" "time" @@ -34,7 +35,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) @@ -311,7 +311,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic return fmt.Errorf("prism error building stage %v: \n%w", stage.ID, err) } stages[stage.ID] = stage - slog.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) + j.Logger.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) outputs := maps.Keys(stage.OutputsToCoders) sort.Strings(outputs) em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs) @@ -322,9 +322,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic em.StageProcessingTimeTimers(stage.ID, stage.processingTimeTimers) } default: - err := fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) - slog.Error("Execute", err) - return err + return fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) } } @@ -344,11 +342,13 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic for { select { case <-ctx.Done(): - return context.Cause(ctx) + err := context.Cause(ctx) + j.Logger.Debug("context canceled", slog.Any("cause", err)) + return err case rb, ok := <-bundles: if !ok { err := eg.Wait() - slog.Debug("pipeline done!", slog.String("job", j.String()), slog.Any("error", err)) + j.Logger.Debug("pipeline done!", slog.String("job", j.String()), slog.Any("error", err), slog.Any("topo", topo)) return err } eg.Go(func() error { diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go index 8590fd0d4ced..be9d39ad02b7 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go @@ -19,6 +19,7 @@ import ( "bytes" "fmt" "io" + "log/slog" "reflect" "sort" @@ -31,7 +32,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go index 99b786d45980..e42e3e7ca666 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go @@ -20,9 +20,9 @@ import ( "context" "fmt" "io" + "log/slog" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" ) @@ -77,7 +77,7 @@ func (s *Server) ReverseArtifactRetrievalService(stream jobpb.ArtifactStagingSer case *jobpb.ArtifactResponseWrapper_ResolveArtifactResponse: err := fmt.Errorf("unexpected ResolveArtifactResponse to GetArtifact: %v", in.GetResponse()) - slog.Error("GetArtifact failure", err) + slog.Error("GetArtifact failure", slog.Any("error", err)) return err } } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index 1407feafe325..deef259a99d1 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -27,6 +27,7 @@ package jobservices import ( "context" "fmt" + "log/slog" "sort" "strings" "sync" @@ -37,7 +38,6 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/protobuf/types/known/structpb" ) @@ -88,6 +88,8 @@ type Job struct { // Context used to terminate this job. RootCtx context.Context CancelFn context.CancelCauseFunc + // Logger for this job. + Logger *slog.Logger metrics metricsStore } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index b957b99ca63d..a2840760bf7a 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -19,6 +19,7 @@ import ( "context" "errors" "fmt" + "log/slog" "sync" "sync/atomic" @@ -27,7 +28,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -92,6 +92,7 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * cancelFn(err) terminalOnceWrap() }, + Logger: s.logger, // TODO substitute with a configured logger. artifactEndpoint: s.Endpoint(), } // Stop the idle timer when a new job appears. diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go index 03d5b0a98369..bbbdfd1eba4f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go @@ -19,6 +19,7 @@ import ( "bytes" "fmt" "hash/maphash" + "log/slog" "math" "sort" "sync" @@ -28,7 +29,6 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "golang.org/x/exp/constraints" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -589,7 +589,7 @@ func (m *metricsStore) AddShortIDs(resp *fnpb.MonitoringInfosMetadataResponse) { urn := mi.GetUrn() ops, ok := mUrn2Ops[urn] if !ok { - slog.Debug("unknown metrics urn", slog.String("urn", urn)) + slog.Debug("unknown metrics urn", slog.String("shortID", short), slog.String("urn", urn), slog.String("type", mi.Type)) continue } key := ops.keyFn(urn, mi.GetLabels()) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go index 320159f54c06..bdfe2aff2dd4 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -18,6 +18,7 @@ package jobservices import ( "context" "fmt" + "log/slog" "math" "net" "os" @@ -27,7 +28,6 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" - "golang.org/x/exp/slog" "google.golang.org/grpc" ) @@ -53,6 +53,7 @@ type Server struct { terminatedJobCount uint32 // Use with atomics. idleTimeout time.Duration cancelFn context.CancelCauseFunc + logger *slog.Logger // execute defines how a job is executed. execute func(*Job) @@ -71,8 +72,9 @@ func NewServer(port int, execute func(*Job)) *Server { lis: lis, jobs: make(map[string]*Job), execute: execute, + logger: slog.Default(), // TODO substitute with a configured logger. } - slog.Info("Serving JobManagement", slog.String("endpoint", s.Endpoint())) + s.logger.Info("Serving JobManagement", slog.String("endpoint", s.Endpoint())) opts := []grpc.ServerOption{ grpc.MaxRecvMsgSize(math.MaxInt32), } diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go index 7de32f85b7ee..dceaa9ab8fcb 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -17,6 +17,7 @@ package internal import ( "fmt" + "log/slog" "sort" "strings" @@ -26,7 +27,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go index 1be3d3e70841..650932f525c8 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go @@ -18,6 +18,7 @@ package internal_test import ( "context" "fmt" + "log/slog" "net" "net/http" "net/rpc" @@ -34,7 +35,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/stats" - "golang.org/x/exp/slog" ) // separate_test.go retains structures and tests to ensure the runner can @@ -286,7 +286,7 @@ func (ws *Watchers) Check(args *Args, unblocked *bool) error { w.mu.Lock() *unblocked = w.sentinelCount >= w.sentinelCap w.mu.Unlock() - slog.Debug("sentinel target for watcher%d is %d/%d. unblocked=%v", args.WatcherID, w.sentinelCount, w.sentinelCap, *unblocked) + slog.Debug("sentinel watcher status", slog.Int("watcher", args.WatcherID), slog.Int("sentinelCount", w.sentinelCount), slog.Int("sentinelCap", w.sentinelCap), slog.Bool("unblocked", *unblocked)) return nil } @@ -360,7 +360,7 @@ func (fn *sepHarnessBase) setup() error { sepClientOnce.Do(func() { client, err := rpc.DialHTTP("tcp", fn.LocalService) if err != nil { - slog.Error("failed to dial sentinels server", err, slog.String("endpoint", fn.LocalService)) + slog.Error("failed to dial sentinels server", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic(fmt.Sprintf("dialing sentinels server %v: %v", fn.LocalService, err)) } sepClient = client @@ -385,7 +385,7 @@ func (fn *sepHarnessBase) setup() error { var unblock bool err := sepClient.Call("Watchers.Check", &Args{WatcherID: id}, &unblock) if err != nil { - slog.Error("Watchers.Check: sentinels server error", err, slog.String("endpoint", fn.LocalService)) + slog.Error("Watchers.Check: sentinels server error", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic("sentinel server error") } if unblock { @@ -406,7 +406,7 @@ func (fn *sepHarnessBase) block() { var ignored bool err := sepClient.Call("Watchers.Block", &Args{WatcherID: fn.WatcherID}, &ignored) if err != nil { - slog.Error("Watchers.Block error", err, slog.String("endpoint", fn.LocalService)) + slog.Error("Watchers.Block error", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic(err) } c := sepWaitMap[fn.WatcherID] @@ -423,7 +423,7 @@ func (fn *sepHarnessBase) delay() bool { var delay bool err := sepClient.Call("Watchers.Delay", &Args{WatcherID: fn.WatcherID}, &delay) if err != nil { - slog.Error("Watchers.Delay error", err) + slog.Error("Watchers.Delay error", slog.Any("error", err)) panic(err) } return delay diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index f33754b2ca0a..9f00c22789b6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "runtime/debug" "sync/atomic" "time" @@ -33,7 +34,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) @@ -361,7 +361,7 @@ func portFor(wInCid string, wk *worker.W) []byte { } sourcePortBytes, err := proto.Marshal(sourcePort) if err != nil { - slog.Error("bad port", err, slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl())) + slog.Error("bad port", slog.Any("error", err), slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl())) } return sourcePortBytes } diff --git a/sdks/go/pkg/beam/runners/prism/internal/web/web.go b/sdks/go/pkg/beam/runners/prism/internal/web/web.go index 9fabe22cee3a..b14778e4462c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/web/web.go +++ b/sdks/go/pkg/beam/runners/prism/internal/web/web.go @@ -26,6 +26,7 @@ import ( "fmt" "html/template" "io" + "log/slog" "net/http" "sort" "strings" @@ -40,7 +41,6 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go index 3ccafdb81e9a..55cdb97f258c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -19,12 +19,12 @@ import ( "bytes" "context" "fmt" + "log/slog" "sync/atomic" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" - "golang.org/x/exp/slog" ) // SideInputKey is for data lookups for a given bundle. diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index f9ec03793488..1f129595abef 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -22,10 +22,9 @@ import ( "context" "fmt" "io" + "log/slog" "math" "net" - "strconv" - "strings" "sync" "sync/atomic" "time" @@ -39,7 +38,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -203,30 +201,46 @@ func (wk *W) Logging(stream fnpb.BeamFnLogging_LoggingServer) error { case codes.Canceled: return nil default: - slog.Error("logging.Recv", err, "worker", wk) + slog.Error("logging.Recv", slog.Any("error", err), slog.Any("worker", wk)) return err } } for _, l := range in.GetLogEntries() { - if l.Severity >= minsev { - // TODO: Connect to the associated Job for this worker instead of - // logging locally for SDK side logging. - file := l.GetLogLocation() - i := strings.LastIndex(file, ":") - line, _ := strconv.Atoi(file[i+1:]) - if i > 0 { - file = file[:i] - } + // TODO base this on a per pipeline logging setting. + if l.Severity < minsev { + continue + } + + // Ideally we'd be writing these to per-pipeline files, but for now re-log them on the Prism process. + // We indicate they're from the SDK, and which worker, keeping the same log severity. + // SDK specific and worker specific fields are in separate groups for legibility. - slog.LogAttrs(stream.Context(), toSlogSev(l.GetSeverity()), l.GetMessage(), - slog.Any(slog.SourceKey, &slog.Source{ - File: file, - Line: line, - }), - slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), - slog.Any("worker", wk), - ) + attrs := []any{ + slog.String("transformID", l.GetTransformId()), // TODO: pull the unique name from the pipeline graph. + slog.String("location", l.GetLogLocation()), + slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), + slog.String(slog.MessageKey, l.GetMessage()), } + if fs := l.GetCustomData().GetFields(); len(fs) > 0 { + var grp []any + for n, v := range l.GetCustomData().GetFields() { + var attr slog.Attr + switch v.Kind.(type) { + case *structpb.Value_BoolValue: + attr = slog.Bool(n, v.GetBoolValue()) + case *structpb.Value_NumberValue: + attr = slog.Float64(n, v.GetNumberValue()) + case *structpb.Value_StringValue: + attr = slog.String(n, v.GetStringValue()) + default: + attr = slog.Any(n, v.AsInterface()) + } + grp = append(grp, attr) + } + attrs = append(attrs, slog.Group("customData", grp...)) + } + + slog.LogAttrs(stream.Context(), toSlogSev(l.GetSeverity()), "log from SDK worker", slog.Any("worker", wk), slog.Group("sdk", attrs...)) } } } @@ -298,7 +312,7 @@ func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error { if b, ok := wk.activeInstructions[resp.GetInstructionId()]; ok { b.Respond(resp) } else { - slog.Debug("ctrl.Recv: %v", resp) + slog.Debug("ctrl.Recv", slog.Any("response", resp)) } wk.mu.Unlock() } @@ -355,7 +369,7 @@ func (wk *W) Data(data fnpb.BeamFnData_DataServer) error { case codes.Canceled: return default: - slog.Error("data.Recv failed", err, "worker", wk) + slog.Error("data.Recv failed", slog.Any("error", err), slog.Any("worker", wk)) panic(err) } } @@ -434,7 +448,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case codes.Canceled: return default: - slog.Error("state.Recv failed", err, "worker", wk) + slog.Error("state.Recv failed", slog.Any("error", err), slog.Any("worker", wk)) panic(err) } } @@ -584,7 +598,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { }() for resp := range responses { if err := state.Send(resp); err != nil { - slog.Error("state.Send error", err) + slog.Error("state.Send", slog.Any("error", err)) } } return nil diff --git a/sdks/go/test/integration/io/xlang/debezium/debezium_test.go b/sdks/go/test/integration/io/xlang/debezium/debezium_test.go index 24c2b513b2b2..208a062f9436 100644 --- a/sdks/go/test/integration/io/xlang/debezium/debezium_test.go +++ b/sdks/go/test/integration/io/xlang/debezium/debezium_test.go @@ -34,7 +34,7 @@ import ( ) const ( - debeziumImage = "debezium/example-postgres:latest" + debeziumImage = "quay.io/debezium/example-postgres:latest" debeziumPort = "5432/tcp" maxRetries = 5 ) diff --git a/sdks/java/core/build.gradle b/sdks/java/core/build.gradle index e150c22de62d..a8dfbf42f970 100644 --- a/sdks/java/core/build.gradle +++ b/sdks/java/core/build.gradle @@ -73,6 +73,7 @@ dependencies { antlr library.java.antlr // antlr is used to generate code from sdks/java/core/src/main/antlr/ permitUnusedDeclared library.java.antlr + permitUsedUndeclared library.java.antlr_runtime // Required to load constants from the model, e.g. max timestamp for global window shadow project(path: ":model:pipeline", configuration: "shadow") shadow project(path: ":model:fn-execution", configuration: "shadow") @@ -81,7 +82,6 @@ dependencies { shadow library.java.vendored_grpc_1_60_1 shadow library.java.vendored_guava_32_1_2_jre shadow library.java.byte_buddy - shadow library.java.antlr_runtime shadow library.java.commons_compress shadow library.java.commons_lang3 testImplementation library.java.mockito_inline diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java index f739a797af80..453c1cd79a42 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.fn; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.harness.JvmInitializer; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.common.ReflectHelpers; @@ -25,6 +26,8 @@ /** Helpers for executing {@link JvmInitializer} implementations. */ public class JvmInitializers { + private static final AtomicBoolean initialized = new AtomicBoolean(false); + /** * Finds all registered implementations of JvmInitializer and executes their {@code onStartup} * methods. Should be called in worker harness implementations at the very beginning of their main @@ -50,10 +53,23 @@ public static void runBeforeProcessing(PipelineOptions options) { // We load the logger in the method to minimize the amount of class loading that happens // during class initialization. Logger logger = LoggerFactory.getLogger(JvmInitializers.class); - for (JvmInitializer initializer : ReflectHelpers.loadServicesOrdered(JvmInitializer.class)) { - logger.info("Running JvmInitializer#beforeProcessing for {}", initializer); - initializer.beforeProcessing(options); - logger.info("Completed JvmInitializer#beforeProcessing for {}", initializer); + + try { + for (JvmInitializer initializer : ReflectHelpers.loadServicesOrdered(JvmInitializer.class)) { + logger.info("Running JvmInitializer#beforeProcessing for {}", initializer); + initializer.beforeProcessing(options); + logger.info("Completed JvmInitializer#beforeProcessing for {}", initializer); + } + initialized.compareAndSet(false, true); + } catch (Error e) { + if (initialized.get()) { + logger.warn( + "Error at JvmInitializer#beforeProcessing. This error is suppressed after " + + "previous success runs. It is expected on Embedded environment", + e); + } else { + throw e; + } } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java index e946022c4e36..aa0dea80b0a1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java @@ -17,11 +17,13 @@ */ package org.apache.beam.sdk.fn.data; +import java.time.Duration; import java.util.HashSet; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.function.Consumer; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.pipeline.v1.Endpoints; @@ -30,6 +32,8 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; @@ -49,13 +53,20 @@ */ public class BeamFnDataGrpcMultiplexer implements AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcMultiplexer.class); + private static final Duration POISONED_INSTRUCTION_ID_CACHE_TIMEOUT = Duration.ofMinutes(20); private final Endpoints.@Nullable ApiServiceDescriptor apiServiceDescriptor; private final StreamObserver inboundObserver; private final StreamObserver outboundObserver; - private final ConcurrentMap< + private final ConcurrentHashMap< /*instructionId=*/ String, CompletableFuture>> receivers; - private final ConcurrentMap erroredInstructionIds; + private final Cache poisonedInstructionIds; + + private static class PoisonedException extends RuntimeException { + public PoisonedException() { + super("Instruction poisoned"); + } + }; public BeamFnDataGrpcMultiplexer( Endpoints.@Nullable ApiServiceDescriptor apiServiceDescriptor, @@ -64,7 +75,8 @@ public BeamFnDataGrpcMultiplexer( baseOutboundObserverFactory) { this.apiServiceDescriptor = apiServiceDescriptor; this.receivers = new ConcurrentHashMap<>(); - this.erroredInstructionIds = new ConcurrentHashMap<>(); + this.poisonedInstructionIds = + CacheBuilder.newBuilder().expireAfterWrite(POISONED_INSTRUCTION_ID_CACHE_TIMEOUT).build(); this.inboundObserver = new InboundObserver(); this.outboundObserver = outboundObserverFactory.outboundObserverFor(baseOutboundObserverFactory, inboundObserver); @@ -87,11 +99,6 @@ public StreamObserver getOutboundObserver() { return outboundObserver; } - private CompletableFuture> receiverFuture( - String instructionId) { - return receivers.computeIfAbsent(instructionId, (unused) -> new CompletableFuture<>()); - } - /** * Registers a consumer for the specified instruction id. * @@ -99,17 +106,63 @@ private CompletableFuture> receiverF * instruction ids ensuring that the receiver will only see {@link BeamFnApi.Elements} with a * single instruction id. * - *

The caller must {@link #unregisterConsumer unregister the consumer} when they no longer wish - * to receive messages. + *

The caller must either {@link #unregisterConsumer unregister the consumer} when all messages + * have been processed or {@link #poisonInstructionId(String) poison the instruction} if messages + * for the instruction should be dropped. */ public void registerConsumer( String instructionId, CloseableFnDataReceiver receiver) { - receiverFuture(instructionId).complete(receiver); + receivers.compute( + instructionId, + (unused, existing) -> { + if (existing != null) { + if (!existing.complete(receiver)) { + throw new IllegalArgumentException("Instruction id was registered twice"); + } + return existing; + } + if (poisonedInstructionIds.getIfPresent(instructionId) != null) { + throw new IllegalArgumentException("Instruction id was poisoned"); + } + return CompletableFuture.completedFuture(receiver); + }); } - /** Unregisters a consumer. */ + /** Unregisters a previously registered consumer. */ public void unregisterConsumer(String instructionId) { - receivers.remove(instructionId); + @Nullable + CompletableFuture> receiverFuture = + receivers.remove(instructionId); + if (receiverFuture != null && !receiverFuture.isDone()) { + // The future must have been inserted by the inbound observer since registerConsumer completes + // the future. + throw new IllegalArgumentException("Unregistering consumer which was not registered."); + } + } + + /** + * Poisons an instruction id. + * + *

Any records for the instruction on the inbound observer will be dropped for the next {@link + * #POISONED_INSTRUCTION_ID_CACHE_TIMEOUT}. + */ + public void poisonInstructionId(String instructionId) { + poisonedInstructionIds.put(instructionId, Boolean.TRUE); + @Nullable + CompletableFuture> receiverFuture = + receivers.remove(instructionId); + if (receiverFuture != null) { + // Completing exceptionally has no effect if the future was already notified. In that case + // whatever registered the receiver needs to handle cancelling it. + receiverFuture.completeExceptionally(new PoisonedException()); + if (!receiverFuture.isCompletedExceptionally()) { + try { + receiverFuture.get().close(); + } catch (Exception e) { + LOG.warn("Unexpected error closing existing observer"); + } + } + } } @VisibleForTesting @@ -210,27 +263,42 @@ public void onNext(BeamFnApi.Elements value) { } private void forwardToConsumerForInstructionId(String instructionId, BeamFnApi.Elements value) { - if (erroredInstructionIds.containsKey(instructionId)) { - LOG.debug("Ignoring inbound data for failed instruction {}", instructionId); - return; - } - CompletableFuture> consumerFuture = - receiverFuture(instructionId); - if (!consumerFuture.isDone()) { - LOG.debug( - "Received data for instruction {} without consumer ready. " - + "Waiting for consumer to be registered.", - instructionId); - } CloseableFnDataReceiver consumer; try { - consumer = consumerFuture.get(); - + CompletableFuture> consumerFuture = + receivers.computeIfAbsent( + instructionId, + (unused) -> { + if (poisonedInstructionIds.getIfPresent(instructionId) != null) { + throw new PoisonedException(); + } + LOG.debug( + "Received data for instruction {} without consumer ready. " + + "Waiting for consumer to be registered.", + instructionId); + return new CompletableFuture<>(); + }); + // The consumer may not be registered until the bundle processor is fully constructed so we + // conservatively set + // a high timeout. Poisoning will prevent this for occurring for consumers that will not be + // registered. + consumer = consumerFuture.get(3, TimeUnit.HOURS); /* * TODO: On failure we should fail any bundles that were impacted eagerly * instead of relying on the Runner harness to do all the failure handling. */ - } catch (ExecutionException | InterruptedException e) { + } catch (TimeoutException e) { + LOG.error( + "Timed out waiting to observe consumer data stream for instruction {}", + instructionId, + e); + outboundObserver.onError(e); + return; + } catch (ExecutionException | InterruptedException | PoisonedException e) { + if (e instanceof PoisonedException || e.getCause() instanceof PoisonedException) { + LOG.debug("Received data for poisoned instruction {}. Dropping input.", instructionId); + return; + } LOG.error( "Client interrupted during handling of data for instruction {}", instructionId, e); outboundObserver.onError(e); @@ -240,10 +308,11 @@ private void forwardToConsumerForInstructionId(String instructionId, BeamFnApi.E outboundObserver.onError(e); return; } + try { consumer.accept(value); } catch (Exception e) { - erroredInstructionIds.put(instructionId, true); + poisonInstructionId(instructionId); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java index b0d29e2295a8..2c6b61e62121 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java @@ -202,6 +202,12 @@ public WeightedList decodeFromChunkBoundaryToChunkBoundary() { T next = next(); rvals.add(next); } + // We don't support seeking backwards so release the memory of the last + // page if it is completed. + if (inbound.currentStream.available() == 0) { + inbound.position = 0; + inbound.currentStream = EMPTY_STREAM; + } // Uses the size of the ByteString as an approximation for the heap size occupied by the // page, considering an overhead of {@link BYTES_LIST_ELEMENT_OVERHEAD} for each element. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java index c369eefeb65c..f35782c2b9a2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java @@ -19,10 +19,7 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.lang.reflect.Type; -import java.util.Comparator; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; import org.apache.beam.sdk.schemas.utils.AutoValueUtils; @@ -34,13 +31,10 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** A {@link SchemaProvider} for AutoValue classes. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) public class AutoValueSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on AutoValue getters. */ @VisibleForTesting @@ -51,7 +45,11 @@ public static class AbstractGetterTypeSupplier implements FieldValueTypeSupplier public List get(TypeDescriptor typeDescriptor) { // If the generated class is passed in, we want to look at the base class to find the getters. - TypeDescriptor targetTypeDescriptor = AutoValueUtils.getBaseAutoValueClass(typeDescriptor); + TypeDescriptor targetTypeDescriptor = + Preconditions.checkNotNull( + AutoValueUtils.getBaseAutoValueClass(typeDescriptor), + "unable to determine base AutoValue class for type {}", + typeDescriptor); List methods = ReflectUtils.getMethods(targetTypeDescriptor.getRawType()).stream() @@ -63,11 +61,10 @@ public List get(TypeDescriptor typeDescriptor) { .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); return types; } @@ -92,8 +89,8 @@ private static void validateFieldNumbers(List types) } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( targetTypeDescriptor, schema, @@ -146,8 +143,7 @@ public SchemaUserTypeCreator schemaTypeCreator( @Override public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); return JavaBeanUtils.schemaFromJavaBeanClass( - typeDescriptor, AbstractGetterTypeSupplier.INSTANCE, boundTypes); + typeDescriptor, AbstractGetterTypeSupplier.INSTANCE); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java index 8725833bc1da..6e244fefb263 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java @@ -20,6 +20,9 @@ import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.initialization.qual.NotOnlyInitialized; +import org.checkerframework.checker.initialization.qual.UnknownInitialization; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -32,24 +35,25 @@ * significant for larger schemas) on each lookup. This wrapper caches the value returned by the * inner factory, so the schema comparison only need happen on the first lookup. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) -public class CachingFactory implements Factory { +public class CachingFactory implements Factory { private transient @Nullable ConcurrentHashMap, CreatedT> cache = null; - private final Factory innerFactory; + private final @NotOnlyInitialized Factory innerFactory; - public CachingFactory(Factory innerFactory) { + public CachingFactory(@UnknownInitialization Factory innerFactory) { this.innerFactory = innerFactory; } - @Override - public CreatedT create(TypeDescriptor typeDescriptor, Schema schema) { + private ConcurrentHashMap, CreatedT> getCache() { if (cache == null) { cache = new ConcurrentHashMap<>(); } + return cache; + } + + @Override + public CreatedT create(TypeDescriptor typeDescriptor, Schema schema) { + ConcurrentHashMap, CreatedT> cache = getCache(); CreatedT cached = cache.get(typeDescriptor); if (cached != null) { return cached; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java index fb98db8e8343..63ab56dc7609 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java @@ -19,6 +19,7 @@ import java.io.Serializable; import org.apache.beam.sdk.annotations.Internal; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -29,7 +30,7 @@ *

Implementations of this interface are generated at runtime to map object fields to Row fields. */ @Internal -public interface FieldValueGetter extends Serializable { +public interface FieldValueGetter extends Serializable { @Nullable ValueT get(ObjectT object); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index 64687e6d3381..43aac6a5e20c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -24,10 +24,10 @@ import java.lang.reflect.Field; import java.lang.reflect.Member; import java.lang.reflect.Method; -import java.lang.reflect.Type; import java.util.Arrays; import java.util.Collections; import java.util.Map; +import java.util.Optional; import java.util.stream.Stream; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; @@ -42,10 +42,6 @@ /** Represents type information for a Java type that will be used to infer a Schema type. */ @AutoValue -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) @Internal public abstract class FieldValueTypeInformation implements Serializable { /** Optionally returns the field index. */ @@ -129,9 +125,12 @@ public static FieldValueTypeInformation forOneOf( } public static FieldValueTypeInformation forField( - Field field, int index, Map boundTypes) { + @Nullable TypeDescriptor typeDescriptor, Field field, int index) { TypeDescriptor type = - TypeDescriptor.of(ReflectUtils.resolveType(field.getGenericType(), boundTypes)); + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(field.getGenericType())) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(field.getGenericType())); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(field.getName(), field)) .setNumber(getNumberOverride(index, field)) @@ -139,9 +138,9 @@ public static FieldValueTypeInformation forField( .setType(type) .setRawType(type.getRawType()) .setField(field) - .setElementType(getIterableComponentType(field, boundTypes)) - .setMapKeyType(getMapKeyType(field, boundTypes)) - .setMapValueType(getMapValueType(field, boundTypes)) + .setElementType(getIterableComponentType(type)) + .setMapKeyType(getMapKeyType(type)) + .setMapValueType(getMapValueType(type)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(field)) .build(); @@ -189,8 +188,12 @@ public static String getNameOverride( return fieldDescription.value(); } + public static FieldValueTypeInformation forGetter(Method method, int index) { + return forGetter(null, method, index); + } + public static FieldValueTypeInformation forGetter( - Method method, int index, Map boundTypes) { + @Nullable TypeDescriptor typeDescriptor, Method method, int index) { String name; if (method.getName().startsWith("get")) { name = ReflectUtils.stripPrefix(method.getName(), "get"); @@ -201,7 +204,11 @@ public static FieldValueTypeInformation forGetter( } TypeDescriptor type = - TypeDescriptor.of(ReflectUtils.resolveType(method.getGenericReturnType(), boundTypes)); + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(method.getGenericReturnType())) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(method.getGenericReturnType())); + boolean nullable = hasNullableReturnType(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(name, method)) @@ -210,9 +217,9 @@ public static FieldValueTypeInformation forGetter( .setType(type) .setRawType(type.getRawType()) .setMethod(method) - .setElementType(getIterableComponentType(type, boundTypes)) - .setMapKeyType(getMapKeyType(type, boundTypes)) - .setMapValueType(getMapValueType(type, boundTypes)) + .setElementType(getIterableComponentType(type)) + .setMapKeyType(getMapKeyType(type)) + .setMapValueType(getMapValueType(type)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(method)) .build(); @@ -259,13 +266,21 @@ private static boolean isNullableAnnotation(Annotation annotation) { return annotation.annotationType().getSimpleName().equals("Nullable"); } + public static FieldValueTypeInformation forSetter(Method method) { + return forSetter(null, method); + } + + public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) { + return forSetter(null, method, setterPrefix); + } + public static FieldValueTypeInformation forSetter( - Method method, Map boundParameters) { - return forSetter(method, "set", boundParameters); + @Nullable TypeDescriptor typeDescriptor, Method method) { + return forSetter(typeDescriptor, method, "set"); } public static FieldValueTypeInformation forSetter( - Method method, String setterPrefix, Map boundTypes) { + @Nullable TypeDescriptor typeDescriptor, Method method, String setterPrefix) { String name; if (method.getName().startsWith(setterPrefix)) { name = ReflectUtils.stripPrefix(method.getName(), setterPrefix); @@ -274,8 +289,10 @@ public static FieldValueTypeInformation forSetter( } TypeDescriptor type = - TypeDescriptor.of( - ReflectUtils.resolveType(method.getGenericParameterTypes()[0], boundTypes)); + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(method.getGenericParameterTypes()[0])) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(method.getGenericParameterTypes()[0])); boolean nullable = hasSingleNullableParameter(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(name) @@ -283,9 +300,9 @@ public static FieldValueTypeInformation forSetter( .setType(type) .setRawType(type.getRawType()) .setMethod(method) - .setElementType(getIterableComponentType(type, boundTypes)) - .setMapKeyType(getMapKeyType(type, boundTypes)) - .setMapValueType(getMapValueType(type, boundTypes)) + .setElementType(getIterableComponentType(type)) + .setMapKeyType(getMapKeyType(type)) + .setMapValueType(getMapValueType(type)) .setOneOfTypes(Collections.emptyMap()) .build(); } @@ -294,15 +311,9 @@ public FieldValueTypeInformation withName(String name) { return toBuilder().setName(name).build(); } - private static FieldValueTypeInformation getIterableComponentType( - Field field, Map boundTypes) { - return getIterableComponentType(TypeDescriptor.of(field.getGenericType()), boundTypes); - } - - static @Nullable FieldValueTypeInformation getIterableComponentType( - TypeDescriptor valueType, Map boundTypes) { + static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { // TODO: Figure out nullable elements. - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType, boundTypes); + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); if (componentType == null) { return null; } @@ -312,43 +323,30 @@ private static FieldValueTypeInformation getIterableComponentType( .setNullable(false) .setType(componentType) .setRawType(componentType.getRawType()) - .setElementType(getIterableComponentType(componentType, boundTypes)) - .setMapKeyType(getMapKeyType(componentType, boundTypes)) - .setMapValueType(getMapValueType(componentType, boundTypes)) + .setElementType(getIterableComponentType(componentType)) + .setMapKeyType(getMapKeyType(componentType)) + .setMapValueType(getMapValueType(componentType)) .setOneOfTypes(Collections.emptyMap()) .build(); } - // If the Field is a map type, returns the key type, otherwise returns a null reference. - - private static @Nullable FieldValueTypeInformation getMapKeyType( - Field field, Map boundTypes) { - return getMapKeyType(TypeDescriptor.of(field.getGenericType()), boundTypes); - } - + // If the type is a map type, returns the key type, otherwise returns a null reference. private static @Nullable FieldValueTypeInformation getMapKeyType( - TypeDescriptor typeDescriptor, Map boundTypes) { - return getMapType(typeDescriptor, 0, boundTypes); - } - - // If the Field is a map type, returns the value type, otherwise returns a null reference. - - private static @Nullable FieldValueTypeInformation getMapValueType( - Field field, Map boundTypes) { - return getMapType(TypeDescriptor.of(field.getGenericType()), 1, boundTypes); + TypeDescriptor typeDescriptor) { + return getMapType(typeDescriptor, 0); } + // If the type is a map type, returns the value type, otherwise returns a null reference. private static @Nullable FieldValueTypeInformation getMapValueType( - TypeDescriptor typeDescriptor, Map boundTypes) { - return getMapType(typeDescriptor, 1, boundTypes); + TypeDescriptor typeDescriptor) { + return getMapType(typeDescriptor, 1); } // If the Field is a map type, returns the key or value type (0 is key type, 1 is value). // Otherwise returns a null reference. - @SuppressWarnings("unchecked") private static @Nullable FieldValueTypeInformation getMapType( - TypeDescriptor valueType, int index, Map boundTypes) { - TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index, boundTypes); + TypeDescriptor valueType, int index) { + TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); if (mapType == null) { return null; } @@ -357,9 +355,9 @@ private static FieldValueTypeInformation getIterableComponentType( .setNullable(false) .setType(mapType) .setRawType(mapType.getRawType()) - .setElementType(getIterableComponentType(mapType, boundTypes)) - .setMapKeyType(getMapKeyType(mapType, boundTypes)) - .setMapValueType(getMapValueType(mapType, boundTypes)) + .setElementType(getIterableComponentType(mapType)) + .setMapKeyType(getMapKeyType(mapType)) + .setMapValueType(getMapValueType(mapType)) .setOneOfTypes(Collections.emptyMap()) .build(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java index ce5be71933b8..4e431bb45207 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java @@ -17,13 +17,12 @@ */ package org.apache.beam.sdk.schemas; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; @@ -32,10 +31,13 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.initialization.qual.NotOnlyInitialized; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -46,10 +48,7 @@ * methods which receive {@link TypeDescriptor}s instead of ordinary {@link Class}es as * arguments, which permits to support generic type signatures during schema inference */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) @Deprecated public abstract class GetterBasedSchemaProvider implements SchemaProvider { @@ -67,9 +66,9 @@ public abstract class GetterBasedSchemaProvider implements SchemaProvider { * override it if you want to use the richer type signature contained in the {@link * TypeDescriptor} not subject to the type erasure. */ - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { - return fieldValueGetters(targetTypeDescriptor.getRawType(), schema); + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { + return (List) fieldValueGetters(targetTypeDescriptor.getRawType(), schema); } /** @@ -112,9 +111,10 @@ public SchemaUserTypeCreator schemaTypeCreator( return schemaTypeCreator(targetTypeDescriptor.getRawType(), schema); } - private class ToRowWithValueGetters implements SerializableFunction { + private class ToRowWithValueGetters + implements SerializableFunction { private final Schema schema; - private final Factory> getterFactory; + private final Factory>> getterFactory; public ToRowWithValueGetters(Schema schema) { this.schema = schema; @@ -122,7 +122,12 @@ public ToRowWithValueGetters(Schema schema) { // schema, return a caching factory that caches the first value seen for each class. This // prevents having to lookup the getter list each time createGetters is called. this.getterFactory = - RowValueGettersFactory.of(GetterBasedSchemaProvider.this::fieldValueGetters); + RowValueGettersFactory.of( + (Factory>>) + (typeDescriptor, schema1) -> + (List) + GetterBasedSchemaProvider.this.fieldValueGetters( + typeDescriptor, schema1)); } @Override @@ -160,13 +165,15 @@ public SerializableFunction toRowFunction(TypeDescriptor typeDesc // important to capture the schema once here, so all invocations of the toRowFunction see the // same version of the schema. If schemaFor were to be called inside the lambda below, different // workers would see different versions of the schema. - Schema schema = schemaFor(typeDescriptor); + @NonNull + Schema schema = + Verify.verifyNotNull( + schemaFor(typeDescriptor), "can't create a ToRowFunction with null schema"); return new ToRowWithValueGetters<>(schema); } @Override - @SuppressWarnings("unchecked") public SerializableFunction fromRowFunction(TypeDescriptor typeDescriptor) { return new FromRowUsingCreator<>(typeDescriptor, this); } @@ -181,23 +188,27 @@ public boolean equals(@Nullable Object obj) { return obj != null && this.getClass() == obj.getClass(); } - private static class RowValueGettersFactory implements Factory> { - private final Factory> gettersFactory; - private final Factory> cachingGettersFactory; + private static class RowValueGettersFactory + implements Factory>> { + private final Factory>> gettersFactory; + private final @NotOnlyInitialized Factory>> + cachingGettersFactory; - static Factory> of(Factory> gettersFactory) { - return new RowValueGettersFactory(gettersFactory).cachingGettersFactory; + static Factory>> of( + Factory>> gettersFactory) { + return new RowValueGettersFactory<>(gettersFactory).cachingGettersFactory; } - RowValueGettersFactory(Factory> gettersFactory) { + RowValueGettersFactory(Factory>> gettersFactory) { this.gettersFactory = gettersFactory; this.cachingGettersFactory = new CachingFactory<>(this); } @Override - public List create(TypeDescriptor typeDescriptor, Schema schema) { - List getters = gettersFactory.create(typeDescriptor, schema); - List rowGetters = new ArrayList<>(getters.size()); + public List> create( + TypeDescriptor typeDescriptor, Schema schema) { + List> getters = gettersFactory.create(typeDescriptor, schema); + List> rowGetters = new ArrayList<>(getters.size()); for (int i = 0; i < getters.size(); i++) { rowGetters.add(rowValueGetter(getters.get(i), schema.getField(i).getType())); } @@ -209,71 +220,80 @@ static boolean needsConversion(FieldType type) { return typeName.equals(TypeName.ROW) || typeName.isLogicalType() || ((typeName.equals(TypeName.ARRAY) || typeName.equals(TypeName.ITERABLE)) - && needsConversion(type.getCollectionElementType())) + && needsConversion(Verify.verifyNotNull(type.getCollectionElementType()))) || (typeName.equals(TypeName.MAP) - && (needsConversion(type.getMapKeyType()) - || needsConversion(type.getMapValueType()))); + && (needsConversion(Verify.verifyNotNull(type.getMapKeyType())) + || needsConversion(Verify.verifyNotNull(type.getMapValueType())))); } - FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { + FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { TypeName typeName = type.getTypeName(); if (!needsConversion(type)) { return base; } if (typeName.equals(TypeName.ROW)) { - return new GetRow(base, type.getRowSchema(), cachingGettersFactory); + return new GetRow(base, Verify.verifyNotNull(type.getRowSchema()), cachingGettersFactory); } else if (typeName.equals(TypeName.ARRAY)) { - FieldType elementType = type.getCollectionElementType(); + FieldType elementType = Verify.verifyNotNull(type.getCollectionElementType()); return elementType.getTypeName().equals(TypeName.ROW) ? new GetEagerCollection(base, converter(elementType)) : new GetCollection(base, converter(elementType)); } else if (typeName.equals(TypeName.ITERABLE)) { - return new GetIterable(base, converter(type.getCollectionElementType())); + return new GetIterable( + base, converter(Verify.verifyNotNull(type.getCollectionElementType()))); } else if (typeName.equals(TypeName.MAP)) { - return new GetMap(base, converter(type.getMapKeyType()), converter(type.getMapValueType())); + return new GetMap( + base, + converter(Verify.verifyNotNull(type.getMapKeyType())), + converter(Verify.verifyNotNull(type.getMapValueType()))); } else if (type.isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = type.getLogicalType(OneOfType.class); Schema oneOfSchema = oneOfType.getOneOfSchema(); Map values = oneOfType.getCaseEnumType().getValuesMap(); - Map converters = Maps.newHashMapWithExpectedSize(values.size()); + Map> converters = + Maps.newHashMapWithExpectedSize(values.size()); for (Map.Entry kv : values.entrySet()) { FieldType fieldType = oneOfSchema.getField(kv.getKey()).getType(); - FieldValueGetter converter = converter(fieldType); + FieldValueGetter converter = converter(fieldType); converters.put(kv.getValue(), converter); } return new GetOneOf(base, converters, oneOfType); } else if (typeName.isLogicalType()) { - return new GetLogicalInputType(base, type.getLogicalType()); + return new GetLogicalInputType(base, Verify.verifyNotNull(type.getLogicalType())); } return base; } - FieldValueGetter converter(FieldType type) { + FieldValueGetter converter(FieldType type) { return rowValueGetter(IDENTITY, type); } - static class GetRow extends Converter { + static class GetRow + extends Converter { final Schema schema; - final Factory> factory; + final Factory>> factory; - GetRow(FieldValueGetter getter, Schema schema, Factory> factory) { + GetRow( + FieldValueGetter getter, + Schema schema, + Factory>> factory) { super(getter); this.schema = schema; this.factory = factory; } @Override - Object convert(Object value) { + Object convert(V value) { return Row.withSchema(schema).withFieldValueGetters(factory, value); } } - static class GetEagerCollection extends Converter { + static class GetEagerCollection extends Converter { final FieldValueGetter converter; - GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) { + GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @@ -288,15 +308,16 @@ Object convert(Collection collection) { } } - static class GetCollection extends Converter { + static class GetCollection extends Converter { final FieldValueGetter converter; - GetCollection(FieldValueGetter getter, FieldValueGetter converter) { + GetCollection(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @Override + @SuppressWarnings({"nullness"}) Object convert(Collection collection) { if (collection instanceof List) { // For performance reasons if the input is a list, make sure that we produce a list. @@ -309,45 +330,51 @@ Object convert(Collection collection) { } } - static class GetIterable extends Converter { + static class GetIterable extends Converter { final FieldValueGetter converter; - GetIterable(FieldValueGetter getter, FieldValueGetter converter) { + GetIterable(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @Override + @SuppressWarnings({"nullness"}) Object convert(Iterable value) { return Iterables.transform(value, converter::get); } } - static class GetMap extends Converter> { - final FieldValueGetter keyConverter; - final FieldValueGetter valueConverter; + static class GetMap + extends Converter> { + final FieldValueGetter<@NonNull K1, K2> keyConverter; + final FieldValueGetter<@NonNull V1, V2> valueConverter; GetMap( - FieldValueGetter getter, FieldValueGetter keyConverter, FieldValueGetter valueConverter) { + FieldValueGetter> getter, + FieldValueGetter<@NonNull K1, K2> keyConverter, + FieldValueGetter<@NonNull V1, V2> valueConverter) { super(getter); this.keyConverter = keyConverter; this.valueConverter = valueConverter; } @Override - Object convert(Map value) { - Map returnMap = Maps.newHashMapWithExpectedSize(value.size()); - for (Map.Entry entry : value.entrySet()) { - returnMap.put(keyConverter.get(entry.getKey()), valueConverter.get(entry.getValue())); + Map<@Nullable K2, @Nullable V2> convert(Map<@Nullable K1, @Nullable V1> value) { + Map<@Nullable K2, @Nullable V2> returnMap = Maps.newHashMapWithExpectedSize(value.size()); + for (Map.Entry<@Nullable K1, @Nullable V1> entry : value.entrySet()) { + returnMap.put( + Optional.ofNullable(entry.getKey()).map(keyConverter::get).orElse(null), + Optional.ofNullable(entry.getValue()).map(valueConverter::get).orElse(null)); } return returnMap; } } - static class GetLogicalInputType extends Converter { + static class GetLogicalInputType extends Converter { final LogicalType logicalType; - GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) { + GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) { super(getter); this.logicalType = logicalType; } @@ -359,12 +386,14 @@ Object convert(Object value) { } } - static class GetOneOf extends Converter { + static class GetOneOf extends Converter { final OneOfType oneOfType; - final Map converters; + final Map> converters; GetOneOf( - FieldValueGetter getter, Map converters, OneOfType oneOfType) { + FieldValueGetter getter, + Map> converters, + OneOfType oneOfType) { super(getter); this.converters = converters; this.oneOfType = oneOfType; @@ -373,24 +402,31 @@ static class GetOneOf extends Converter { @Override Object convert(OneOfType.Value value) { EnumerationType.Value caseType = value.getCaseType(); - FieldValueGetter converter = converters.get(caseType.getValue()); - checkState(converter != null, "Missing OneOf converter for case %s.", caseType); + + @NonNull + FieldValueGetter<@NonNull Object, Object> converter = + Verify.verifyNotNull( + converters.get(caseType.getValue()), + "Missing OneOf converter for case %s.", + caseType); + return oneOfType.createValue(caseType, converter.get(value.getValue())); } } - abstract static class Converter implements FieldValueGetter { - final FieldValueGetter getter; + abstract static class Converter + implements FieldValueGetter { + final FieldValueGetter getter; - public Converter(FieldValueGetter getter) { + public Converter(FieldValueGetter getter) { this.getter = getter; } - abstract Object convert(T value); + abstract Object convert(ValueT value); @Override - public @Nullable Object get(Object object) { - T value = (T) getter.get(object); + public @Nullable Object get(ObjectT object) { + ValueT value = getter.get(object); if (value == null) { return null; } @@ -398,7 +434,7 @@ public Converter(FieldValueGetter getter) { } @Override - public @Nullable Object getRaw(Object object) { + public @Nullable Object getRaw(ObjectT object) { return getter.getRaw(object); } @@ -408,16 +444,16 @@ public String name() { } } - private static final FieldValueGetter IDENTITY = - new FieldValueGetter() { + private static final FieldValueGetter<@NonNull Object, Object> IDENTITY = + new FieldValueGetter<@NonNull Object, Object>() { @Override - public @Nullable Object get(Object object) { + public Object get(@NonNull Object object) { return object; } @Override public String name() { - return null; + return "IDENTITY"; } }; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java index de31f9947c36..e7214d8f663a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java @@ -19,6 +19,7 @@ import java.util.List; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A newer version of {@link GetterBasedSchemaProvider}, which works with {@link TypeDescriptor}s, @@ -28,12 +29,12 @@ public abstract class GetterBasedSchemaProviderV2 extends GetterBasedSchemaProvider { @Override public List fieldValueGetters(Class targetClass, Schema schema) { - return fieldValueGetters(TypeDescriptor.of(targetClass), schema); + return (List) fieldValueGetters(TypeDescriptor.of(targetClass), schema); } @Override - public abstract List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema); + public abstract List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema); @Override public List fieldValueTypeInformations( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java index ad71576670bf..14adf2f6603e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java @@ -19,10 +19,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.lang.reflect.Type; -import java.util.Comparator; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -36,6 +33,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -51,10 +49,7 @@ *

TODO: Validate equals() method is provided, and if not generate a "slow" equals method based * on the schema. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class JavaBeanSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on getter methods. */ @VisibleForTesting @@ -69,11 +64,10 @@ public List get(TypeDescriptor typeDescriptor) { .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); return types; } @@ -114,33 +108,35 @@ public static class SetterTypeSupplier implements FieldValueTypeSupplier { @Override public List get(TypeDescriptor typeDescriptor) { - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream() .filter(ReflectUtils::isSetter) .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) - .map(m -> FieldValueTypeInformation.forSetter(m, boundTypes)) + .map(m -> FieldValueTypeInformation.forSetter(typeDescriptor, m)) .map( t -> { - if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) { + Method m = + Preconditions.checkNotNull( + t.getMethod(), JavaBeanUtils.SETTER_WITH_NULL_METHOD_ERROR); + if (m.getAnnotation(SchemaFieldNumber.class) != null) { throw new RuntimeException( String.format( "@SchemaFieldNumber can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } - if (t.getMethod().getAnnotation(SchemaFieldName.class) != null) { + if (m.getAnnotation(SchemaFieldName.class) != null) { throw new RuntimeException( String.format( "@SchemaFieldName can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } - if (t.getMethod().getAnnotation(SchemaCaseFormat.class) != null) { + if (m.getAnnotation(SchemaCaseFormat.class) != null) { throw new RuntimeException( String.format( "@SchemaCaseFormat can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } return t; }) @@ -160,10 +156,8 @@ public boolean equals(@Nullable Object obj) { @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); Schema schema = - JavaBeanUtils.schemaFromJavaBeanClass( - typeDescriptor, GetterTypeSupplier.INSTANCE, boundTypes); + JavaBeanUtils.schemaFromJavaBeanClass(typeDescriptor, GetterTypeSupplier.INSTANCE); // If there are no creator methods, then validate that we have setters for every field. // Otherwise, we will have no way of creating instances of the class. @@ -178,8 +172,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( targetTypeDescriptor, schema, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java index da0f59c8ee96..9a8eef2bf2c8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java @@ -21,22 +21,22 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.lang.reflect.Type; -import java.util.Comparator; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.DefaultTypeConversionsFactory; import org.apache.beam.sdk.schemas.utils.FieldValueTypeSupplier; +import org.apache.beam.sdk.schemas.utils.JavaBeanUtils; import org.apache.beam.sdk.schemas.utils.POJOUtils; import org.apache.beam.sdk.schemas.utils.ReflectUtils; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A {@link SchemaProvider} for Java POJO objects. @@ -51,7 +51,6 @@ *

TODO: Validate equals() method is provided, and if not generate a "slow" equals method based * on the schema. */ -@SuppressWarnings({"nullness", "rawtypes"}) public class JavaFieldSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on public fields. */ @VisibleForTesting @@ -64,13 +63,11 @@ public List get(TypeDescriptor typeDescriptor) { ReflectUtils.getFields(typeDescriptor.getRawType()).stream() .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); - List types = Lists.newArrayListWithCapacity(fields.size()); - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < fields.size(); ++i) { - types.add(FieldValueTypeInformation.forField(fields.get(i), i, boundTypes)); + types.add(FieldValueTypeInformation.forField(typeDescriptor, fields.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); // If there are no creators registered, then make sure none of the schema fields are final, @@ -79,7 +76,9 @@ public List get(TypeDescriptor typeDescriptor) { && ReflectUtils.getAnnotatedConstructor(typeDescriptor.getRawType()) == null) { Optional finalField = types.stream() - .map(FieldValueTypeInformation::getField) + .flatMap( + fvti -> + Optional.ofNullable(fvti.getField()).map(Stream::of).orElse(Stream.empty())) .filter(f -> Modifier.isFinal(f.getModifiers())) .findAny(); if (finalField.isPresent()) { @@ -115,14 +114,12 @@ private static void validateFieldNumbers(List types) @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); - return POJOUtils.schemaFromPojoClass( - typeDescriptor, JavaFieldTypeSupplier.INSTANCE, boundTypes); + return POJOUtils.schemaFromPojoClass(typeDescriptor, JavaFieldTypeSupplier.INSTANCE); } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return POJOUtils.getGetters( targetTypeDescriptor, schema, @@ -155,7 +152,7 @@ public SchemaUserTypeCreator schemaTypeCreator( ReflectUtils.getAnnotatedConstructor(targetTypeDescriptor.getRawType()); if (constructor != null) { return POJOUtils.getConstructorCreator( - targetTypeDescriptor, + (TypeDescriptor) targetTypeDescriptor, constructor, schema, JavaFieldTypeSupplier.INSTANCE, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index 5af59356b174..02607d91b079 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -90,6 +90,7 @@ public String toString() { return Arrays.toString(array); } } + // A mapping between field names an indices. private final BiMap fieldIndices; @@ -830,10 +831,11 @@ public static FieldType iterable(FieldType elementType) { public static FieldType map(FieldType keyType, FieldType valueType) { if (FieldType.BYTES.equals(keyType)) { LOG.warn( - "Using byte arrays as keys in a Map may lead to unexpected behavior and may not work as intended. " - + "Since arrays do not override equals() or hashCode, comparisons will be done on reference equality only. " - + "ByteBuffers, when used as keys, present similar challenges because Row stores ByteBuffer as a byte array. " - + "Consider using a different type of key for more consistent and predictable behavior."); + "Using byte arrays as keys in a Map may lead to unexpected behavior and may not work as" + + " intended. Since arrays do not override equals() or hashCode, comparisons will" + + " be done on reference equality only. ByteBuffers, when used as keys, present" + + " similar challenges because Row stores ByteBuffer as a byte array. Consider" + + " using a different type of key for more consistent and predictable behavior."); } return FieldType.forTypeName(TypeName.MAP) .setMapKeyType(keyType) @@ -1443,7 +1445,7 @@ private static Schema fromFields(List fields) { } /** Return the list of all field names. */ - public List getFieldNames() { + public List<@NonNull String> getFieldNames() { return getFields().stream().map(Schema.Field::getName).collect(Collectors.toList()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java index b7e3cdf60c18..37b4952e529c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java @@ -38,7 +38,8 @@ public interface SchemaProvider extends Serializable { * Given a type, return a function that converts that type to a {@link Row} object If no schema * exists, returns null. */ - @Nullable SerializableFunction toRowFunction(TypeDescriptor typeDescriptor); + @Nullable + SerializableFunction toRowFunction(TypeDescriptor typeDescriptor); /** * Given a type, returns a function that converts from a {@link Row} object to that type. If no diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java index 5d8b7aab6193..679a1fcf54fc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java @@ -76,12 +76,13 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid providers.put(typeDescriptor, schemaProvider); } - private @Nullable SchemaProvider schemaProviderFor(TypeDescriptor typeDescriptor) { + @Override + public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { TypeDescriptor type = typeDescriptor; do { SchemaProvider schemaProvider = providers.get(type); if (schemaProvider != null) { - return schemaProvider; + return schemaProvider.schemaFor(type); } Class superClass = type.getRawType().getSuperclass(); if (superClass == null || superClass.equals(Object.class)) { @@ -91,24 +92,38 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid } while (true); } - @Override - public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { - @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); - return schemaProvider != null ? schemaProvider.schemaFor(typeDescriptor) : null; - } - @Override public @Nullable SerializableFunction toRowFunction( TypeDescriptor typeDescriptor) { - @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); - return schemaProvider != null ? schemaProvider.toRowFunction(typeDescriptor) : null; + TypeDescriptor type = typeDescriptor; + do { + SchemaProvider schemaProvider = providers.get(type); + if (schemaProvider != null) { + return (SerializableFunction) schemaProvider.toRowFunction(type); + } + Class superClass = type.getRawType().getSuperclass(); + if (superClass == null || superClass.equals(Object.class)) { + return null; + } + type = TypeDescriptor.of(superClass); + } while (true); } @Override public @Nullable SerializableFunction fromRowFunction( TypeDescriptor typeDescriptor) { - @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); - return schemaProvider != null ? schemaProvider.fromRowFunction(typeDescriptor) : null; + TypeDescriptor type = typeDescriptor; + do { + SchemaProvider schemaProvider = providers.get(type); + if (schemaProvider != null) { + return (SerializableFunction) schemaProvider.fromRowFunction(type); + } + Class superClass = type.getRawType().getSuperclass(); + if (superClass == null || superClass.equals(Object.class)) { + return null; + } + type = TypeDescriptor.of(superClass); + } while (true); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java index c3a71bbb454b..54e2a595fa71 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java @@ -160,8 +160,7 @@ public FunctionAndType(Type outputType, Function function) { public FunctionAndType(TypeDescriptor outputType, Function function) { this( - StaticSchemaInference.fieldFromType( - outputType, new EmptyFieldValueTypeSupplier(), Collections.emptyMap()), + StaticSchemaInference.fieldFromType(outputType, new EmptyFieldValueTypeSupplier()), function); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java index 74e97bad4f0f..300dce61e2ea 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java @@ -27,6 +27,7 @@ import java.lang.reflect.Parameter; import java.lang.reflect.Type; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -53,7 +54,6 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; @@ -63,23 +63,25 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversionsFactory; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for managing AutoValue schemas. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) -@Internal +@SuppressWarnings({"rawtypes"}) public class AutoValueUtils { - public static TypeDescriptor getBaseAutoValueClass(TypeDescriptor typeDescriptor) { + public static @Nullable TypeDescriptor getBaseAutoValueClass( + TypeDescriptor typeDescriptor) { // AutoValue extensions may be nested - while (typeDescriptor != null && typeDescriptor.getRawType().getName().contains("AutoValue_")) { - typeDescriptor = TypeDescriptor.of(typeDescriptor.getRawType().getSuperclass()); + @Nullable TypeDescriptor baseTypeDescriptor = typeDescriptor; + while (baseTypeDescriptor != null + && baseTypeDescriptor.getRawType().getName().contains("AutoValue_")) { + baseTypeDescriptor = + Optional.ofNullable(baseTypeDescriptor.getRawType().getSuperclass()) + .map(TypeDescriptor::of) + .orElse(null); } - return typeDescriptor; + return baseTypeDescriptor; } private static TypeDescriptor getAutoValueGenerated(TypeDescriptor typeDescriptor) { @@ -157,14 +159,18 @@ private static boolean matchConstructor( getterTypes.stream() .collect( Collectors.toMap( - f -> ReflectUtils.stripGetterPrefix(f.getMethod().getName()), + f -> + ReflectUtils.stripGetterPrefix( + Preconditions.checkNotNull( + f.getMethod(), JavaBeanUtils.GETTER_WITH_NULL_METHOD_ERROR) + .getName()), Function.identity())); boolean valid = true; // Verify that constructor parameters match (name and type) the inferred schema. for (Parameter parameter : constructor.getParameters()) { FieldValueTypeInformation type = typeMap.get(parameter.getName()); - if (type == null || !type.getRawType().equals(parameter.getType())) { + if (type == null || type.getRawType() != parameter.getType()) { valid = false; break; } @@ -181,7 +187,7 @@ private static boolean matchConstructor( } name = name.substring(0, name.length() - 1); FieldValueTypeInformation type = typeMap.get(name); - if (type == null || !type.getRawType().equals(parameter.getType())) { + if (type == null || type.getRawType() != parameter.getType()) { return false; } } @@ -199,11 +205,11 @@ private static boolean matchConstructor( return null; } - Map boundTypes = ReflectUtils.getAllBoundTypes(TypeDescriptor.of(builderClass)); - Map setterTypes = Maps.newHashMap(); + Map setterTypes = new HashMap<>(); + ReflectUtils.getMethods(builderClass).stream() .filter(ReflectUtils::isSetter) - .map(m -> FieldValueTypeInformation.forSetter(m, boundTypes)) + .map(m -> FieldValueTypeInformation.forSetter(TypeDescriptor.of(builderClass), m)) .forEach(fv -> setterTypes.putIfAbsent(fv.getName(), fv)); List setterMethods = @@ -211,7 +217,11 @@ private static boolean matchConstructor( List schemaTypes = fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); for (FieldValueTypeInformation type : schemaTypes) { - String autoValueFieldName = ReflectUtils.stripGetterPrefix(type.getMethod().getName()); + String autoValueFieldName = + ReflectUtils.stripGetterPrefix( + Preconditions.checkNotNull( + type.getMethod(), JavaBeanUtils.GETTER_WITH_NULL_METHOD_ERROR) + .getName()); FieldValueTypeInformation setterType = setterTypes.get(autoValueFieldName); if (setterType == null) { @@ -325,7 +335,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Duplication.SINGLE, typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(parameter.getParameterizedType())), + .convert(TypeDescriptor.of(parameter.getType())), MethodInvocation.invoke(new ForLoadedMethod(setterMethod)), Removal.SINGLE); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index 3b2428ebb999..5297eb113a97 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -22,6 +22,7 @@ import java.io.Serializable; import java.lang.reflect.Constructor; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; @@ -34,6 +35,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.SortedMap; import net.bytebuddy.ByteBuddy; @@ -42,6 +44,7 @@ import net.bytebuddy.asm.AsmVisitorWrapper; import net.bytebuddy.description.method.MethodDescription.ForLoadedConstructor; import net.bytebuddy.description.method.MethodDescription.ForLoadedMethod; +import net.bytebuddy.description.type.PackageDescription; import net.bytebuddy.description.type.TypeDescription; import net.bytebuddy.description.type.TypeDescription.ForLoadedType; import net.bytebuddy.dynamic.DynamicType; @@ -78,6 +81,8 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeParameter; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Function; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -85,6 +90,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Primitives; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ClassUtils; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTimeZone; import org.joda.time.Instant; @@ -95,8 +101,6 @@ @Internal @SuppressWarnings({ "keyfor", - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" }) public class ByteBuddyUtils { private static final ForLoadedType ARRAYS_TYPE = new ForLoadedType(Arrays.class); @@ -147,7 +151,11 @@ protected String name(TypeDescription superClass) { // If the target class is in a prohibited package (java.*) then leave the original package // alone. String realPackage = - overridePackage(targetPackage) ? targetPackage : superClass.getPackage().getName(); + overridePackage(targetPackage) + ? targetPackage + : Optional.ofNullable(superClass.getPackage()) + .map(PackageDescription::getName) + .orElse(""); return realPackage + className + "$" + SUFFIX + "$" + randomString.nextString(); } @@ -202,25 +210,27 @@ static class ShortCircuitReturnNull extends IfNullElse { // Create a new FieldValueGetter subclass. @SuppressWarnings("unchecked") - public static DynamicType.Builder subclassGetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { + public static + DynamicType.Builder> subclassGetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { TypeDescription.Generic getterGenericType = TypeDescription.Generic.Builder.parameterizedType( FieldValueGetter.class, objectType, fieldType) .build(); - return (DynamicType.Builder) + return (DynamicType.Builder>) byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(getterGenericType); } // Create a new FieldValueSetter subclass. @SuppressWarnings("unchecked") - public static DynamicType.Builder subclassSetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { + public static + DynamicType.Builder> subclassSetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { TypeDescription.Generic setterGenericType = TypeDescription.Generic.Builder.parameterizedType( FieldValueSetter.class, objectType, fieldType) .build(); - return (DynamicType.Builder) + return (DynamicType.Builder) byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(setterGenericType); } @@ -252,9 +262,11 @@ public TypeConversion createSetterConversions(StackManipulati // Base class used below to convert types. @SuppressWarnings("unchecked") public abstract static class TypeConversion { - public T convert(TypeDescriptor typeDescriptor) { + public T convert(TypeDescriptor typeDescriptor) { if (typeDescriptor.isArray() - && !typeDescriptor.getComponentType().getRawType().equals(byte.class)) { + && !Preconditions.checkNotNull(typeDescriptor.getComponentType()) + .getRawType() + .equals(byte.class)) { // Byte arrays are special, so leave those alone. return convertArray(typeDescriptor); } else if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Map.class))) { @@ -339,28 +351,32 @@ protected ConvertType(boolean returnRawTypes) { @Override protected Type convertArray(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(type.getComponentType()); + TypeDescriptor ret = + createCollectionType(Preconditions.checkNotNull(type.getComponentType())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertCollection(TypeDescriptor type) { - TypeDescriptor ret = - createCollectionType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + TypeDescriptor ret = + createCollectionType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertList(TypeDescriptor type) { - TypeDescriptor ret = - createCollectionType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + TypeDescriptor ret = + createCollectionType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertIterable(TypeDescriptor type) { - TypeDescriptor ret = - createIterableType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + TypeDescriptor ret = + createIterableType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @@ -402,8 +418,9 @@ protected Type convertDefault(TypeDescriptor type) { @SuppressWarnings("unchecked") private TypeDescriptor> createCollectionType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = + (TypeDescriptor) + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -411,8 +428,9 @@ private TypeDescriptor> createCollectionType( @SuppressWarnings("unchecked") private TypeDescriptor> createIterableType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = + (TypeDescriptor) + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -424,7 +442,7 @@ private TypeDescriptor> createIterableType( // This function // generates a subclass of Function that can be used to recursively transform each element of the // container. - static Class createCollectionTransformFunction( + static Class createCollectionTransformFunction( Type fromType, Type toType, Function convertElement) { // Generate a TypeDescription for the class we want to generate. TypeDescription.Generic functionGenericType = @@ -432,8 +450,8 @@ static Class createCollectionTransformFunction( Function.class, Primitives.wrap((Class) fromType), Primitives.wrap((Class) toType)) .build(); - DynamicType.Builder builder = - (DynamicType.Builder) + DynamicType.Builder> builder = + (DynamicType.Builder) BYTE_BUDDY .with(new InjectPackageStrategy((Class) fromType)) .subclass(functionGenericType) @@ -467,9 +485,11 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) { .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader(((Class) fromType).getClassLoader()), + ReflectHelpers.findClassLoader(((Class) fromType).getClassLoader()), getClassLoadingStrategy( - ((Class) fromType).getClassLoader() == null ? Function.class : (Class) fromType)) + ((Class) fromType).getClassLoader() == null + ? Function.class + : (Class) fromType)) .getLoaded(); } @@ -551,17 +571,17 @@ public boolean containsValue(Object value) { } @Override - public V2 get(Object key) { + public @Nullable V2 get(Object key) { return delegateMap.get(key); } @Override - public V2 put(K2 key, V2 value) { + public @Nullable V2 put(K2 key, V2 value) { return delegateMap.put(key, value); } @Override - public V2 remove(Object key) { + public @Nullable V2 remove(Object key) { return delegateMap.remove(key); } @@ -639,12 +659,12 @@ protected StackManipulation convertArray(TypeDescriptor type) { // return isComponentTypePrimitive ? Arrays.asList(ArrayUtils.toObject(value)) // : Arrays.asList(value); - TypeDescriptor componentType = type.getComponentType(); + TypeDescriptor componentType = Preconditions.checkNotNull(type.getComponentType()); ForLoadedType loadedArrayType = new ForLoadedType(type.getRawType()); StackManipulation readArrayValue = readValue; // Row always expects to get an Iterable back for array types. Wrap this array into a // List using Arrays.asList before returning. - if (loadedArrayType.getComponentType().isPrimitive()) { + if (Preconditions.checkNotNull(loadedArrayType.getComponentType()).isPrimitive()) { // Arrays.asList doesn't take primitive arrays, so convert first using ArrayUtils.toObject. readArrayValue = new Compound( @@ -672,7 +692,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Generate a SerializableFunction to convert the element-type objects. StackManipulation stackManipulation; - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); @@ -691,11 +711,11 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - TypeDescriptor componentType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -712,10 +732,10 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - TypeDescriptor componentType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -732,10 +752,10 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - TypeDescriptor componentType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -752,8 +772,8 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0, Collections.emptyMap()); - final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1, Collections.emptyMap()); + final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); + final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); Type convertedKeyType = getFactory().createTypeConversion(true).convert(keyType); Type convertedValueType = getFactory().createTypeConversion(true).convert(valueType); @@ -977,16 +997,18 @@ protected StackManipulation convertArray(TypeDescriptor type) { // return isPrimitive ? toArray : ArrayUtils.toPrimitive(toArray); ForLoadedType loadedType = new ForLoadedType(type.getRawType()); + TypeDescription loadedTypeComponentType = Verify.verifyNotNull(loadedType.getComponentType()); + // The type of the array containing the (possibly) boxed values. TypeDescription arrayType = - TypeDescription.Generic.Builder.rawType(loadedType.getComponentType().asBoxed()) + TypeDescription.Generic.Builder.rawType(loadedTypeComponentType.asBoxed()) .asArray() .build() .asErasure(); - Type rowElementType = - getFactory().createTypeConversion(false).convert(type.getComponentType()); - final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(type.getComponentType()); + TypeDescriptor componentType = Preconditions.checkNotNull(type.getComponentType()); + Type rowElementType = getFactory().createTypeConversion(false).convert(componentType); + final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(componentType); StackManipulation readTransformedValue = readValue; if (!arrayElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -1006,7 +1028,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Call Collection.toArray(T[[]) to extract the array. Push new T[0] on the stack // before // calling toArray. - ArrayFactory.forType(loadedType.getComponentType().asBoxed().asGenericType()) + ArrayFactory.forType(loadedTypeComponentType.asBoxed().asGenericType()) .withValues(Collections.emptyList()), MethodInvocation.invoke( COLLECTION_TYPE @@ -1023,7 +1045,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Cast the result to T[]. TypeCasting.to(arrayType)); - if (loadedType.getComponentType().isPrimitive()) { + if (loadedTypeComponentType.isPrimitive()) { // The array we extract will be an array of objects. If the pojo field is an array of // primitive types, we need to then convert to an array of unboxed objects. stackManipulation = @@ -1042,12 +1064,9 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); - final TypeDescriptor iterableElementType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + final TypeDescriptor iterableElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(iterableElementType); if (!iterableElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = new ForLoadedType( @@ -1065,12 +1084,9 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); - final TypeDescriptor collectionElementType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + final TypeDescriptor collectionElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(collectionElementType); if (!collectionElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -1089,12 +1105,9 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); - final TypeDescriptor collectionElementType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + final TypeDescriptor collectionElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(collectionElementType); StackManipulation readTrasformedValue = readValue; if (!collectionElementType.hasUnresolvedParameters()) { @@ -1122,18 +1135,12 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - Type rowKeyType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getMapType(type, 0, Collections.emptyMap())); - final TypeDescriptor keyElementType = - ReflectUtils.getMapType(type, 0, Collections.emptyMap()); - Type rowValueType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getMapType(type, 1, Collections.emptyMap())); - final TypeDescriptor valueElementType = - ReflectUtils.getMapType(type, 1, Collections.emptyMap()); + final TypeDescriptor keyElementType = + Preconditions.checkNotNull(ReflectUtils.getMapType(type, 0)); + final TypeDescriptor valueElementType = + Preconditions.checkNotNull(ReflectUtils.getMapType(type, 1)); + Type rowKeyType = getFactory().createTypeConversion(false).convert(keyElementType); + Type rowValueType = getFactory().createTypeConversion(false).convert(valueElementType); StackManipulation readTrasformedValue = readValue; if (!keyElementType.hasUnresolvedParameters() @@ -1348,12 +1355,12 @@ protected StackManipulation convertDefault(TypeDescriptor type) { * constructor. */ static class ConstructorCreateInstruction extends InvokeUserCreateInstruction { - private final Constructor constructor; + private final Constructor constructor; ConstructorCreateInstruction( List fields, - Class targetClass, - Constructor constructor, + Class targetClass, + Constructor constructor, TypeConversionsFactory typeConversionsFactory) { super( fields, @@ -1391,7 +1398,7 @@ static class StaticFactoryMethodInstruction extends InvokeUserCreateInstruction StaticFactoryMethodInstruction( List fields, - Class targetClass, + Class targetClass, Method creator, TypeConversionsFactory typeConversionsFactory) { super( @@ -1415,14 +1422,14 @@ protected StackManipulation afterPushingParameters() { static class InvokeUserCreateInstruction implements Implementation { protected final List fields; - protected final Class targetClass; + protected final Class targetClass; protected final List parameters; protected final Map fieldMapping; private final TypeConversionsFactory typeConversionsFactory; protected InvokeUserCreateInstruction( List fields, - Class targetClass, + Class targetClass, List parameters, TypeConversionsFactory typeConversionsFactory) { this.fields = fields; @@ -1440,11 +1447,15 @@ protected InvokeUserCreateInstruction( // actual Java field or method names. FieldValueTypeInformation fieldValue = checkNotNull(fields.get(i)); fieldsByLogicalName.put(fieldValue.getName(), i); - if (fieldValue.getField() != null) { - fieldsByJavaClassMember.put(fieldValue.getField().getName(), i); - } else if (fieldValue.getMethod() != null) { - String name = ReflectUtils.stripGetterPrefix(fieldValue.getMethod().getName()); - fieldsByJavaClassMember.put(name, i); + Field field = fieldValue.getField(); + if (field != null) { + fieldsByJavaClassMember.put(field.getName(), i); + } else { + Method method = fieldValue.getMethod(); + if (method != null) { + String name = ReflectUtils.stripGetterPrefix(method.getName()); + fieldsByJavaClassMember.put(name, i); + } } } @@ -1491,14 +1502,14 @@ public ByteCodeAppender appender(final Target implementationTarget) { Parameter parameter = parameters.get(i); ForLoadedType convertedType = new ForLoadedType( - (Class) convertType.convert(TypeDescriptor.of(parameter.getParameterizedType()))); + (Class) convertType.convert(TypeDescriptor.of(parameter.getType()))); // The instruction to read the parameter. Use the fieldMapping to reorder parameters as // necessary. StackManipulation readParameter = new StackManipulation.Compound( MethodVariableAccess.REFERENCE.loadFrom(1), - IntegerConstant.forValue(fieldMapping.get(i)), + IntegerConstant.forValue(Preconditions.checkNotNull(fieldMapping.get(i))), ArrayAccess.REFERENCE.load(), TypeCasting.to(convertedType)); stackManipulation = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java index e98a0b9495cf..7f2403035d97 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java @@ -22,7 +22,6 @@ import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Type; -import java.util.Collections; import java.util.ServiceLoader; import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; @@ -37,7 +36,6 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.JavaFieldSchema.JavaFieldTypeSupplier; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; @@ -58,7 +56,6 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) -@Internal public class ConvertHelpers { private static class SchemaInformationProviders { private static final ServiceLoader INSTANCE = @@ -151,8 +148,7 @@ public static SerializableFunction getConvertPrimitive( TypeDescriptor outputTypeDescriptor, TypeConversionsFactory typeConversionsFactory) { FieldType expectedFieldType = - StaticSchemaInference.fieldFromType( - outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE, Collections.emptyMap()); + StaticSchemaInference.fieldFromType(outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE); if (!expectedFieldType.equals(fieldType)) { throw new IllegalArgumentException( "Element argument type " diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java index 83f6b5c928d8..ee4868ddb2b6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java @@ -22,10 +22,11 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.lang.reflect.Type; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; @@ -43,7 +44,6 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; @@ -56,26 +56,32 @@ import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; /** A set of utilities to generate getter and setter classes for JavaBean objects. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) -@Internal +@SuppressWarnings({"rawtypes"}) public class JavaBeanUtils { + + private static final String X_WITH_NULL_METHOD_ERROR_FMT = + "a %s FieldValueTypeInformation object has a null method field"; + public static final String GETTER_WITH_NULL_METHOD_ERROR = + String.format(X_WITH_NULL_METHOD_ERROR_FMT, "getter"); + public static final String SETTER_WITH_NULL_METHOD_ERROR = + String.format(X_WITH_NULL_METHOD_ERROR_FMT, "setter"); + /** Create a {@link Schema} for a Java Bean class. */ public static Schema schemaFromJavaBeanClass( - TypeDescriptor typeDescriptor, - FieldValueTypeSupplier fieldValueTypeSupplier, - Map boundTypes) { - return StaticSchemaInference.schemaFromClass( - typeDescriptor, fieldValueTypeSupplier, boundTypes); + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { + return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); } private static final String CONSTRUCTOR_HELP_STRING = - "In order to infer a Schema from a Java Bean, it must have a constructor annotated with @SchemaCreate, or it must have a compatible setter for every getter used as a Schema field."; + "In order to infer a Schema from a Java Bean, it must have a constructor annotated with" + + " @SchemaCreate, or it must have a compatible setter for every getter used as a Schema" + + " field."; // Make sure that there are matching setters and getters. public static void validateJavaBean( @@ -94,23 +100,26 @@ public static void validateJavaBean( for (FieldValueTypeInformation type : getters) { FieldValueTypeInformation setterType = setterMap.get(type.getName()); + Method m = Preconditions.checkNotNull(type.getMethod(), GETTER_WITH_NULL_METHOD_ERROR); if (setterType == null) { throw new RuntimeException( String.format( - "Java Bean '%s' contains a getter for field '%s', but does not contain a matching setter. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + "Java Bean '%s' contains a getter for field '%s', but does not contain a matching" + + " setter. %s", + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } if (!type.getType().equals(setterType.getType())) { throw new RuntimeException( String.format( "Java Bean '%s' contains a setter for field '%s' that has a mismatching type. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } if (!type.isNullable() == setterType.isNullable()) { throw new RuntimeException( String.format( - "Java Bean '%s' contains a setter for field '%s' that has a mismatching nullable attribute. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + "Java Bean '%s' contains a setter for field '%s' that has a mismatching nullable" + + " attribute. %s", + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } } } @@ -132,36 +141,41 @@ public static List getFieldTypes( // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map, List> CACHED_GETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_GETTERS = Maps.newConcurrentMap(); /** * Return the list of {@link FieldValueGetter}s for a Java Bean class * *

The returned list is ordered by the order of fields in the schema. */ - public static List getGetters( - TypeDescriptor typeDescriptor, + public static List> getGetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { - return CACHED_GETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - return types.stream() - .map(t -> createGetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return types.stream() + .map(t -> JavaBeanUtils.createGetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + }); } - public static FieldValueGetter createGetter( - FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - DynamicType.Builder builder = + public static + FieldValueGetter createGetter( + FieldValueTypeInformation typeInformation, + TypeConversionsFactory typeConversionsFactory) { + final Method m = + Preconditions.checkNotNull(typeInformation.getMethod(), GETTER_WITH_NULL_METHOD_ERROR); + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface( BYTE_BUDDY, - typeInformation.getMethod().getDeclaringClass(), + m.getDeclaringClass(), typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); builder = implementGetterMethods(builder, typeInformation, typeConversionsFactory); try { @@ -169,9 +183,8 @@ public static FieldValueGetter createGetter( .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader( - typeInformation.getMethod().getDeclaringClass().getClassLoader()), - getClassLoadingStrategy(typeInformation.getMethod().getDeclaringClass())) + ReflectHelpers.findClassLoader(m.getDeclaringClass().getClassLoader()), + getClassLoadingStrategy(m.getDeclaringClass())) .getLoaded() .getDeclaredConstructor() .newInstance(); @@ -184,10 +197,11 @@ public static FieldValueGetter createGetter( } } - private static DynamicType.Builder implementGetterMethods( - DynamicType.Builder builder, - FieldValueTypeInformation typeInformation, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementGetterMethods( + DynamicType.Builder> builder, + FieldValueTypeInformation typeInformation, + TypeConversionsFactory typeConversionsFactory) { return builder .method(ElementMatchers.named("name")) .intercept(FixedValue.reference(typeInformation.getName())) @@ -221,12 +235,14 @@ public static List getSetters( }); } - public static FieldValueSetter createSetter( + public static FieldValueSetter createSetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - DynamicType.Builder builder = + final Method m = + Preconditions.checkNotNull(typeInformation.getMethod(), SETTER_WITH_NULL_METHOD_ERROR); + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, - typeInformation.getMethod().getDeclaringClass(), + m.getDeclaringClass(), typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); builder = implementSetterMethods(builder, typeInformation, typeConversionsFactory); try { @@ -234,9 +250,8 @@ public static FieldValueSetter createSetter( .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader( - typeInformation.getMethod().getDeclaringClass().getClassLoader()), - getClassLoadingStrategy(typeInformation.getMethod().getDeclaringClass())) + ReflectHelpers.findClassLoader(m.getDeclaringClass().getClassLoader()), + getClassLoadingStrategy(m.getDeclaringClass())) .getLoaded() .getDeclaredConstructor() .newInstance(); @@ -249,10 +264,11 @@ public static FieldValueSetter createSetter( } } - private static DynamicType.Builder implementSetterMethods( - DynamicType.Builder builder, - FieldValueTypeInformation fieldValueTypeInformation, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementSetterMethods( + DynamicType.Builder> builder, + FieldValueTypeInformation fieldValueTypeInformation, + TypeConversionsFactory typeConversionsFactory) { return builder .method(ElementMatchers.named("name")) .intercept(FixedValue.reference(fieldValueTypeInformation.getName())) @@ -364,6 +380,11 @@ public static SchemaUserTypeCreator createStaticCreator( } } + public static > Comparator comparingNullFirst( + Function keyExtractor) { + return Comparator.comparing(keyExtractor, Comparator.nullsFirst(Comparator.naturalOrder())); + } + // Implements a method to read a public getter out of an object. private static class InvokeGetterInstruction implements Implementation { private final FieldValueTypeInformation typeInformation; @@ -392,7 +413,10 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Method param is offset 1 (offset 0 is the this parameter). MethodVariableAccess.REFERENCE.loadFrom(1), // Invoke the getter - MethodInvocation.invoke(new ForLoadedMethod(typeInformation.getMethod()))); + MethodInvocation.invoke( + new ForLoadedMethod( + Preconditions.checkNotNull( + typeInformation.getMethod(), GETTER_WITH_NULL_METHOD_ERROR)))); StackManipulation stackManipulation = new StackManipulation.Compound( @@ -434,7 +458,9 @@ public ByteCodeAppender appender(final Target implementationTarget) { // The instruction to read the field. StackManipulation readField = MethodVariableAccess.REFERENCE.loadFrom(2); - Method method = fieldValueTypeInformation.getMethod(); + Method method = + Preconditions.checkNotNull( + fieldValueTypeInformation.getMethod(), SETTER_WITH_NULL_METHOD_ERROR); boolean setterMethodReturnsVoid = method.getReturnType().equals(Void.TYPE); // Read the object onto the stack. StackManipulation stackManipulation = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index 1e60c9312cb3..8e33d321a1c6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -49,7 +49,6 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; @@ -63,23 +62,20 @@ import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; -import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.NonNull; /** A set of utilities to generate getter and setter classes for POJOs. */ @SuppressWarnings({ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) }) -@Internal public class POJOUtils { public static Schema schemaFromPojoClass( - TypeDescriptor typeDescriptor, - FieldValueTypeSupplier fieldValueTypeSupplier, - Map boundTypes) { - return StaticSchemaInference.schemaFromClass( - typeDescriptor, fieldValueTypeSupplier, boundTypes); + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { + return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); } // Static ByteBuddy instance used by all helpers. @@ -99,38 +95,40 @@ public static List getFieldTypes( // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_GETTERS = Maps.newConcurrentMap(); - public static List getGetters( - TypeDescriptor typeDescriptor, + public static List> getGetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { // Return the getters ordered by their position in the schema. - return CACHED_GETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - List getters = - types.stream() - .map(t -> createGetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - if (getters.size() != schema.getFieldCount()) { - throw new RuntimeException( - "Was not able to generate getters for schema: " - + schema - + " class: " - + typeDescriptor); - } - return getters; - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + List> getters = + types.stream() + .>map( + t -> POJOUtils.createGetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + if (getters.size() != schema.getFieldCount()) { + throw new RuntimeException( + "Was not able to generate getters for schema: " + + schema + + " class: " + + typeDescriptor); + } + return (List) getters; + }); } // The list of constructors for a class is cached, so we only create the classes the first time // getConstructor is called. - public static final Map CACHED_CREATORS = + public static final Map, SchemaUserTypeCreator> CACHED_CREATORS = Maps.newConcurrentMap(); public static SchemaUserTypeCreator getSetFieldCreator( @@ -155,7 +153,9 @@ private static SchemaUserTypeCreator createSetFieldCreator( TypeConversionsFactory typeConversionsFactory) { // Get the list of class fields ordered by schema. List fields = - types.stream().map(FieldValueTypeInformation::getField).collect(Collectors.toList()); + types.stream() + .map(type -> Preconditions.checkNotNull(type.getField())) + .collect(Collectors.toList()); try { DynamicType.Builder builder = BYTE_BUDDY @@ -180,14 +180,16 @@ private static SchemaUserTypeCreator createSetFieldCreator( | InvocationTargetException e) { throw new RuntimeException( String.format( - "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must have a zero-argument constructor, or a constructor annotated with @SchemaCreate.", + "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must" + + " have a zero-argument constructor, or a constructor annotated with" + + " @SchemaCreate.", clazz, schema)); } } - public static SchemaUserTypeCreator getConstructorCreator( - TypeDescriptor typeDescriptor, - Constructor constructor, + public static SchemaUserTypeCreator getConstructorCreator( + TypeDescriptor typeDescriptor, + Constructor constructor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { @@ -196,13 +198,13 @@ public static SchemaUserTypeCreator getConstructorCreator( c -> { List types = fieldValueTypeSupplier.get(typeDescriptor, schema); - return createConstructorCreator( + return POJOUtils.createConstructorCreator( typeDescriptor.getRawType(), constructor, schema, types, typeConversionsFactory); }); } public static SchemaUserTypeCreator createConstructorCreator( - Class clazz, + Class clazz, Constructor constructor, Schema schema, List types, @@ -296,17 +298,16 @@ public static SchemaUserTypeCreator createStaticCreator( * } * */ - @SuppressWarnings("unchecked") - static @Nullable FieldValueGetter createGetter( + static FieldValueGetter<@NonNull ObjectT, ValueT> createGetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - Field field = typeInformation.getField(); - DynamicType.Builder builder = + Field field = Preconditions.checkNotNull(typeInformation.getField()); + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface( BYTE_BUDDY, field.getDeclaringClass(), typeConversionsFactory .createTypeConversion(false) - .convert(TypeDescriptor.of(field.getGenericType()))); + .convert(TypeDescriptor.of(field.getType()))); builder = implementGetterMethods(builder, field, typeInformation.getName(), typeConversionsFactory); try { @@ -327,11 +328,12 @@ public static SchemaUserTypeCreator createStaticCreator( } } - private static DynamicType.Builder implementGetterMethods( - DynamicType.Builder builder, - Field field, - String name, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementGetterMethods( + DynamicType.Builder> builder, + Field field, + String name, + TypeConversionsFactory typeConversionsFactory) { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .method(ElementMatchers.named("name")) @@ -342,24 +344,25 @@ private static DynamicType.Builder implementGetterMethods( // The list of setters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_SETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_SETTERS = Maps.newConcurrentMap(); - public static List getSetters( - TypeDescriptor typeDescriptor, + public static List> getSetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { // Return the setters, ordered by their position in the schema. - return CACHED_SETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - return types.stream() - .map(t -> createSetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_SETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return types.stream() + .map(t -> createSetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + }); } /** @@ -381,14 +384,14 @@ public static List getSetters( @SuppressWarnings("unchecked") private static FieldValueSetter createSetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - Field field = typeInformation.getField(); - DynamicType.Builder builder = + Field field = Preconditions.checkNotNull(typeInformation.getField()); + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, field.getDeclaringClass(), typeConversionsFactory .createTypeConversion(false) - .convert(TypeDescriptor.of(field.getGenericType()))); + .convert(TypeDescriptor.of(field.getType()))); builder = implementSetterMethods(builder, field, typeConversionsFactory); try { return builder @@ -408,10 +411,11 @@ private static FieldValueSetter createSetter( } } - private static DynamicType.Builder implementSetterMethods( - DynamicType.Builder builder, - Field field, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementSetterMethods( + DynamicType.Builder> builder, + Field field, + TypeConversionsFactory typeConversionsFactory) { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .method(ElementMatchers.named("name")) @@ -496,7 +500,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Do any conversions necessary. typeConversionsFactory .createSetterConversions(readField) - .convert(TypeDescriptor.of(field.getGenericType())), + .convert(TypeDescriptor.of(field.getType())), // Now update the field and return void. FieldAccess.forField(new ForLoadedField(field)).write(), MethodReturn.VOID); @@ -510,11 +514,11 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Implements a method to construct an object. static class SetFieldCreateInstruction implements Implementation { private final List fields; - private final Class pojoClass; + private final Class pojoClass; private final TypeConversionsFactory typeConversionsFactory; SetFieldCreateInstruction( - List fields, Class pojoClass, TypeConversionsFactory typeConversionsFactory) { + List fields, Class pojoClass, TypeConversionsFactory typeConversionsFactory) { this.fields = fields; this.pojoClass = pojoClass; this.typeConversionsFactory = typeConversionsFactory; @@ -551,8 +555,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Field field = fields.get(i); ForLoadedType convertedType = - new ForLoadedType( - (Class) convertType.convert(TypeDescriptor.of(field.getGenericType()))); + new ForLoadedType((Class) convertType.convert(TypeDescriptor.of(field.getType()))); // The instruction to read the parameter. StackManipulation readParameter = @@ -569,7 +572,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Do any conversions necessary. typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(field.getGenericType())), + .convert(TypeDescriptor.of(field.getType())), // Now update the field. FieldAccess.forField(new ForLoadedField(field)).write()); stackManipulation = new StackManipulation.Compound(stackManipulation, updateField); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index 32cfa5689193..423fea4c3845 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -26,7 +26,6 @@ import java.lang.reflect.Modifier; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; -import java.lang.reflect.TypeVariable; import java.security.InvalidParameterException; import java.util.Arrays; import java.util.Collection; @@ -36,7 +35,6 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.SchemaCreate; -import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; @@ -211,8 +209,7 @@ public static String stripSetterPrefix(String method) { } /** For an array T[] or a subclass of Iterable, return a TypeDescriptor describing T. */ - public static @Nullable TypeDescriptor getIterableComponentType( - TypeDescriptor valueType, Map boundTypes) { + public static @Nullable TypeDescriptor getIterableComponentType(TypeDescriptor valueType) { TypeDescriptor componentType = null; if (valueType.isArray()) { Type component = valueType.getComponentType().getType(); @@ -226,7 +223,7 @@ public static String stripSetterPrefix(String method) { ParameterizedType ptype = (ParameterizedType) collection.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); checkArgument(params.length == 1); - componentType = TypeDescriptor.of(resolveType(params[0], boundTypes)); + componentType = TypeDescriptor.of(params[0]); } else { throw new RuntimeException("Collection parameter is not parameterized!"); } @@ -234,15 +231,14 @@ public static String stripSetterPrefix(String method) { return componentType; } - public static TypeDescriptor getMapType( - TypeDescriptor valueType, int index, Map boundTypes) { + public static TypeDescriptor getMapType(TypeDescriptor valueType, int index) { TypeDescriptor mapType = null; if (valueType.isSubtypeOf(TypeDescriptor.of(Map.class))) { TypeDescriptor> map = valueType.getSupertype(Map.class); if (map.getType() instanceof ParameterizedType) { ParameterizedType ptype = (ParameterizedType) map.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - mapType = TypeDescriptor.of(resolveType(params[index], boundTypes)); + mapType = TypeDescriptor.of(params[index]); } else { throw new RuntimeException("Map type is not parameterized! " + map); } @@ -255,49 +251,4 @@ public static TypeDescriptor boxIfPrimitive(TypeDescriptor typeDescriptor) { ? TypeDescriptor.of(Primitives.wrap(typeDescriptor.getRawType())) : typeDescriptor; } - - /** - * If this (or a base class)is a paremeterized type, return a map of all TypeVariable->Type - * bindings. This allows us to resolve types in any contained fields or methods. - */ - public static Map getAllBoundTypes(TypeDescriptor typeDescriptor) { - Map boundParameters = Maps.newHashMap(); - TypeDescriptor currentType = typeDescriptor; - do { - if (currentType.getType() instanceof ParameterizedType) { - ParameterizedType parameterizedType = (ParameterizedType) currentType.getType(); - TypeVariable[] typeVariables = currentType.getRawType().getTypeParameters(); - Type[] typeArguments = parameterizedType.getActualTypeArguments(); - ; - if (typeArguments.length != typeVariables.length) { - throw new RuntimeException("Unmatching arguments lengths in type " + typeDescriptor); - } - for (int i = 0; i < typeVariables.length; ++i) { - boundParameters.put(typeVariables[i], typeArguments[i]); - } - } - Type superClass = currentType.getRawType().getGenericSuperclass(); - if (superClass == null || superClass.equals(Object.class)) { - break; - } - currentType = TypeDescriptor.of(superClass); - } while (true); - return boundParameters; - } - - public static Type resolveType(Type type, Map boundTypes) { - TypeDescriptor typeDescriptor = TypeDescriptor.of(type); - if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Iterable.class)) - || typeDescriptor.isSubtypeOf(TypeDescriptor.of(Map.class))) { - // Don't resolve these as we special case map and interable. - return type; - } - - if (type instanceof TypeVariable) { - TypeVariable typeVariable = (TypeVariable) type; - return Preconditions.checkArgumentNotNull(boundTypes.get(typeVariable)); - } else { - return type; - } - } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java index 275bc41be53d..196ee6f86593 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java @@ -19,7 +19,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import java.lang.reflect.Type; +import java.lang.reflect.ParameterizedType; import java.math.BigDecimal; import java.nio.ByteBuffer; import java.util.Arrays; @@ -29,12 +29,10 @@ import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; -import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.ReadableInstant; @@ -44,7 +42,6 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) -@Internal public class StaticSchemaInference { public static List sortBySchema( List types, Schema schema) { @@ -88,17 +85,14 @@ enum MethodType { * public getter methods, or special annotations on the class. */ public static Schema schemaFromClass( - TypeDescriptor typeDescriptor, - FieldValueTypeSupplier fieldValueTypeSupplier, - Map boundTypes) { - return schemaFromClass(typeDescriptor, fieldValueTypeSupplier, new HashMap<>(), boundTypes); + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { + return schemaFromClass(typeDescriptor, fieldValueTypeSupplier, new HashMap<>()); } private static Schema schemaFromClass( TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier, - Map, Schema> alreadyVisitedSchemas, - Map boundTypes) { + Map, Schema> alreadyVisitedSchemas) { if (alreadyVisitedSchemas.containsKey(typeDescriptor)) { Schema existingSchema = alreadyVisitedSchemas.get(typeDescriptor); if (existingSchema == null) { @@ -112,7 +106,7 @@ private static Schema schemaFromClass( Schema.Builder builder = Schema.builder(); for (FieldValueTypeInformation type : fieldValueTypeSupplier.get(typeDescriptor)) { Schema.FieldType fieldType = - fieldFromType(type.getType(), fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes); + fieldFromType(type.getType(), fieldValueTypeSupplier, alreadyVisitedSchemas); Schema.Field f = type.isNullable() ? Schema.Field.nullable(type.getName(), fieldType) @@ -129,18 +123,15 @@ private static Schema schemaFromClass( /** Map a Java field type to a Beam Schema FieldType. */ public static Schema.FieldType fieldFromType( - TypeDescriptor type, - FieldValueTypeSupplier fieldValueTypeSupplier, - Map boundTypes) { - return fieldFromType(type, fieldValueTypeSupplier, new HashMap<>(), boundTypes); + TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier) { + return fieldFromType(type, fieldValueTypeSupplier, new HashMap<>()); } // TODO(https://github.com/apache/beam/issues/21567): support type inference for logical types private static Schema.FieldType fieldFromType( TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier, - Map, Schema> alreadyVisitedSchemas, - Map boundTypes) { + Map, Schema> alreadyVisitedSchemas) { FieldType primitiveType = PRIMITIVE_TYPES.get(type.getRawType()); if (primitiveType != null) { return primitiveType; @@ -161,25 +152,27 @@ private static Schema.FieldType fieldFromType( } else { // Otherwise this is an array type. return FieldType.array( - fieldFromType(component, fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes)); + fieldFromType(component, fieldValueTypeSupplier, alreadyVisitedSchemas)); } } else if (type.isSubtypeOf(TypeDescriptor.of(Map.class))) { - FieldType keyType = - fieldFromType( - ReflectUtils.getMapType(type, 0, boundTypes), - fieldValueTypeSupplier, - alreadyVisitedSchemas, - boundTypes); - FieldType valueType = - fieldFromType( - ReflectUtils.getMapType(type, 1, boundTypes), - fieldValueTypeSupplier, - alreadyVisitedSchemas, - boundTypes); - checkArgument( - keyType.getTypeName().isPrimitiveType(), - "Only primitive types can be map keys. type: " + keyType.getTypeName()); - return FieldType.map(keyType, valueType); + TypeDescriptor> map = type.getSupertype(Map.class); + if (map.getType() instanceof ParameterizedType) { + ParameterizedType ptype = (ParameterizedType) map.getType(); + java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); + checkArgument(params.length == 2); + FieldType keyType = + fieldFromType( + TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas); + FieldType valueType = + fieldFromType( + TypeDescriptor.of(params[1]), fieldValueTypeSupplier, alreadyVisitedSchemas); + checkArgument( + keyType.getTypeName().isPrimitiveType(), + "Only primitive types can be map keys. type: " + keyType.getTypeName()); + return FieldType.map(keyType, valueType); + } else { + throw new RuntimeException("Cannot infer schema from unparameterized map."); + } } else if (type.isSubtypeOf(TypeDescriptor.of(CharSequence.class))) { return FieldType.STRING; } else if (type.isSubtypeOf(TypeDescriptor.of(ReadableInstant.class))) { @@ -187,22 +180,26 @@ private static Schema.FieldType fieldFromType( } else if (type.isSubtypeOf(TypeDescriptor.of(ByteBuffer.class))) { return FieldType.BYTES; } else if (type.isSubtypeOf(TypeDescriptor.of(Iterable.class))) { - FieldType elementType = - fieldFromType( - Preconditions.checkArgumentNotNull( - ReflectUtils.getIterableComponentType(type, boundTypes)), - fieldValueTypeSupplier, - alreadyVisitedSchemas, - boundTypes); - // TODO: should this be AbstractCollection? - if (type.isSubtypeOf(TypeDescriptor.of(Collection.class))) { - return FieldType.array(elementType); + TypeDescriptor> iterable = type.getSupertype(Iterable.class); + if (iterable.getType() instanceof ParameterizedType) { + ParameterizedType ptype = (ParameterizedType) iterable.getType(); + java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); + checkArgument(params.length == 1); + // TODO: should this be AbstractCollection? + if (type.isSubtypeOf(TypeDescriptor.of(Collection.class))) { + return FieldType.array( + fieldFromType( + TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas)); + } else { + return FieldType.iterable( + fieldFromType( + TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas)); + } } else { - return FieldType.iterable(elementType); + throw new RuntimeException("Cannot infer schema from unparameterized collection."); } } else { - return FieldType.row( - schemaFromClass(type, fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes)); + return FieldType.row(schemaFromClass(type, fieldValueTypeSupplier, alreadyVisitedSchemas)); } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java index 8da3cf71af9f..159f92cd5e87 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java @@ -20,6 +20,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableLikeCoder; import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; @@ -82,6 +83,81 @@ public static Iterables iterables() { return new Iterables<>(); } + /** + * Returns a {@link PTransform} that flattens the input {@link PCollection} with a given {@link + * PCollection} resulting in a {@link PCollection} containing all the elements of both {@link + * PCollection}s as its output. + * + *

This is equivalent to creating a {@link PCollectionList} containing both the input and + * {@code other} and then applying {@link #pCollections()}, but has the advantage that it can be + * more easily used inline. + * + *

Both {@cpde PCollections} must have equal {@link WindowFn}s. The output elements of {@code + * Flatten} are in the same windows and have the same timestamps as their corresponding input + * elements. The output {@code PCollection} will have the same {@link WindowFn} as both inputs. + * + * @param other the other PCollection to flatten with the input + * @param the type of the elements in the input and output {@code PCollection}s. + */ + public static PTransform, PCollection> with(PCollection other) { + return new FlattenWithPCollection<>(other); + } + + /** Implementation of {@link #with(PCollection)}. */ + private static class FlattenWithPCollection + extends PTransform, PCollection> { + // We only need to access this at pipeline construction time. + private final transient PCollection other; + + public FlattenWithPCollection(PCollection other) { + this.other = other; + } + + @Override + public PCollection expand(PCollection input) { + return PCollectionList.of(input).and(other).apply(pCollections()); + } + + @Override + public String getKindString() { + return "Flatten.With"; + } + } + + /** + * Returns a {@link PTransform} that flattens the input {@link PCollection} with the output of + * another {@link PTransform} resulting in a {@link PCollection} containing all the elements of + * both the input {@link PCollection}s and the output of the given {@link PTransform} as its + * output. + * + *

This is equivalent to creating a {@link PCollectionList} containing both the input and the + * output of {@code other} and then applying {@link #pCollections()}, but has the advantage that + * it can be more easily used inline. + * + *

Both {@code PCollections} must have equal {@link WindowFn}s. The output elements of {@code + * Flatten} are in the same windows and have the same timestamps as their corresponding input + * elements. The output {@code PCollection} will have the same {@link WindowFn} as both inputs. + * + * @param the type of the elements in the input and output {@code PCollection}s. + * @param other a PTransform whose ouptput should be flattened with the input + */ + public static PTransform, PCollection> with( + PTransform> other) { + return new PTransform, PCollection>() { + @Override + public PCollection expand(PCollection input) { + return PCollectionList.of(input) + .and(input.getPipeline().apply(other)) + .apply(pCollections()); + } + + @Override + public String getKindString() { + return "Flatten.With"; + } + }; + } + /** * A {@link PTransform} that flattens a {@link PCollectionList} into a {@link PCollection} * containing all the elements of all the {@link PCollection}s in its input. Implements {@link diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java new file mode 100644 index 000000000000..492a1cc84f74 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.function.Consumer; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; + +/** + * A PTransform that returns its input, but also applies its input to an auxiliary PTransform, akin + * to the shell {@code tee} command, which is named after the T-splitter used in plumbing. + * + *

This can be useful to write out or otherwise process an intermediate transform without + * breaking the linear flow of a chain of transforms, e.g. + * + *


+ * {@literal PCollection} input = ... ;
+ * {@literal PCollection} result =
+ *     {@literal input.apply(...)}
+ *     ...
+ *     {@literal input.apply(Tee.of(someSideTransform)}
+ *     ...
+ *     {@literal input.apply(...)};
+ * 
+ * + * @param the element type of the input PCollection + */ +public class Tee extends PTransform, PCollection> { + private final PTransform, ?> consumer; + + /** + * Returns a new Tee PTransform that will apply an auxilary transform to the input as well as pass + * it on. + * + * @param consumer An additional PTransform that should process the input PCollection. Its output + * will be ignored. + * @param the type of the elements in the input {@code PCollection}. + */ + public static Tee of(PTransform, ?> consumer) { + return new Tee<>(consumer); + } + + /** + * Returns a new Tee PTransform that will apply an auxilary transform to the input as well as pass + * it on. + * + * @param consumer An arbitrary {@link Consumer} that will be wrapped in a PTransform and applied + * to the input. Its output will be ignored. + * @param the type of the elements in the input {@code PCollection}. + */ + public static Tee of(Consumer> consumer) { + return of( + new PTransform, PCollectionTuple>() { + @Override + public PCollectionTuple expand(PCollection input) { + consumer.accept(input); + return PCollectionTuple.empty(input.getPipeline()); + } + }); + } + + private Tee(PTransform, ?> consumer) { + this.consumer = consumer; + } + + @Override + public PCollection expand(PCollection input) { + input.apply(consumer); + return input; + } + + @Override + protected String getKindString() { + return "Tee(" + consumer.getName() + ")"; + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java index cd38da100a79..0999f2ad0771 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java @@ -45,9 +45,6 @@ *
  • Return {@link CompletableFuture} only to the producer of a future value. * */ -@SuppressWarnings({ - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) public class MoreFutures { /** @@ -99,22 +96,18 @@ public static boolean isCancelled(CompletionStage future) { */ public static CompletionStage supplyAsync( ThrowingSupplier supplier, ExecutorService executorService) { - CompletableFuture result = new CompletableFuture<>(); - - CompletionStage wrapper = - CompletableFuture.runAsync( - () -> { - try { - result.complete(supplier.get()); - } catch (InterruptedException e) { - result.completeExceptionally(e); - Thread.currentThread().interrupt(); - } catch (Throwable t) { - result.completeExceptionally(t); - } - }, - executorService); - return wrapper.thenCompose(nothing -> result); + return CompletableFuture.supplyAsync( + () -> { + try { + return supplier.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CompletionException(e); + } catch (Throwable t) { + throw new CompletionException(t); + } + }, + executorService); } /** @@ -132,23 +125,18 @@ public static CompletionStage supplyAsync(ThrowingSupplier supplier) { */ public static CompletionStage runAsync( ThrowingRunnable runnable, ExecutorService executorService) { - CompletableFuture result = new CompletableFuture<>(); - - CompletionStage wrapper = - CompletableFuture.runAsync( - () -> { - try { - runnable.run(); - result.complete(null); - } catch (InterruptedException e) { - result.completeExceptionally(e); - Thread.currentThread().interrupt(); - } catch (Throwable t) { - result.completeExceptionally(t); - } - }, - executorService); - return wrapper.thenCompose(nothing -> result); + return CompletableFuture.runAsync( + () -> { + try { + runnable.run(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CompletionException(e); + } catch (Throwable t) { + throw new CompletionException(t); + } + }, + executorService); } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java index aeb76492bb6d..c2d945bbaac1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java @@ -44,6 +44,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSortedSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; +import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for working with with {@link Class Classes} and {@link Method Methods}. */ @SuppressWarnings({"nullness", "keyfor"}) // TODO(https://github.com/apache/beam/issues/20497) @@ -216,7 +217,7 @@ public static Iterable loadServicesOrdered(Class iface) { * which by default would use the proposed {@code ClassLoader}, which can be null. The fallback is * as follows: context ClassLoader, class ClassLoader and finally the system ClassLoader. */ - public static ClassLoader findClassLoader(final ClassLoader proposed) { + public static ClassLoader findClassLoader(@Nullable final ClassLoader proposed) { ClassLoader classLoader = proposed; if (classLoader == null) { classLoader = ReflectHelpers.class.getClassLoader(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java index cd6ab7dd414a..de1717f0a45f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java @@ -43,6 +43,9 @@ public class PipelineOptionsTranslation { new ObjectMapper() .registerModules(ObjectMapper.findModules(ReflectHelpers.findClassLoader())); + public static final String PIPELINE_OPTIONS_URN_PREFIX = "beam:option:"; + public static final String PIPELINE_OPTIONS_URN_SUFFIX = ":v1"; + /** Converts the provided {@link PipelineOptions} to a {@link Struct}. */ public static Struct toProto(PipelineOptions options) { Struct.Builder builder = Struct.newBuilder(); @@ -65,9 +68,9 @@ public static Struct toProto(PipelineOptions options) { while (optionsEntries.hasNext()) { Map.Entry entry = optionsEntries.next(); optionsUsingUrns.put( - "beam:option:" + PIPELINE_OPTIONS_URN_PREFIX + CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, entry.getKey()) - + ":v1", + + PIPELINE_OPTIONS_URN_SUFFIX, entry.getValue()); } @@ -92,7 +95,9 @@ public static PipelineOptions fromProto(Struct protoOptions) { mapWithoutUrns.put( CaseFormat.LOWER_UNDERSCORE.to( CaseFormat.LOWER_CAMEL, - optionKey.substring("beam:option:".length(), optionKey.length() - ":v1".length())), + optionKey.substring( + PIPELINE_OPTIONS_URN_PREFIX.length(), + optionKey.length() - PIPELINE_OPTIONS_URN_SUFFIX.length())), optionValue); } return MAPPER.readValue( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index ee3852d70bbe..591a83600561 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.values.RowUtils.RowFieldMatcher; import org.apache.beam.sdk.values.RowUtils.RowPosition; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.ReadableDateTime; @@ -771,6 +772,7 @@ public FieldValueBuilder withFieldValue( checkState(values.isEmpty()); return new FieldValueBuilder(schema, null).withFieldValue(fieldAccessDescriptor, value); } + /** * Sets field values using the field names. Nested values can be set using the field selection * syntax. @@ -836,10 +838,10 @@ public int nextFieldId() { } @Internal - public Row withFieldValueGetters( - Factory> fieldValueGetterFactory, Object getterTarget) { + public <@NonNull T> Row withFieldValueGetters( + Factory>> fieldValueGetterFactory, T getterTarget) { checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + return new RowWithGetters<>(schema, fieldValueGetterFactory, getterTarget); } public Row build() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index 9731507fb0f6..35e0ac20d3f7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -42,13 +42,13 @@ * the appropriate fields from the POJO. */ @SuppressWarnings("rawtypes") -public class RowWithGetters extends Row { - private final Object getterTarget; - private final List getters; +public class RowWithGetters extends Row { + private final T getterTarget; + private final List> getters; private @Nullable Map cache = null; RowWithGetters( - Schema schema, Factory> getterFactory, Object getterTarget) { + Schema schema, Factory>> getterFactory, T getterTarget) { super(schema); this.getterTarget = getterTarget; this.getters = getterFactory.create(TypeDescriptor.of(getterTarget.getClass()), schema); @@ -56,7 +56,7 @@ public class RowWithGetters extends Row { @Override @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) - public @Nullable T getValue(int fieldIdx) { + public W getValue(int fieldIdx) { Field field = getSchema().getField(fieldIdx); boolean cacheField = cacheFieldType(field); @@ -64,7 +64,7 @@ public class RowWithGetters extends Row { cache = new TreeMap<>(); } - Object fieldValue; + @Nullable Object fieldValue; if (cacheField) { if (cache == null) { cache = new TreeMap<>(); @@ -72,15 +72,12 @@ public class RowWithGetters extends Row { fieldValue = cache.computeIfAbsent( fieldIdx, - new Function() { + new Function() { @Override - public Object apply(Integer idx) { - FieldValueGetter getter = getters.get(idx); + public @Nullable Object apply(Integer idx) { + FieldValueGetter getter = getters.get(idx); checkStateNotNull(getter); - @SuppressWarnings("nullness") - @NonNull - Object value = getter.get(getterTarget); - return value; + return getter.get(getterTarget); } }); } else { @@ -90,7 +87,7 @@ public Object apply(Integer idx) { if (fieldValue == null && !field.getType().getNullable()) { throw new RuntimeException("Null value set on non-nullable field " + field); } - return (T) fieldValue; + return (W) fieldValue; } private boolean cacheFieldType(Field field) { @@ -116,7 +113,7 @@ public int getFieldCount() { return rawValues; } - public List getGetters() { + public List> getGetters() { return getters; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java index 3a7a0d5a8935..37580824b558 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java @@ -280,6 +280,7 @@ public void testFailedProcessingCausesAdditionalInboundDataToBeIgnored() throws DESCRIPTOR, OutboundObserverFactory.clientDirect(), inboundObserver -> TestStreams.withOnNext(outboundValues::add).build()); + final AtomicBoolean closed = new AtomicBoolean(); multiplexer.registerConsumer( DATA_INSTRUCTION_ID, new CloseableFnDataReceiver() { @@ -290,7 +291,7 @@ public void flush() throws Exception { @Override public void close() throws Exception { - fail("Unexpected call"); + closed.set(true); } @Override @@ -320,6 +321,7 @@ public void accept(BeamFnApi.Elements input) throws Exception { dataInboundValues, Matchers.contains( BeamFnApi.Elements.newBuilder().addData(data.setTransformId("A").build()).build())); + assertTrue(closed.get()); } @Test diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java index 49fd2bfe2259..d0ee623dea7c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java @@ -28,7 +28,6 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.Map; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; @@ -40,7 +39,6 @@ import org.apache.beam.sdk.schemas.utils.SchemaTestUtils; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; import org.joda.time.DateTime; import org.joda.time.Instant; @@ -888,151 +886,4 @@ public void testSchema_SchemaFieldDescription() throws NoSuchSchemaException { assertEquals(FIELD_DESCRIPTION_SCHEMA.getField("lng"), schema.getField("lng")); assertEquals(FIELD_DESCRIPTION_SCHEMA.getField("str"), schema.getField("str")); } - - @AutoValue - @DefaultSchema(AutoValueSchema.class) - abstract static class ParameterizedAutoValue { - abstract W getValue1(); - - abstract T getValue2(); - - abstract V getValue3(); - - abstract X getValue4(); - } - - @Test - public void testAutoValueWithTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> typeDescriptor = - new TypeDescriptor>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @DefaultSchema(AutoValueSchema.class) - abstract static class ParameterizedAutoValueSubclass - extends ParameterizedAutoValue { - abstract T getValue5(); - } - - @Test - public void testAutoValueWithInheritedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> typeDescriptor = - new TypeDescriptor>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .addInt16Field("value5") - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @AutoValue - @DefaultSchema(AutoValueSchema.class) - abstract static class NestedParameterizedCollectionAutoValue { - abstract Iterable getNested(); - - abstract Map getMap(); - } - - @Test - public void testAutoValueWithNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - NestedParameterizedCollectionAutoValue< - ParameterizedAutoValue, String>> - typeDescriptor = - new TypeDescriptor< - NestedParameterizedCollectionAutoValue< - ParameterizedAutoValue, String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", FieldType.row(expectedInnerSchema)) - .addMapField("map", FieldType.STRING, FieldType.row(expectedInnerSchema)) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testAutoValueWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - NestedParameterizedCollectionAutoValue< - Iterable>, String>> - typeDescriptor = - new TypeDescriptor< - NestedParameterizedCollectionAutoValue< - Iterable>, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", FieldType.iterable(FieldType.row(expectedInnerSchema))) - .addMapField( - "map", FieldType.STRING, FieldType.iterable(FieldType.row(expectedInnerSchema))) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @AutoValue - @DefaultSchema(AutoValueSchema.class) - abstract static class NestedParameterizedAutoValue { - abstract T getNested(); - } - - @Test - public void testAutoValueWithNestedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - NestedParameterizedAutoValue< - ParameterizedAutoValue>> - typeDescriptor = - new TypeDescriptor< - NestedParameterizedAutoValue< - ParameterizedAutoValue>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder().addRowField("nested", expectedInnerSchema).build(); - assertTrue(expectedSchema.equivalent(schema)); - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java new file mode 100644 index 000000000000..26e3278df025 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas; + +import static org.junit.Assert.assertEquals; + +import java.util.Map; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.junit.Test; + +public class FieldValueTypeInformationTest { + public static class GenericClass { + public T t; + + public GenericClass(T t) { + this.t = t; + } + + public T getT() { + return t; + } + + public void setT(T t) { + this.t = t; + } + } + + private final TypeDescriptor>> typeDescriptor = + new TypeDescriptor>>() {}; + private final TypeDescriptor> expectedFieldTypeDescriptor = + new TypeDescriptor>() {}; + + @Test + public void testForGetter() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forGetter( + typeDescriptor, GenericClass.class.getMethod("getT"), 0); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } + + @Test + public void testForField() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forField(typeDescriptor, GenericClass.class.getField("t"), 0); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } + + @Test + public void testForSetter() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forSetter( + typeDescriptor, GenericClass.class.getMethod("setT", Object.class)); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java index 2252c3aef0db..5313feb5c6c0 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java @@ -68,7 +68,6 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBeanWithAnnotations; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -626,127 +625,4 @@ public void testSetterConstructionWithRenamedFields() throws NoSuchSchemaExcepti assertEquals( registry.getFromRowFunction(BeanWithCaseFormat.class).apply(row), beanWithCaseFormat); } - - @Test - public void testBeanWithTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> - typeDescriptor = - new TypeDescriptor< - TestJavaBeans.SimpleParameterizedBean>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testBeanWithInheritedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> typeDescriptor = - new TypeDescriptor>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .addInt16Field("value5") - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testBeanWithNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestJavaBeans.NestedParameterizedCollectionBean< - TestJavaBeans.SimpleParameterizedBean, String>> - typeDescriptor = - new TypeDescriptor< - TestJavaBeans.NestedParameterizedCollectionBean< - TestJavaBeans.SimpleParameterizedBean, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", Schema.FieldType.row(expectedInnerSchema)) - .addMapField("map", Schema.FieldType.STRING, Schema.FieldType.row(expectedInnerSchema)) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testBeanWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestJavaBeans.NestedParameterizedCollectionBean< - Iterable>, - String>> - typeDescriptor = - new TypeDescriptor< - TestJavaBeans.NestedParameterizedCollectionBean< - Iterable< - TestJavaBeans.SimpleParameterizedBean>, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField( - "nested", Schema.FieldType.iterable(Schema.FieldType.row(expectedInnerSchema))) - .addMapField( - "map", - Schema.FieldType.STRING, - Schema.FieldType.iterable(Schema.FieldType.row(expectedInnerSchema))) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testBeanWithNestedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestJavaBeans.NestedParameterizedBean< - TestJavaBeans.SimpleParameterizedBean>> - typeDescriptor = - new TypeDescriptor< - TestJavaBeans.NestedParameterizedBean< - TestJavaBeans.SimpleParameterizedBean>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder().addRowField("nested", expectedInnerSchema).build(); - assertTrue(expectedSchema.equivalent(schema)); - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index 70bc3030924b..11bef79b26f7 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -76,7 +76,6 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -782,123 +781,4 @@ public void testCircularNestedPOJOThrows() throws NoSuchSchemaException { thrown.getMessage(), containsString("TestPOJOs$FirstCircularNestedPOJO")); } - - @Test - public void testPojoWithTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> - typeDescriptor = - new TypeDescriptor< - TestPOJOs.SimpleParameterizedPOJO>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testPojoWithInheritedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> typeDescriptor = - new TypeDescriptor>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .addInt16Field("value5") - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testPojoWithNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestPOJOs.NestedParameterizedCollectionPOJO< - TestPOJOs.SimpleParameterizedPOJO, String>> - typeDescriptor = - new TypeDescriptor< - TestPOJOs.NestedParameterizedCollectionPOJO< - TestPOJOs.SimpleParameterizedPOJO, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", FieldType.row(expectedInnerSchema)) - .addMapField("map", FieldType.STRING, FieldType.row(expectedInnerSchema)) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testPojoWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestPOJOs.NestedParameterizedCollectionPOJO< - Iterable>, - String>> - typeDescriptor = - new TypeDescriptor< - TestPOJOs.NestedParameterizedCollectionPOJO< - Iterable>, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", FieldType.iterable(FieldType.row(expectedInnerSchema))) - .addMapField( - "map", FieldType.STRING, FieldType.iterable(FieldType.row(expectedInnerSchema))) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testPojoWithNestedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestPOJOs.NestedParameterizedPOJO< - TestPOJOs.SimpleParameterizedPOJO>> - typeDescriptor = - new TypeDescriptor< - TestPOJOs.NestedParameterizedPOJO< - TestPOJOs.SimpleParameterizedPOJO>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder().addRowField("nested", expectedInnerSchema).build(); - assertTrue(expectedSchema.equivalent(schema)); - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java index e0a45c2c82fe..7e9cf9a894b9 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java @@ -34,7 +34,6 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.Collections; import java.util.List; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; @@ -54,6 +53,7 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.PrimitiveMapBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBean; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.junit.Test; @@ -66,9 +66,7 @@ public class JavaBeanUtilsTest { public void testNullable() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); assertTrue(schema.getField("str").getType().getNullable()); assertFalse(schema.getField("anInt").getType().getNullable()); } @@ -77,9 +75,7 @@ public void testNullable() { public void testSimpleBean() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(SIMPLE_BEAN_SCHEMA, schema); } @@ -87,9 +83,7 @@ public void testSimpleBean() { public void testNestedBean() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_BEAN_SCHEMA, schema); } @@ -97,9 +91,7 @@ public void testNestedBean() { public void testPrimitiveArray() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_ARRAY_BEAN_SCHEMA, schema); } @@ -107,9 +99,7 @@ public void testPrimitiveArray() { public void testNestedArray() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_ARRAY_BEAN_SCHEMA, schema); } @@ -117,9 +107,7 @@ public void testNestedArray() { public void testNestedCollection() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_COLLECTION_BEAN_SCHEMA, schema); } @@ -127,9 +115,7 @@ public void testNestedCollection() { public void testPrimitiveMap() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_MAP_BEAN_SCHEMA, schema); } @@ -137,9 +123,7 @@ public void testPrimitiveMap() { public void testNestedMap() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_MAP_BEAN_SCHEMA, schema); } @@ -159,11 +143,11 @@ public void testGeneratedSimpleGetters() { simpleBean.setBigDecimal(new BigDecimal(42)); simpleBean.setStringBuilder(new StringBuilder("stringBuilder")); - List getters = + List> getters = JavaBeanUtils.getGetters( new TypeDescriptor() {}, SIMPLE_BEAN_SCHEMA, - new JavaBeanSchema.GetterTypeSupplier(), + new GetterTypeSupplier(), new DefaultTypeConversionsFactory()); assertEquals(12, getters.size()); assertEquals("str", getters.get(0).name()); @@ -237,7 +221,7 @@ public void testGeneratedSimpleBoxedGetters() { bean.setaLong(44L); bean.setaBoolean(true); - List getters = + List> getters = JavaBeanUtils.getGetters( new TypeDescriptor() {}, BEAN_WITH_BOXED_FIELDS_SCHEMA, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java index 46c098dddaeb..378cdc06805f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java @@ -35,7 +35,6 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.Collections; import java.util.List; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; @@ -53,6 +52,7 @@ import org.apache.beam.sdk.schemas.utils.TestPOJOs.PrimitiveMapPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.SimplePOJO; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.joda.time.Instant; import org.junit.Test; @@ -72,9 +72,7 @@ public class POJOUtilsTest { public void testNullables() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); assertTrue(schema.getField("str").getType().getNullable()); assertFalse(schema.getField("anInt").getType().getNullable()); } @@ -83,9 +81,7 @@ public void testNullables() { public void testSimplePOJO() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); assertEquals(SIMPLE_POJO_SCHEMA, schema); } @@ -93,9 +89,7 @@ public void testSimplePOJO() { public void testNestedPOJO() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_SCHEMA, schema); } @@ -104,8 +98,7 @@ public void testNestedPOJOWithSimplePOJO() { Schema schema = POJOUtils.schemaFromPojoClass( new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_WITH_SIMPLE_POJO_SCHEMA, schema); } @@ -113,9 +106,7 @@ public void testNestedPOJOWithSimplePOJO() { public void testPrimitiveArray() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_ARRAY_POJO_SCHEMA, schema); } @@ -123,9 +114,7 @@ public void testPrimitiveArray() { public void testNestedArray() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_ARRAY_POJO_SCHEMA, schema); } @@ -133,9 +122,7 @@ public void testNestedArray() { public void testNestedCollection() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_COLLECTION_POJO_SCHEMA, schema); } @@ -143,9 +130,7 @@ public void testNestedCollection() { public void testPrimitiveMap() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_MAP_POJO_SCHEMA, schema); } @@ -153,9 +138,7 @@ public void testPrimitiveMap() { public void testNestedMap() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_MAP_POJO_SCHEMA, schema); } @@ -176,7 +159,7 @@ public void testGeneratedSimpleGetters() { new BigDecimal(42), new StringBuilder("stringBuilder")); - List getters = + List> getters = POJOUtils.getGetters( new TypeDescriptor() {}, SIMPLE_POJO_SCHEMA, @@ -202,7 +185,7 @@ public void testGeneratedSimpleGetters() { @Test public void testGeneratedSimpleSetters() { SimplePOJO simplePojo = new SimplePOJO(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, SIMPLE_POJO_SCHEMA, @@ -241,7 +224,7 @@ public void testGeneratedSimpleSetters() { public void testGeneratedSimpleBoxedGetters() { POJOWithBoxedFields pojo = new POJOWithBoxedFields((byte) 41, (short) 42, 43, 44L, true); - List getters = + List> getters = POJOUtils.getGetters( new TypeDescriptor() {}, POJO_WITH_BOXED_FIELDS_SCHEMA, @@ -257,7 +240,7 @@ public void testGeneratedSimpleBoxedGetters() { @Test public void testGeneratedSimpleBoxedSetters() { POJOWithBoxedFields pojo = new POJOWithBoxedFields(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, POJO_WITH_BOXED_FIELDS_SCHEMA, @@ -280,7 +263,7 @@ public void testGeneratedSimpleBoxedSetters() { @Test public void testGeneratedByteBufferSetters() { POJOWithByteArray pojo = new POJOWithByteArray(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, POJO_WITH_BYTE_ARRAY_SCHEMA, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java index cbc976144971..b5ad6f989d9e 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java @@ -1397,95 +1397,4 @@ public void setValue(@Nullable Float value) { Schema.Field.nullable("value", FieldType.FLOAT) .withDescription("This value is the value stored in the object as a float.")) .build(); - - @DefaultSchema(JavaBeanSchema.class) - public static class SimpleParameterizedBean { - @Nullable private W value1; - @Nullable private T value2; - @Nullable private V value3; - @Nullable private X value4; - - public W getValue1() { - return value1; - } - - public void setValue1(W value1) { - this.value1 = value1; - } - - public T getValue2() { - return value2; - } - - public void setValue2(T value2) { - this.value2 = value2; - } - - public V getValue3() { - return value3; - } - - public void setValue3(V value3) { - this.value3 = value3; - } - - public X getValue4() { - return value4; - } - - public void setValue4(X value4) { - this.value4 = value4; - } - } - - @DefaultSchema(JavaBeanSchema.class) - public static class SimpleParameterizedBeanSubclass - extends SimpleParameterizedBean { - @Nullable private T value5; - - public SimpleParameterizedBeanSubclass() {} - - public T getValue5() { - return value5; - } - - public void setValue5(T value5) { - this.value5 = value5; - } - } - - @DefaultSchema(JavaBeanSchema.class) - public static class NestedParameterizedCollectionBean { - private Iterable nested; - private Map map; - - public Iterable getNested() { - return nested; - } - - public Map getMap() { - return map; - } - - public void setNested(Iterable nested) { - this.nested = nested; - } - - public void setMap(Map map) { - this.map = map; - } - } - - @DefaultSchema(JavaBeanSchema.class) - public static class NestedParameterizedBean { - private T nested; - - public T getNested() { - return nested; - } - - public void setNested(T nested) { - this.nested = nested; - } - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java index ce7409365d09..789de02adee8 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java @@ -495,125 +495,6 @@ public int hashCode() { .addStringField("stringBuilder") .build(); - @DefaultSchema(JavaFieldSchema.class) - public static class SimpleParameterizedPOJO { - public W value1; - public T value2; - public V value3; - public X value4; - - public SimpleParameterizedPOJO() {} - - public SimpleParameterizedPOJO(W value1, T value2, V value3, X value4) { - this.value1 = value1; - this.value2 = value2; - this.value3 = value3; - this.value4 = value4; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof SimpleParameterizedPOJO)) { - return false; - } - SimpleParameterizedPOJO that = (SimpleParameterizedPOJO) o; - return Objects.equals(value1, that.value1) - && Objects.equals(value2, that.value2) - && Objects.equals(value3, that.value3) - && Objects.equals(value4, that.value4); - } - - @Override - public int hashCode() { - return Objects.hash(value1, value2, value3, value4); - } - } - - @DefaultSchema(JavaFieldSchema.class) - public static class SimpleParameterizedPOJOSubclass - extends SimpleParameterizedPOJO { - public T value5; - - public SimpleParameterizedPOJOSubclass() {} - - public SimpleParameterizedPOJOSubclass(T value5) { - this.value5 = value5; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof SimpleParameterizedPOJOSubclass)) { - return false; - } - SimpleParameterizedPOJOSubclass that = (SimpleParameterizedPOJOSubclass) o; - return Objects.equals(value5, that.value5); - } - - @Override - public int hashCode() { - return Objects.hash(value4); - } - } - - @DefaultSchema(JavaFieldSchema.class) - public static class NestedParameterizedCollectionPOJO { - public Iterable nested; - public Map map; - - public NestedParameterizedCollectionPOJO(Iterable nested, Map map) { - this.nested = nested; - this.map = map; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof NestedParameterizedCollectionPOJO)) { - return false; - } - NestedParameterizedCollectionPOJO that = (NestedParameterizedCollectionPOJO) o; - return Objects.equals(nested, that.nested) && Objects.equals(map, that.map); - } - - @Override - public int hashCode() { - return Objects.hash(nested, map); - } - } - - @DefaultSchema(JavaFieldSchema.class) - public static class NestedParameterizedPOJO { - public T nested; - - public NestedParameterizedPOJO(T nested) { - this.nested = nested; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof NestedParameterizedPOJO)) { - return false; - } - NestedParameterizedPOJO that = (NestedParameterizedPOJO) o; - return Objects.equals(nested, that.nested); - } - - @Override - public int hashCode() { - return Objects.hash(nested); - } - } /** A POJO containing a nested class. * */ @DefaultSchema(JavaFieldSchema.class) public static class NestedPOJO { @@ -1006,7 +887,7 @@ public boolean equals(@Nullable Object o) { if (this == o) { return true; } - if (!(o instanceof PojoWithIterable)) { + if (!(o instanceof PojoWithNestedArray)) { return false; } PojoWithIterable that = (PojoWithIterable) o; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java index 282a41bed0dc..7a02d95a5046 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java @@ -402,6 +402,32 @@ public void testFlattenWithDifferentInputAndOutputCoders2() { ///////////////////////////////////////////////////////////////////////////// + @Test + @Category(NeedsRunner.class) + public void testFlattenWithPCollection() { + PCollection output = + p.apply(Create.of(LINES)) + .apply("FlattenWithLines1", Flatten.with(p.apply("Create1", Create.of(LINES)))) + .apply("FlattenWithLines2", Flatten.with(p.apply("Create2", Create.of(LINES2)))); + + PAssert.that(output).containsInAnyOrder(flattenLists(Arrays.asList(LINES, LINES2, LINES))); + p.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testFlattenWithPTransform() { + PCollection output = + p.apply(Create.of(LINES)) + .apply("Create1", Flatten.with(Create.of(LINES))) + .apply("Create2", Flatten.with(Create.of(LINES2))); + + PAssert.that(output).containsInAnyOrder(flattenLists(Arrays.asList(LINES, LINES2, LINES))); + p.run(); + } + + ///////////////////////////////////////////////////////////////////////////// + @Test @Category(NeedsRunner.class) public void testEqualWindowFnPropagation() { @@ -470,6 +496,7 @@ public void testIncompatibleWindowFnPropagationFailure() { public void testFlattenGetName() { Assert.assertEquals("Flatten.Iterables", Flatten.iterables().getName()); Assert.assertEquals("Flatten.PCollections", Flatten.pCollections().getName()); + Assert.assertEquals("Flatten.With", Flatten.with((PCollection) null).getName()); } ///////////////////////////////////////////////////////////////////////////// diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java new file mode 100644 index 000000000000..ee3a00c46caa --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.UUID; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashMultimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimaps; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for Tee. */ +@RunWith(JUnit4.class) +public class TeeTest { + + @Rule public final transient TestPipeline p = TestPipeline.create(); + + @Test + @Category(NeedsRunner.class) + public void testTee() { + List elements = Arrays.asList("a", "b", "c"); + CollectToMemory collector = new CollectToMemory<>(); + PCollection output = p.apply(Create.of(elements)).apply(Tee.of(collector)); + + PAssert.that(output).containsInAnyOrder(elements); + p.run().waitUntilFinish(); + + // Here we assert that this "sink" had the correct side effects. + assertThat(collector.get(), containsInAnyOrder(elements.toArray(new String[3]))); + } + + private static class CollectToMemory extends PTransform, PCollection> { + + private static final Multimap ALL_ELEMENTS = + Multimaps.synchronizedMultimap(HashMultimap.create()); + + UUID uuid = UUID.randomUUID(); + + @Override + public PCollection expand(PCollection input) { + return input.apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + ALL_ELEMENTS.put(uuid, c.element()); + } + })); + } + + @SuppressWarnings("unchecked") + public Collection get() { + return (Collection) ALL_ELEMENTS.get(uuid); + } + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java index 296a53f48e80..fd178f8e7649 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Objects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -226,5 +227,19 @@ public long getNum() { public String getStr() { return this.str; } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Pojo)) { + return false; + } + Pojo pojo = (Pojo) o; + return num == pojo.num && Objects.equals(str, pojo.str); + } + + @Override + public int hashCode() { + return Objects.hash(num, str); + } } } diff --git a/sdks/java/expansion-service/build.gradle b/sdks/java/expansion-service/build.gradle index 4dd8c8968ed9..a25583870acf 100644 --- a/sdks/java/expansion-service/build.gradle +++ b/sdks/java/expansion-service/build.gradle @@ -57,3 +57,7 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.main.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } + +compileJava { + outputs.upToDateWhen { false } +} \ No newline at end of file diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java index 150fe9729573..9c5b5a0ad136 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java @@ -60,7 +60,6 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.PortablePipelineOptions; -import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; @@ -535,7 +534,7 @@ private static void invokeSetter(ConfigT config, @Nullable Object valu } private @MonotonicNonNull Map registeredTransforms; - private final PipelineOptions pipelineOptions; + private final PipelineOptions commandLineOptions; private final @Nullable String loopbackAddress; public ExpansionService() { @@ -551,7 +550,7 @@ public ExpansionService(PipelineOptions opts) { } public ExpansionService(PipelineOptions opts, @Nullable String loopbackAddress) { - this.pipelineOptions = opts; + this.commandLineOptions = opts; this.loopbackAddress = loopbackAddress; } @@ -587,12 +586,15 @@ private Map loadRegisteredTransforms() { request.getTransform().getSpec().getUrn()); LOG.debug("Full transform: {}", request.getTransform()); Set existingTransformIds = request.getComponents().getTransformsMap().keySet(); - Pipeline pipeline = - createPipeline(PipelineOptionsTranslation.fromProto(request.getPipelineOptions())); + + PipelineOptions pipelineOptionsFromRequest = + PipelineOptionsTranslation.fromProto(request.getPipelineOptions()); + Pipeline pipeline = createPipeline(pipelineOptionsFromRequest); + boolean isUseDeprecatedRead = - ExperimentalOptions.hasExperiment(pipelineOptions, "use_deprecated_read") + ExperimentalOptions.hasExperiment(commandLineOptions, "use_deprecated_read") || ExperimentalOptions.hasExperiment( - pipelineOptions, "beam_fn_api_use_deprecated_read"); + commandLineOptions, "beam_fn_api_use_deprecated_read"); if (!isUseDeprecatedRead) { ExperimentalOptions.addExperiment( pipeline.getOptions().as(ExperimentalOptions.class), "beam_fn_api"); @@ -629,7 +631,7 @@ private Map loadRegisteredTransforms() { if (transformProvider == null) { if (getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP).equals(urn)) { AllowList allowList = - pipelineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist(); + commandLineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist(); assert allowList != null; transformProvider = new JavaClassLookupTransformProvider(allowList); } else if (getUrn(SCHEMA_TRANSFORM).equals(urn)) { @@ -671,7 +673,7 @@ private Map loadRegisteredTransforms() { RunnerApi.Environment defaultEnvironment = Environments.createOrGetDefaultEnvironment( pipeline.getOptions().as(PortablePipelineOptions.class)); - if (pipelineOptions.as(ExpansionServiceOptions.class).getAlsoStartLoopbackWorker()) { + if (commandLineOptions.as(ExpansionServiceOptions.class).getAlsoStartLoopbackWorker()) { PortablePipelineOptions externalOptions = PipelineOptionsFactory.create().as(PortablePipelineOptions.class); externalOptions.setDefaultEnvironmentType(Environments.ENVIRONMENT_EXTERNAL); @@ -723,35 +725,34 @@ private Map loadRegisteredTransforms() { } protected Pipeline createPipeline(PipelineOptions requestOptions) { - // TODO: [https://github.com/apache/beam/issues/21064]: implement proper validation - PipelineOptions effectiveOpts = PipelineOptionsFactory.create(); - PortablePipelineOptions portableOptions = effectiveOpts.as(PortablePipelineOptions.class); - PortablePipelineOptions specifiedOptions = pipelineOptions.as(PortablePipelineOptions.class); - Optional.ofNullable(specifiedOptions.getDefaultEnvironmentType()) - .ifPresent(portableOptions::setDefaultEnvironmentType); - Optional.ofNullable(specifiedOptions.getDefaultEnvironmentConfig()) - .ifPresent(portableOptions::setDefaultEnvironmentConfig); - List filesToStage = specifiedOptions.getFilesToStage(); + // We expect the ExpansionRequest to contain a valid set of options to be used for this + // expansion. + // Additionally, we override selected options using options values set via command line or + // ExpansionService wide overrides. + + PortablePipelineOptions requestPortablePipelineOptions = + requestOptions.as(PortablePipelineOptions.class); + PortablePipelineOptions commandLinePortablePipelineOptions = + commandLineOptions.as(PortablePipelineOptions.class); + Optional.ofNullable(commandLinePortablePipelineOptions.getDefaultEnvironmentType()) + .ifPresent(requestPortablePipelineOptions::setDefaultEnvironmentType); + Optional.ofNullable(commandLinePortablePipelineOptions.getDefaultEnvironmentConfig()) + .ifPresent(requestPortablePipelineOptions::setDefaultEnvironmentConfig); + List filesToStage = commandLinePortablePipelineOptions.getFilesToStage(); if (filesToStage != null) { - effectiveOpts.as(PortablePipelineOptions.class).setFilesToStage(filesToStage); + requestPortablePipelineOptions + .as(PortablePipelineOptions.class) + .setFilesToStage(filesToStage); } - effectiveOpts + requestPortablePipelineOptions .as(ExperimentalOptions.class) - .setExperiments(pipelineOptions.as(ExperimentalOptions.class).getExperiments()); - effectiveOpts.setRunner(NotRunnableRunner.class); - effectiveOpts + .setExperiments(commandLineOptions.as(ExperimentalOptions.class).getExperiments()); + requestPortablePipelineOptions.setRunner(NotRunnableRunner.class); + requestPortablePipelineOptions .as(ExpansionServiceOptions.class) .setExpansionServiceConfig( - pipelineOptions.as(ExpansionServiceOptions.class).getExpansionServiceConfig()); - // TODO(https://github.com/apache/beam/issues/20090): Figure out the correct subset of options - // to propagate. - if (requestOptions.as(StreamingOptions.class).getUpdateCompatibilityVersion() != null) { - effectiveOpts - .as(StreamingOptions.class) - .setUpdateCompatibilityVersion( - requestOptions.as(StreamingOptions.class).getUpdateCompatibilityVersion()); - } - return Pipeline.create(effectiveOpts); + commandLineOptions.as(ExpansionServiceOptions.class).getExpansionServiceConfig()); + return Pipeline.create(requestOptions); } @Override diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java index 1c8d515d5c85..9ee0c2c1797b 100644 --- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.expansion.service; +import static org.apache.beam.sdk.util.construction.PipelineOptionsTranslation.PIPELINE_OPTIONS_URN_PREFIX; +import static org.apache.beam.sdk.util.construction.PipelineOptionsTranslation.PIPELINE_OPTIONS_URN_SUFFIX; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.contains; @@ -49,6 +51,8 @@ import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; @@ -58,15 +62,20 @@ import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Struct; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Value; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Resources; import org.checkerframework.checker.nullness.qual.Nullable; import org.hamcrest.Matchers; +import org.junit.Assert; import org.junit.Test; /** Tests for {@link ExpansionService}. */ @@ -76,6 +85,7 @@ public class ExpansionServiceTest { private static final String TEST_URN = "test:beam:transforms:count"; + private static final String TEST_OPTIONS_URN = "test:beam:transforms:test_options"; private static final String TEST_NAME = "TestName"; @@ -98,9 +108,59 @@ public class ExpansionServiceTest { @AutoService(ExpansionService.ExpansionServiceRegistrar.class) public static class TestTransformRegistrar implements ExpansionService.ExpansionServiceRegistrar { + static final String EXPECTED_STRING_VALUE = "abcde"; + static final Boolean EXPECTED_BOOLEAN_VALUE = true; + static final Integer EXPECTED_INTEGER_VALUE = 12345; + @Override public Map knownTransforms() { - return ImmutableMap.of(TEST_URN, (spec, options) -> Count.perElement()); + return ImmutableMap.of( + TEST_URN, (spec, options) -> Count.perElement(), + TEST_OPTIONS_URN, + (spec, options) -> + new TestOptionsTransform( + EXPECTED_STRING_VALUE, EXPECTED_BOOLEAN_VALUE, EXPECTED_INTEGER_VALUE)); + } + } + + public interface TestOptions extends PipelineOptions { + String getStringOption(); + + void setStringOption(String value); + + Boolean getBooleanOption(); + + void setBooleanOption(Boolean value); + + Integer getIntegerOption(); + + void setIntegerOption(Integer value); + } + + public static class TestOptionsTransform + extends PTransform, PCollection> { + String expectedStringValue; + + Boolean expectedBooleanValue; + + Integer expectedIntegerValue; + + public TestOptionsTransform( + String expectedStringValue, Boolean expectedBooleanValue, Integer expectedIntegerValue) { + this.expectedStringValue = expectedStringValue; + this.expectedBooleanValue = expectedBooleanValue; + this.expectedIntegerValue = expectedIntegerValue; + } + + @Override + public PCollection expand(PCollection input) { + TestOptions testOption = input.getPipeline().getOptions().as(TestOptions.class); + + Assert.assertEquals(expectedStringValue, testOption.getStringOption()); + Assert.assertEquals(expectedBooleanValue, testOption.getBooleanOption()); + Assert.assertEquals(expectedIntegerValue, testOption.getIntegerOption()); + + return input; } } @@ -146,6 +206,58 @@ public void testConstruct() { } } + @Test + public void testConstructWithPipelineOptions() { + PipelineOptionsFactory.register(TestOptions.class); + Pipeline p = Pipeline.create(); + p.apply(Impulse.create()); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + String inputPcollId = + Iterables.getOnlyElement( + Iterables.getOnlyElement(pipelineProto.getComponents().getTransformsMap().values()) + .getOutputsMap() + .values()); + + Struct optionsStruct = + Struct.newBuilder() + .putFields( + PIPELINE_OPTIONS_URN_PREFIX + "string_option" + PIPELINE_OPTIONS_URN_SUFFIX, + Value.newBuilder() + .setStringValue(TestTransformRegistrar.EXPECTED_STRING_VALUE) + .build()) + .putFields( + PIPELINE_OPTIONS_URN_PREFIX + "boolean_option" + PIPELINE_OPTIONS_URN_SUFFIX, + Value.newBuilder() + .setBoolValue(TestTransformRegistrar.EXPECTED_BOOLEAN_VALUE) + .build()) + .putFields( + PIPELINE_OPTIONS_URN_PREFIX + "integer_option" + PIPELINE_OPTIONS_URN_SUFFIX, + Value.newBuilder() + .setNumberValue(TestTransformRegistrar.EXPECTED_INTEGER_VALUE) + .build()) + .build(); + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setPipelineOptions(optionsStruct) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(TEST_OPTIONS_URN)) + .putInputs("input", inputPcollId)) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + + // Verify it has the right input. + assertThat(expandedTransform.getInputsMap().values(), contains(inputPcollId)); + + // Verify it has the right output. + assertThat(expandedTransform.getOutputsMap().keySet(), contains("output")); + } + @Test public void testConstructGenerateSequenceWithRegistration() { ExternalTransforms.ExternalConfigurationPayload payload = diff --git a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java index 4b6538157fd0..78ba610ad4d1 100644 --- a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java +++ b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java @@ -276,11 +276,11 @@ public static class RecordBatchRowIterator implements Iterator, AutoCloseab new ArrowValueConverterVisitor(); private final Schema schema; private final VectorSchemaRoot vectorSchemaRoot; - private final Factory> fieldValueGetters; + private final Factory>> fieldValueGetters; private Integer currRowIndex; private static class FieldVectorListValueGetterFactory - implements Factory> { + implements Factory>> { private final List fieldVectors; static FieldVectorListValueGetterFactory of(List fieldVectors) { @@ -292,7 +292,8 @@ private FieldVectorListValueGetterFactory(List fieldVectors) { } @Override - public List create(TypeDescriptor typeDescriptor, Schema schema) { + public List> create( + TypeDescriptor typeDescriptor, Schema schema) { return this.fieldVectors.stream() .map( (fieldVector) -> { diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java index e75647a2ccfa..203bcccbf562 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.schemas.SchemaProvider; import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A {@link SchemaProvider} for AVRO generated SpecificRecords and POJOs. @@ -44,8 +45,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return AvroUtils.getGetters(targetTypeDescriptor, schema); } diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java index 1a530a3f6ca5..0a82663c1771 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java @@ -78,8 +78,8 @@ private static SchemaUserTypeCreator createCreator(Class clazz, Schema sc // Generate a method call to create and invoke the SpecificRecord's constructor. . MethodCall construct = MethodCall.construct(baseConstructor); - for (int i = 0; i < baseConstructor.getGenericParameterTypes().length; ++i) { - Type baseType = baseConstructor.getGenericParameterTypes()[i]; + for (int i = 0; i < baseConstructor.getParameterTypes().length; ++i) { + Class baseType = baseConstructor.getParameterTypes()[i]; construct = construct.with(readAndConvertParameter(baseType, i), baseType); } @@ -110,7 +110,7 @@ private static SchemaUserTypeCreator createCreator(Class clazz, Schema sc } private static StackManipulation readAndConvertParameter( - Type constructorParameterType, int index) { + Class constructorParameterType, int index) { TypeConversionsFactory typeConversionsFactory = new AvroUtils.AvroTypeConversionFactory(); // The types in the AVRO-generated constructor might be the types returned by Beam's Row class, diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java index 1324d254e44e..bfbab6fe87f6 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -94,11 +94,13 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.EnsuresNonNullIf; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; import org.joda.time.Days; import org.joda.time.Duration; import org.joda.time.Instant; @@ -139,10 +141,7 @@ * * is used. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class AvroUtils { private static final ForLoadedType BYTES = new ForLoadedType(byte[].class); private static final ForLoadedType JAVA_INSTANT = new ForLoadedType(java.time.Instant.class); @@ -152,6 +151,38 @@ public class AvroUtils { new ForLoadedType(ReadableInstant.class); private static final ForLoadedType JODA_INSTANT = new ForLoadedType(Instant.class); + // contains workarounds for third-party methods that accept nullable arguments but lack proper + // annotations + private static class NullnessCheckerWorkarounds { + + private static ReflectData newReflectData(Class clazz) { + // getClassLoader returns @Nullable Classloader, but it's ok, as ReflectData constructor + // actually tolerates null classloader argument despite lacking the @Nullable annotation + @SuppressWarnings("nullness") + @NonNull + ClassLoader classLoader = clazz.getClassLoader(); + return new ReflectData(classLoader); + } + + private static void builderSet( + GenericRecordBuilder builder, String fieldName, @Nullable Object value) { + // the value argument can actually be null here, it's not annotated as such in the method + // though, hence this wrapper + builder.set(fieldName, castToNonNull(value)); + } + + private static Object createFixed( + @Nullable Object old, byte[] bytes, org.apache.avro.Schema schema) { + // old is tolerated when null, due to an instanceof check + return GenericData.get().createFixed(castToNonNull(old), bytes, schema); + } + + @SuppressWarnings("nullness") + private static @NonNull T castToNonNull(@Nullable T value) { + return value; + } + } + public static void addLogicalTypeConversions(final GenericData data) { // do not add DecimalConversion by default as schema must have extra 'scale' and 'precision' // properties. avro reflect already handles BigDecimal as string with the 'java-class' property @@ -235,7 +266,9 @@ public static FixedBytesField withSize(int size) { /** Create a {@link FixedBytesField} from a Beam {@link FieldType}. */ public static @Nullable FixedBytesField fromBeamFieldType(FieldType fieldType) { if (fieldType.getTypeName().isLogicalType() - && fieldType.getLogicalType().getIdentifier().equals(FixedBytes.IDENTIFIER)) { + && checkNotNull(fieldType.getLogicalType()) + .getIdentifier() + .equals(FixedBytes.IDENTIFIER)) { int length = fieldType.getLogicalType(FixedBytes.class).getLength(); return new FixedBytesField(length); } else { @@ -264,7 +297,7 @@ public FieldType toBeamType() { /** Convert to an AVRO type. */ public org.apache.avro.Schema toAvroType(String name, String namespace) { - return org.apache.avro.Schema.createFixed(name, null, namespace, size); + return org.apache.avro.Schema.createFixed(name, "", namespace, size); } } @@ -451,8 +484,7 @@ public static Field toBeamField(org.apache.avro.Schema.Field field) { public static org.apache.avro.Schema.Field toAvroField(Field field, String namespace) { org.apache.avro.Schema fieldSchema = getFieldSchema(field.getType(), field.getName(), namespace); - return new org.apache.avro.Schema.Field( - field.getName(), fieldSchema, field.getDescription(), (Object) null); + return new org.apache.avro.Schema.Field(field.getName(), fieldSchema, field.getDescription()); } private AvroUtils() {} @@ -463,7 +495,7 @@ private AvroUtils() {} * @param clazz avro class */ public static Schema toBeamSchema(Class clazz) { - ReflectData data = new ReflectData(clazz.getClassLoader()); + ReflectData data = NullnessCheckerWorkarounds.newReflectData(clazz); return toBeamSchema(data.getSchema(clazz)); } @@ -486,10 +518,17 @@ public static Schema toBeamSchema(org.apache.avro.Schema schema) { return builder.build(); } + @EnsuresNonNullIf( + expression = {"#1"}, + result = false) + private static boolean isNullOrEmpty(@Nullable String str) { + return str == null || str.isEmpty(); + } + /** Converts a Beam Schema into an AVRO schema. */ public static org.apache.avro.Schema toAvroSchema( Schema beamSchema, @Nullable String name, @Nullable String namespace) { - final String schemaName = Strings.isNullOrEmpty(name) ? "topLevelRecord" : name; + final String schemaName = isNullOrEmpty(name) ? "topLevelRecord" : name; final String schemaNamespace = namespace == null ? "" : namespace; String childNamespace = !"".equals(schemaNamespace) ? schemaNamespace + "." + schemaName : schemaName; @@ -498,7 +537,7 @@ public static org.apache.avro.Schema toAvroSchema( org.apache.avro.Schema.Field recordField = toAvroField(field, childNamespace); fields.add(recordField); } - return org.apache.avro.Schema.createRecord(schemaName, null, schemaNamespace, false, fields); + return org.apache.avro.Schema.createRecord(schemaName, "", schemaNamespace, false, fields); } public static org.apache.avro.Schema toAvroSchema(Schema beamSchema) { @@ -557,7 +596,8 @@ public static GenericRecord toGenericRecord( GenericRecordBuilder builder = new GenericRecordBuilder(avroSchema); for (int i = 0; i < beamSchema.getFieldCount(); ++i) { Field field = beamSchema.getField(i); - builder.set( + NullnessCheckerWorkarounds.builderSet( + builder, field.getName(), genericFromBeamField( field.getType(), avroSchema.getField(field.getName()).schema(), row.getValue(i))); @@ -567,7 +607,7 @@ public static GenericRecord toGenericRecord( @SuppressWarnings("unchecked") public static SerializableFunction getToRowFunction( - Class clazz, org.apache.avro.@Nullable Schema schema) { + Class clazz, org.apache.avro.Schema schema) { if (GenericRecord.class.equals(clazz)) { Schema beamSchema = toBeamSchema(schema); return (SerializableFunction) getGenericRecordToRowFunction(beamSchema); @@ -662,9 +702,9 @@ public static SerializableFunction getGenericRecordToRowFunc } private static class GenericRecordToRowFn implements SerializableFunction { - private final Schema schema; + private final @Nullable Schema schema; - GenericRecordToRowFn(Schema schema) { + GenericRecordToRowFn(@Nullable Schema schema) { this.schema = schema; } @@ -701,7 +741,7 @@ public static SerializableFunction getRowToGenericRecordFunc } private static class RowToGenericRecordFn implements SerializableFunction { - private transient org.apache.avro.Schema avroSchema; + private transient org.apache.avro.@Nullable Schema avroSchema; RowToGenericRecordFn(org.apache.avro.@Nullable Schema avroSchema) { this.avroSchema = avroSchema; @@ -751,7 +791,8 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE public static SchemaCoder schemaCoder(TypeDescriptor type) { @SuppressWarnings("unchecked") Class clazz = (Class) type.getRawType(); - org.apache.avro.Schema avroSchema = new ReflectData(clazz.getClassLoader()).getSchema(clazz); + org.apache.avro.Schema avroSchema = + NullnessCheckerWorkarounds.newReflectData(clazz).getSchema(clazz); Schema beamSchema = toBeamSchema(avroSchema); return SchemaCoder.of( beamSchema, type, getToRowFunction(clazz, avroSchema), getFromRowFunction(clazz)); @@ -790,7 +831,7 @@ public static SchemaCoder schemaCoder(org.apache.avro.Schema sche */ public static SchemaCoder schemaCoder(Class clazz, org.apache.avro.Schema schema) { return SchemaCoder.of( - getSchema(clazz, schema), + checkNotNull(getSchema(clazz, schema)), TypeDescriptor.of(clazz), getToRowFunction(clazz, schema), getFromRowFunction(clazz)); @@ -814,9 +855,6 @@ public List get(TypeDescriptor typeDescriptor) { @Override public List get(TypeDescriptor typeDescriptor, Schema schema) { - Map boundTypes = - ReflectUtils.getAllBoundTypes(typeDescriptor); - Map mapping = getMapping(schema); List methods = ReflectUtils.getMethods(typeDescriptor.getRawType()); List types = Lists.newArrayList(); @@ -824,7 +862,7 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = methods.get(i); if (ReflectUtils.isGetter(method)) { FieldValueTypeInformation fieldValueTypeInformation = - FieldValueTypeInformation.forGetter(method, i, boundTypes); + FieldValueTypeInformation.forGetter(typeDescriptor, method, i); String name = mapping.get(fieldValueTypeInformation.getName()); if (name != null) { types.add(fieldValueTypeInformation.withName(name)); @@ -868,8 +906,6 @@ private Map getMapping(Schema schema) { private static final class AvroPojoFieldValueTypeSupplier implements FieldValueTypeSupplier { @Override public List get(TypeDescriptor typeDescriptor) { - Map boundTypes = - ReflectUtils.getAllBoundTypes(typeDescriptor); List classFields = ReflectUtils.getFields(typeDescriptor.getRawType()); Map types = Maps.newHashMap(); @@ -877,7 +913,7 @@ public List get(TypeDescriptor typeDescriptor) { java.lang.reflect.Field f = classFields.get(i); if (!f.isAnnotationPresent(AvroIgnore.class)) { FieldValueTypeInformation typeInformation = - FieldValueTypeInformation.forField(f, i, boundTypes); + FieldValueTypeInformation.forField(typeDescriptor, f, i); AvroName avroname = f.getAnnotation(AvroName.class); if (avroname != null) { typeInformation = typeInformation.withName(avroname.value()); @@ -901,7 +937,7 @@ public static List getFieldTypes( } /** Get generated getters for an AVRO-generated SpecificRecord or a POJO. */ - public static List getGetters( + public static List> getGetters( TypeDescriptor typeDescriptor, Schema schema) { if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { return JavaBeanUtils.getGetters( @@ -974,7 +1010,7 @@ private static FieldType toFieldType(TypeWithNullability type) { break; case FIXED: - fieldType = FixedBytesField.fromAvroType(type.type).toBeamType(); + fieldType = checkNotNull(FixedBytesField.fromAvroType(type.type)).toBeamType(); break; case STRING: @@ -1072,7 +1108,8 @@ private static org.apache.avro.Schema getFieldSchema( break; case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + Schema.LogicalType logicalType = checkNotNull(fieldType.getLogicalType()); + String identifier = logicalType.getIdentifier(); if (FixedBytes.IDENTIFIER.equals(identifier)) { FixedBytesField fixedBytesField = checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); @@ -1083,15 +1120,13 @@ private static org.apache.avro.Schema getFieldSchema( } else if (FixedString.IDENTIFIER.equals(identifier) || "CHAR".equals(identifier) || "NCHAR".equals(identifier)) { - baseType = - buildHiveLogicalTypeSchema("char", (int) fieldType.getLogicalType().getArgument()); + baseType = buildHiveLogicalTypeSchema("char", checkNotNull(logicalType.getArgument())); } else if (VariableString.IDENTIFIER.equals(identifier) || "NVARCHAR".equals(identifier) || "VARCHAR".equals(identifier) || "LONGNVARCHAR".equals(identifier) || "LONGVARCHAR".equals(identifier)) { - baseType = - buildHiveLogicalTypeSchema("varchar", (int) fieldType.getLogicalType().getArgument()); + baseType = buildHiveLogicalTypeSchema("varchar", checkNotNull(logicalType.getArgument())); } else if (EnumerationType.IDENTIFIER.equals(identifier)) { EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); baseType = @@ -1109,7 +1144,7 @@ private static org.apache.avro.Schema getFieldSchema( baseType = LogicalTypes.timeMillis().addToSchema(org.apache.avro.Schema.create(Type.INT)); } else { throw new RuntimeException( - "Unhandled logical type " + fieldType.getLogicalType().getIdentifier()); + "Unhandled logical type " + checkNotNull(fieldType.getLogicalType()).getIdentifier()); } break; @@ -1117,22 +1152,23 @@ private static org.apache.avro.Schema getFieldSchema( case ITERABLE: baseType = org.apache.avro.Schema.createArray( - getFieldSchema(fieldType.getCollectionElementType(), fieldName, namespace)); + getFieldSchema( + checkNotNull(fieldType.getCollectionElementType()), fieldName, namespace)); break; case MAP: - if (fieldType.getMapKeyType().getTypeName().isStringType()) { + if (checkNotNull(fieldType.getMapKeyType()).getTypeName().isStringType()) { // Avro only supports string keys in maps. baseType = org.apache.avro.Schema.createMap( - getFieldSchema(fieldType.getMapValueType(), fieldName, namespace)); + getFieldSchema(checkNotNull(fieldType.getMapValueType()), fieldName, namespace)); } else { throw new IllegalArgumentException("Avro only supports maps with string keys"); } break; case ROW: - baseType = toAvroSchema(fieldType.getRowSchema(), fieldName, namespace); + baseType = toAvroSchema(checkNotNull(fieldType.getRowSchema()), fieldName, namespace); break; default: @@ -1173,7 +1209,9 @@ private static org.apache.avro.Schema getFieldSchema( case DECIMAL: BigDecimal decimal = (BigDecimal) value; LogicalType logicalType = typeWithNullability.type.getLogicalType(); - return new Conversions.DecimalConversion().toBytes(decimal, null, logicalType); + @SuppressWarnings("nullness") + ByteBuffer result = new Conversions.DecimalConversion().toBytes(decimal, null, logicalType); + return result; case DATETIME: if (typeWithNullability.type.getType() == Type.INT) { @@ -1191,7 +1229,7 @@ private static org.apache.avro.Schema getFieldSchema( return ByteBuffer.wrap((byte[]) value); case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + String identifier = checkNotNull(fieldType.getLogicalType()).getIdentifier(); if (FixedBytes.IDENTIFIER.equals(identifier)) { FixedBytesField fixedBytesField = checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); @@ -1199,9 +1237,11 @@ private static org.apache.avro.Schema getFieldSchema( if (byteArray.length != fixedBytesField.getSize()) { throw new IllegalArgumentException("Incorrectly sized byte array."); } - return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + return NullnessCheckerWorkarounds.createFixed( + null, (byte[]) value, typeWithNullability.type); } else if (VariableBytes.IDENTIFIER.equals(identifier)) { - return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + return NullnessCheckerWorkarounds.createFixed( + null, (byte[]) value, typeWithNullability.type); } else if (FixedString.IDENTIFIER.equals(identifier) || "CHAR".equals(identifier) || "NCHAR".equals(identifier)) { @@ -1245,26 +1285,27 @@ private static org.apache.avro.Schema getFieldSchema( case ARRAY: case ITERABLE: Iterable iterable = (Iterable) value; - List translatedArray = Lists.newArrayListWithExpectedSize(Iterables.size(iterable)); + List<@Nullable Object> translatedArray = + Lists.newArrayListWithExpectedSize(Iterables.size(iterable)); for (Object arrayElement : iterable) { translatedArray.add( genericFromBeamField( - fieldType.getCollectionElementType(), + checkNotNull(fieldType.getCollectionElementType()), typeWithNullability.type.getElementType(), arrayElement)); } return translatedArray; case MAP: - Map map = Maps.newHashMap(); + Map map = Maps.newHashMap(); Map valueMap = (Map) value; for (Map.Entry entry : valueMap.entrySet()) { - Utf8 key = new Utf8((String) entry.getKey()); + Utf8 key = new Utf8((String) checkNotNull(entry.getKey())); map.put( key, genericFromBeamField( - fieldType.getMapValueType(), + checkNotNull(fieldType.getMapValueType()), typeWithNullability.type.getValueType(), entry.getValue())); } @@ -1288,8 +1329,8 @@ private static org.apache.avro.Schema getFieldSchema( * @return value converted for {@link Row} */ @SuppressWarnings("unchecked") - public static @Nullable Object convertAvroFieldStrict( - @Nullable Object value, + public static @PolyNull Object convertAvroFieldStrict( + @PolyNull Object value, @Nonnull org.apache.avro.Schema avroSchema, @Nonnull FieldType fieldType) { if (value == null) { @@ -1389,7 +1430,8 @@ private static Object convertBytesStrict(ByteBuffer bb, FieldType fieldType) { private static Object convertFixedStrict(GenericFixed fixed, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "fixed"); - checkArgument(FixedBytes.IDENTIFIER.equals(fieldType.getLogicalType().getIdentifier())); + checkArgument( + FixedBytes.IDENTIFIER.equals(checkNotNull(fieldType.getLogicalType()).getIdentifier())); return fixed.bytes().clone(); // clone because GenericFixed is mutable } @@ -1440,7 +1482,10 @@ private static Object convertBooleanStrict(Boolean value, FieldType fieldType) { private static Object convertEnumStrict(Object value, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "enum"); - checkArgument(fieldType.getLogicalType().getIdentifier().equals(EnumerationType.IDENTIFIER)); + checkArgument( + checkNotNull(fieldType.getLogicalType()) + .getIdentifier() + .equals(EnumerationType.IDENTIFIER)); EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); return enumerationType.valueOf(value.toString()); } @@ -1448,7 +1493,8 @@ private static Object convertEnumStrict(Object value, FieldType fieldType) { private static Object convertUnionStrict( Object value, org.apache.avro.Schema unionAvroSchema, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "oneOfType"); - checkArgument(fieldType.getLogicalType().getIdentifier().equals(OneOfType.IDENTIFIER)); + checkArgument( + checkNotNull(fieldType.getLogicalType()).getIdentifier().equals(OneOfType.IDENTIFIER)); OneOfType oneOfType = fieldType.getLogicalType(OneOfType.class); int fieldNumber = GenericData.get().resolveUnion(unionAvroSchema, value); FieldType baseFieldType = oneOfType.getOneOfSchema().getField(fieldNumber).getType(); @@ -1465,7 +1511,7 @@ private static Object convertArrayStrict( FieldType elemFieldType = fieldType.getCollectionElementType(); for (Object value : values) { - ret.add(convertAvroFieldStrict(value, elemAvroSchema, elemFieldType)); + ret.add(convertAvroFieldStrict(value, elemAvroSchema, checkNotNull(elemFieldType))); } return ret; @@ -1476,10 +1522,10 @@ private static Object convertMapStrict( org.apache.avro.Schema valueAvroSchema, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.MAP, "map"); - checkNotNull(fieldType.getMapKeyType()); - checkNotNull(fieldType.getMapValueType()); + FieldType mapKeyType = checkNotNull(fieldType.getMapKeyType()); + FieldType mapValueType = checkNotNull(fieldType.getMapValueType()); - if (!fieldType.getMapKeyType().equals(FieldType.STRING)) { + if (!FieldType.STRING.equals(fieldType.getMapKeyType())) { throw new IllegalArgumentException( "Can't convert 'string' map keys to " + fieldType.getMapKeyType()); } @@ -1488,8 +1534,8 @@ private static Object convertMapStrict( for (Map.Entry value : values.entrySet()) { ret.put( - convertStringStrict(value.getKey(), fieldType.getMapKeyType()), - convertAvroFieldStrict(value.getValue(), valueAvroSchema, fieldType.getMapValueType())); + convertStringStrict(value.getKey(), mapKeyType), + convertAvroFieldStrict(value.getValue(), valueAvroSchema, mapValueType)); } return ret; diff --git a/sdks/java/extensions/google-cloud-platform-core/build.gradle b/sdks/java/extensions/google-cloud-platform-core/build.gradle index 6cb8d3248ac1..8d21df50006b 100644 --- a/sdks/java/extensions/google-cloud-platform-core/build.gradle +++ b/sdks/java/extensions/google-cloud-platform-core/build.gradle @@ -66,12 +66,10 @@ task integrationTestKms(type: Test) { group = "Verification" def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' def gcpTempRoot = project.findProperty('gcpTempRootKms') ?: 'gs://temp-storage-for-end-to-end-tests-cmek' - def gcpGrpcTempRoot = project.findProperty('gcpGrpcTempRoot') ?: 'gs://gcs-grpc-team-apache-beam-testing' def dataflowKmsKey = project.findProperty('dataflowKmsKey') ?: "projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test" systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ "--project=${gcpProject}", "--tempRoot=${gcpTempRoot}", - "--grpcTempRoot=${gcpGrpcTempRoot}", "--dataflowKmsKey=${dataflowKmsKey}", ]) diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java index 6f1e0e985c24..6477564f01a1 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java @@ -21,7 +21,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; import com.google.protobuf.ByteString; import java.io.IOException; @@ -35,10 +34,7 @@ import org.apache.beam.sdk.extensions.gcp.util.GcsUtil.CreateOptions; import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; import org.apache.beam.sdk.io.FileSystems; -import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.ExperimentalOptions; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.testing.UsesKms; @@ -99,8 +95,6 @@ public void testWriteAndReadGcsWithGrpc() throws IOException { "%s/GcsUtilIT-%tF-% writeGcsTextFile(gcsUtil, wrongFilename, testContent)); - // Write a test file in a bucket with gRPC enabled. - GcsGrpcOptions grpcOptions = options.as(GcsGrpcOptions.class); - assertNotNull(grpcOptions.getGrpcTempRoot()); - String tempLocationWithGrpc = grpcOptions.getGrpcTempRoot() + "/temp"; + String tempLocationWithGrpc = options.getTempRoot() + "/temp"; String filename = String.format(outputPattern, tempLocationWithGrpc, new Date()); writeGcsTextFile(gcsUtil, filename, testContent); @@ -132,15 +117,6 @@ public void testWriteAndReadGcsWithGrpc() throws IOException { gcsUtil.remove(Collections.singletonList(filename)); } - public interface GcsGrpcOptions extends PipelineOptions { - /** Get tempRoot in a gRPC-enabled bucket. */ - @Description("TempRoot in a gRPC-enabled bucket") - String getGrpcTempRoot(); - - /** Set the tempRoot in a gRPC-enabled bucket. */ - void setGrpcTempRoot(String grpcTempRoot); - } - void writeGcsTextFile(GcsUtil gcsUtil, String filename, String content) throws IOException { GcsPath gcsPath = GcsPath.fromUri(filename); try (WritableByteChannel channel = diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java index fcfc40403b43..9fe6162ec936 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -39,7 +39,6 @@ import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.Arrays; -import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -105,16 +104,14 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) class ProtoByteBuddyUtils { private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); private static final TypeDescriptor BYTE_STRING_TYPE_DESCRIPTOR = @@ -271,7 +268,7 @@ static class ProtoConvertType extends ConvertType { .build(); @Override - public Type convert(TypeDescriptor typeDescriptor) { + public Type convert(TypeDescriptor typeDescriptor) { if (typeDescriptor.equals(BYTE_STRING_TYPE_DESCRIPTOR) || typeDescriptor.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { return byte[].class; @@ -298,7 +295,7 @@ protected ProtoTypeConversionsFactory getFactory() { } @Override - public StackManipulation convert(TypeDescriptor type) { + public StackManipulation convert(TypeDescriptor type) { if (type.equals(BYTE_STRING_TYPE_DESCRIPTOR) || type.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { return new Compound( @@ -373,7 +370,7 @@ protected ProtoTypeConversionsFactory getFactory() { } @Override - public StackManipulation convert(TypeDescriptor type) { + public StackManipulation convert(TypeDescriptor type) { if (type.isSubtypeOf(TypeDescriptor.of(ByteString.class))) { return new Compound( readValue, @@ -460,7 +457,7 @@ public TypeConversion createSetterConversions(StackManipulati // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = + private static final Map>> CACHED_GETTERS = Maps.newConcurrentMap(); /** @@ -468,35 +465,36 @@ public TypeConversion createSetterConversions(StackManipulati * *

    The returned list is ordered by the order of fields in the schema. */ - public static List getGetters( - Class clazz, + public static List> getGetters( + Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { Multimap methods = ReflectUtils.getMethodsMap(clazz); - return CACHED_GETTERS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), - c -> { - List types = - fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); - return types.stream() - .map( - t -> - createGetter( - t, - typeConversionsFactory, - clazz, - methods, - schema.getField(t.getName()), - fieldValueTypeSupplier)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + ClassWithSchema.create(clazz, schema), + c -> { + List types = + fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); + return types.stream() + .map( + t -> + createGetter( + t, + typeConversionsFactory, + clazz, + methods, + schema.getField(t.getName()), + fieldValueTypeSupplier)) + .collect(Collectors.toList()); + }); } - static FieldValueGetter createOneOfGetter( + static FieldValueGetter<@NonNull ProtoT, OneOfType.Value> createOneOfGetter( FieldValueTypeInformation typeInformation, - TreeMap> getterMethodMap, - Class protoClass, + TreeMap> getterMethodMap, + Class protoClass, OneOfType oneOfType, Method getCaseMethod) { Set indices = getterMethodMap.keySet(); @@ -506,7 +504,7 @@ static FieldValueGetter createOneOfGetter( int[] keys = getterMethodMap.keySet().stream().mapToInt(Integer::intValue).toArray(); - DynamicType.Builder builder = + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface(BYTE_BUDDY, protoClass, OneOfType.Value.class); builder = builder @@ -515,7 +513,8 @@ static FieldValueGetter createOneOfGetter( .method(ElementMatchers.named("get")) .intercept(new OneOfGetterInstruction(contiguous, keys, getCaseMethod)); - List getters = Lists.newArrayList(getterMethodMap.values()); + List> getters = + Lists.newArrayList(getterMethodMap.values()); builder = builder // Store a field with the list of individual getters. The get() instruction will pick @@ -557,12 +556,12 @@ static FieldValueGetter createOneOfGetter( FieldValueSetter createOneOfSetter( String name, TreeMap> setterMethodMap, - Class protoBuilderClass) { + Class protoBuilderClass) { Set indices = setterMethodMap.keySet(); boolean contiguous = isContiguous(indices); int[] keys = setterMethodMap.keySet().stream().mapToInt(Integer::intValue).toArray(); - DynamicType.Builder builder = + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, protoBuilderClass, OneOfType.Value.class); builder = @@ -586,7 +585,8 @@ FieldValueSetter createOneOfSetter( .withParameters(List.class) .intercept(new OneOfSetterConstructor()); - List setters = Lists.newArrayList(setterMethodMap.values()); + List> setters = + Lists.newArrayList(setterMethodMap.values()); try { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) @@ -948,10 +948,10 @@ public ByteCodeAppender appender(final Target implementationTarget) { } } - private static FieldValueGetter createGetter( + private static FieldValueGetter<@NonNull ProtoT, ?> createGetter( FieldValueTypeInformation fieldValueTypeInformation, TypeConversionsFactory typeConversionsFactory, - Class clazz, + Class clazz, Multimap methods, Field field, FieldValueTypeSupplier fieldValueTypeSupplier) { @@ -965,21 +965,23 @@ private static FieldValueGetter createGetter( field.getName() + "_case", FieldType.logicalType(oneOfType.getCaseEnumType())); // Create a map of case enum value to getter. This must be sorted, so store in a TreeMap. - TreeMap> oneOfGetters = Maps.newTreeMap(); + TreeMap> oneOfGetters = + Maps.newTreeMap(); Map oneOfFieldTypes = fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), oneOfType.getOneOfSchema()).stream() .collect(Collectors.toMap(FieldValueTypeInformation::getName, f -> f)); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { int protoFieldIndex = getFieldNumber(oneOfField); - FieldValueGetter oneOfFieldGetter = + FieldValueGetter<@NonNull ProtoT, ?> oneOfFieldGetter = createGetter( - oneOfFieldTypes.get(oneOfField.getName()), + Verify.verifyNotNull(oneOfFieldTypes.get(oneOfField.getName())), typeConversionsFactory, clazz, methods, oneOfField, fieldValueTypeSupplier); - oneOfGetters.put(protoFieldIndex, oneOfFieldGetter); + oneOfGetters.put( + protoFieldIndex, (FieldValueGetter<@NonNull ProtoT, OneOfType.Value>) oneOfFieldGetter); } return createOneOfGetter( fieldValueTypeInformation, oneOfGetters, clazz, oneOfType, caseMethod); @@ -988,10 +990,11 @@ private static FieldValueGetter createGetter( } } - private static Class getProtoGeneratedBuilder(Class clazz) { + private static @Nullable Class getProtoGeneratedBuilder( + Class clazz) { String builderClassName = clazz.getName() + "$Builder"; try { - return Class.forName(builderClassName); + return (Class) Class.forName(builderClassName); } catch (ClassNotFoundException e) { return null; } @@ -1019,27 +1022,33 @@ static Method getProtoGetter(Multimap methods, String name, Fiel public static @Nullable SchemaUserTypeCreator getBuilderCreator( - Class protoClass, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { - Class builderClass = getProtoGeneratedBuilder(protoClass); + TypeDescriptor protoTypeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier) { + Class builderClass = getProtoGeneratedBuilder(protoTypeDescriptor.getRawType()); if (builderClass == null) { return null; } Multimap methods = ReflectUtils.getMethodsMap(builderClass); List> setters = schema.getFields().stream() - .map(f -> getProtoFieldValueSetter(f, methods, builderClass)) + .map(f -> getProtoFieldValueSetter(protoTypeDescriptor, f, methods, builderClass)) .collect(Collectors.toList()); - return createBuilderCreator(protoClass, builderClass, setters, schema); + return createBuilderCreator(protoTypeDescriptor.getRawType(), builderClass, setters, schema); } private static FieldValueSetter getProtoFieldValueSetter( - Field field, Multimap methods, Class builderClass) { + TypeDescriptor typeDescriptor, + Field field, + Multimap methods, + Class builderClass) { if (field.getType().isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = field.getType().getLogicalType(OneOfType.class); TreeMap> oneOfSetters = Maps.newTreeMap(); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { - FieldValueSetter setter = getProtoFieldValueSetter(oneOfField, methods, builderClass); + FieldValueSetter setter = + getProtoFieldValueSetter(typeDescriptor, oneOfField, methods, builderClass); oneOfSetters.put(getFieldNumber(oneOfField), setter); } return createOneOfSetter(field.getName(), oneOfSetters, builderClass); @@ -1047,24 +1056,25 @@ FieldValueSetter getProtoFieldValueSetter( Method method = getProtoSetter(methods, field.getName(), field.getType()); return JavaBeanUtils.createSetter( FieldValueTypeInformation.forSetter( - method, protoSetterPrefix(field.getType()), Collections.emptyMap()), + typeDescriptor, method, protoSetterPrefix(field.getType())), new ProtoTypeConversionsFactory()); } } static SchemaUserTypeCreator createBuilderCreator( Class protoClass, - Class builderClass, + Class builderClass, List> setters, Schema schema) { try { - DynamicType.Builder builder = - BYTE_BUDDY - .with(new InjectPackageStrategy(builderClass)) - .subclass(Supplier.class) - .method(ElementMatchers.named("get")) - .intercept(new BuilderSupplier(protoClass)); - Supplier supplier = + DynamicType.Builder> builder = + (DynamicType.Builder) + BYTE_BUDDY + .with(new InjectPackageStrategy(builderClass)) + .subclass(Supplier.class) + .method(ElementMatchers.named("get")) + .intercept(new BuilderSupplier(protoClass)); + Supplier supplier = builder .visit( new AsmVisitorWrapper.ForDeclaredMethods() diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java index 4b8d51abdea6..b0bb9071524b 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java @@ -23,7 +23,6 @@ import com.google.protobuf.DynamicMessage; import com.google.protobuf.Message; import java.lang.reflect.Method; -import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.beam.sdk.extensions.protobuf.ProtoByteBuddyUtils.ProtoTypeConversionsFactory; @@ -44,12 +43,9 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) public class ProtoMessageSchema extends GetterBasedSchemaProviderV2 { private static final class ProtoClassFieldValueTypeSupplier implements FieldValueTypeSupplier { @@ -73,7 +69,7 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = getProtoGetter(methods, oneOfField.getName(), oneOfField.getType()); oneOfTypes.put( oneOfField.getName(), - FieldValueTypeInformation.forGetter(method, i, Collections.emptyMap()) + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) .withName(field.getName())); } // Add an entry that encapsulates information about all possible getters. @@ -85,7 +81,7 @@ public List get(TypeDescriptor typeDescriptor, Sch // This is a simple field. Add the getter. Method method = getProtoGetter(methods, field.getName(), field.getType()); types.add( - FieldValueTypeInformation.forGetter(method, i, Collections.emptyMap()) + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) .withName(field.getName())); } } @@ -100,8 +96,8 @@ public List get(TypeDescriptor typeDescriptor, Sch } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return ProtoByteBuddyUtils.getGetters( targetTypeDescriptor.getRawType(), schema, @@ -121,7 +117,7 @@ public SchemaUserTypeCreator schemaTypeCreator( TypeDescriptor targetTypeDescriptor, Schema schema) { SchemaUserTypeCreator creator = ProtoByteBuddyUtils.getBuilderCreator( - targetTypeDescriptor.getRawType(), schema, new ProtoClassFieldValueTypeSupplier()); + targetTypeDescriptor, schema, new ProtoClassFieldValueTypeSupplier()); if (creator == null) { throw new RuntimeException("Cannot create creator for " + targetTypeDescriptor); } @@ -156,7 +152,8 @@ public static SimpleFunction getRowToProtoBytesFn(Class claz private void checkForDynamicType(TypeDescriptor typeDescriptor) { if (typeDescriptor.getRawType().equals(DynamicMessage.class)) { throw new RuntimeException( - "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use ProtoDynamicMessageSchema instead."); + "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use" + + " ProtoDynamicMessageSchema instead."); } } diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java index 64f600903d87..d5f1745a9a2c 100644 --- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java +++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java @@ -25,7 +25,6 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -390,8 +389,7 @@ private Schema generateSchemaDirectly( fieldName, StaticSchemaInference.fieldFromType( TypeDescriptor.of(field.getClass()), - JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap())); + JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE)); } counter++; diff --git a/sdks/java/extensions/sql/expansion-service/build.gradle b/sdks/java/extensions/sql/expansion-service/build.gradle index b6963cf7547b..b8d78e4e1bb9 100644 --- a/sdks/java/extensions/sql/expansion-service/build.gradle +++ b/sdks/java/extensions/sql/expansion-service/build.gradle @@ -46,3 +46,7 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.main.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } + +shadowJar { + outputs.upToDateWhen { false } +} \ No newline at end of file diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java index 372e77c54c67..d0f6427a262e 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java @@ -36,7 +36,6 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.MockConsumer; @@ -138,10 +137,6 @@ public synchronized void assign(final Collection assigned) { .collect(Collectors.toList()); super.assign(realPartitions); assignedPartitions.set(ImmutableList.copyOf(realPartitions)); - for (TopicPartition tp : realPartitions) { - updateBeginningOffsets(ImmutableMap.of(tp, 0L)); - updateEndOffsets(ImmutableMap.of(tp, (long) kafkaRecords.get(tp).size())); - } } // Override offsetsForTimes() in order to look up the offsets by timestamp. @Override @@ -163,9 +158,12 @@ public synchronized Map offsetsForTimes( } }; - for (String topic : getTopics()) { - consumer.updatePartitions(topic, partitionInfoMap.get(topic)); - } + partitionInfoMap.forEach(consumer::updatePartitions); + consumer.updateBeginningOffsets( + kafkaRecords.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> 0L))); + consumer.updateEndOffsets( + kafkaRecords.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size()))); Runnable recordEnqueueTask = new Runnable() { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index c91d5ba71b89..0d517503b12d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -64,7 +64,6 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; -import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.Coder; @@ -93,6 +92,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; @@ -108,20 +108,19 @@ import org.slf4j.LoggerFactory; /** - * Processes {@link BeamFnApi.ProcessBundleRequest}s and {@link - * BeamFnApi.ProcessBundleSplitRequest}s. + * Processes {@link ProcessBundleRequest}s and {@link BeamFnApi.ProcessBundleSplitRequest}s. * *

    {@link BeamFnApi.ProcessBundleSplitRequest}s use a {@link BundleProcessorCache cache} to * find/create a {@link BundleProcessor}. The creation of a {@link BundleProcessor} uses the - * associated {@link BeamFnApi.ProcessBundleDescriptor} definition; creating runners for each {@link + * associated {@link ProcessBundleDescriptor} definition; creating runners for each {@link * RunnerApi.FunctionSpec}; wiring them together based upon the {@code input} and {@code output} map * definitions. The {@link BundleProcessor} executes the DAG based graph by starting all runners in * reverse topological order, and finishing all runners in forward topological order. * *

    {@link BeamFnApi.ProcessBundleSplitRequest}s finds an {@code active} {@link BundleProcessor} - * associated with a currently processing {@link BeamFnApi.ProcessBundleRequest} and uses it to - * perform a split request. See breaking the - * fusion barrier for further details. + * associated with a currently processing {@link ProcessBundleRequest} and uses it to perform a + * split request. See breaking the fusion + * barrier for further details. */ @SuppressWarnings({ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) @@ -153,7 +152,7 @@ public class ProcessBundleHandler { } private final PipelineOptions options; - private final Function fnApiRegistry; + private final Function fnApiRegistry; private final BeamFnDataClient beamFnDataClient; private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache; private final FinalizeBundleHandler finalizeBundleHandler; @@ -170,7 +169,7 @@ public class ProcessBundleHandler { public ProcessBundleHandler( PipelineOptions options, Set runnerCapabilities, - Function fnApiRegistry, + Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, FinalizeBundleHandler finalizeBundleHandler, @@ -197,7 +196,7 @@ public ProcessBundleHandler( ProcessBundleHandler( PipelineOptions options, Set runnerCapabilities, - Function fnApiRegistry, + Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, FinalizeBundleHandler finalizeBundleHandler, @@ -216,7 +215,7 @@ public ProcessBundleHandler( this.runnerCapabilities = runnerCapabilities; this.runnerAcceptsShortIds = runnerCapabilities.contains( - BeamUrns.getUrn(RunnerApi.StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS)); + BeamUrns.getUrn(StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS)); this.executionStateSampler = executionStateSampler; this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap; this.defaultPTransformRunnerFactory = @@ -232,7 +231,7 @@ private void createRunnerAndConsumersForPTransformRecursively( String pTransformId, PTransform pTransform, Supplier processBundleInstructionId, - Supplier> cacheTokens, + Supplier> cacheTokens, Supplier> bundleCache, ProcessBundleDescriptor processBundleDescriptor, SetMultimap pCollectionIdsToConsumingPTransforms, @@ -242,7 +241,7 @@ private void createRunnerAndConsumersForPTransformRecursively( PTransformFunctionRegistry finishFunctionRegistry, Consumer addResetFunction, Consumer addTearDownFunction, - BiConsumer> addDataEndpoint, + BiConsumer> addDataEndpoint, Consumer> addTimerEndpoint, Consumer addBundleProgressReporter, BundleSplitListener splitListener, @@ -499,28 +498,29 @@ public BundleFinalizer getBundleFinalizer() { * Processes a bundle, running the start(), process(), and finish() functions. This function is * required to be reentrant. */ - public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request) + public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest request) throws Exception { - BeamFnApi.ProcessBundleResponse.Builder response = BeamFnApi.ProcessBundleResponse.newBuilder(); - - BundleProcessor bundleProcessor = - bundleProcessorCache.get( - request, - () -> { - try { - return createBundleProcessor( - request.getProcessBundle().getProcessBundleDescriptorId(), - request.getProcessBundle()); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); + @Nullable BundleProcessor bundleProcessor = null; try { + bundleProcessor = + Preconditions.checkNotNull( + bundleProcessorCache.get( + request, + () -> { + try { + return createBundleProcessor( + request.getProcessBundle().getProcessBundleDescriptorId(), + request.getProcessBundle()); + } catch (IOException e) { + throw new RuntimeException(e); + } + })); + PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry(); PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry(); ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker(); - + ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder(); try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) { stateTracker.start(request.getInstructionId()); try { @@ -596,12 +596,17 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction request.getProcessBundle().getProcessBundleDescriptorId(), bundleProcessor); return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response); } catch (Exception e) { - // Make sure we clean up from the active set of bundle processors. LOG.debug( - "Discard bundleProcessor for {} after exception: {}", + "Error processing bundle {} with bundleProcessor for {} after exception: {}", + request.getInstructionId(), request.getProcessBundle().getProcessBundleDescriptorId(), e.getMessage()); - bundleProcessorCache.discard(bundleProcessor); + if (bundleProcessor != null) { + // Make sure we clean up from the active set of bundle processors. + bundleProcessorCache.discard(bundleProcessor); + } + // Ensure that if more data arrives for the instruction it is discarded. + beamFnDataClient.poisonInstructionId(request.getInstructionId()); throw e; } } @@ -643,7 +648,7 @@ private void embedOutboundElementsIfApplicable( } } - public BeamFnApi.InstructionResponse.Builder progress(BeamFnApi.InstructionRequest request) + public BeamFnApi.InstructionResponse.Builder progress(InstructionRequest request) throws Exception { BundleProcessor bundleProcessor = bundleProcessorCache.find(request.getProcessBundleProgress().getInstructionId()); @@ -727,7 +732,7 @@ private Map finalMonitoringData(BundleProcessor bundleProces } /** Splits an active bundle. */ - public BeamFnApi.InstructionResponse.Builder trySplit(BeamFnApi.InstructionRequest request) { + public BeamFnApi.InstructionResponse.Builder trySplit(InstructionRequest request) { BundleProcessor bundleProcessor = bundleProcessorCache.find(request.getProcessBundleSplit().getInstructionId()); BeamFnApi.ProcessBundleSplitResponse.Builder response = @@ -772,8 +777,8 @@ public void discard() { } private BundleProcessor createBundleProcessor( - String bundleId, BeamFnApi.ProcessBundleRequest processBundleRequest) throws IOException { - BeamFnApi.ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId); + String bundleId, ProcessBundleRequest processBundleRequest) throws IOException { + ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId); SetMultimap pCollectionIdsToConsumingPTransforms = HashMultimap.create(); BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar = @@ -799,8 +804,7 @@ private BundleProcessor createBundleProcessor( List tearDownFunctions = new ArrayList<>(); // Build a multimap of PCollection ids to PTransform ids which consume said PCollections - for (Map.Entry entry : - bundleDescriptor.getTransformsMap().entrySet()) { + for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { for (String pCollectionId : entry.getValue().getInputsMap().values()) { pCollectionIdsToConsumingPTransforms.put(pCollectionId, entry.getKey()); } @@ -848,8 +852,7 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) { runnerCapabilities); // Create a BeamFnStateClient - for (Map.Entry entry : - bundleDescriptor.getTransformsMap().entrySet()) { + for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { // Skip anything which isn't a root. // Also force data output transforms to be unconditionally instantiated (see BEAM-10450). @@ -1090,7 +1093,7 @@ public static BundleProcessor create( abstract HandleStateCallsForBundle getBeamFnStateClient(); - abstract List getInboundEndpointApiServiceDescriptors(); + abstract List getInboundEndpointApiServiceDescriptors(); abstract List> getInboundDataEndpoints(); @@ -1117,7 +1120,7 @@ synchronized List getCacheTokens() { synchronized Cache getBundleCache() { if (this.bundleCache == null) { this.bundleCache = - new Caches.ClearableCache<>( + new ClearableCache<>( Caches.subCache(getProcessWideCache(), "Bundle", this.instructionId)); } return this.bundleCache; @@ -1264,7 +1267,7 @@ public void close() throws Exception { } @Override - public CompletableFuture handle(BeamFnApi.StateRequest.Builder requestBuilder) { + public CompletableFuture handle(StateRequest.Builder requestBuilder) { throw new IllegalStateException( String.format( "State API calls are unsupported because the " diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java index 75f3a24301c9..94d59d0fcb62 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java @@ -55,10 +55,19 @@ void registerReceiver( * successfully. * *

    It is expected that if a bundle fails during processing then the failure will become visible - * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation. + * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation or via + * a call to {@link #poisonInstructionId}. */ void unregisterReceiver(String instructionId, List apiServiceDescriptors); + /** + * Poisons the instruction id, indicating that future data arriving for it should be discarded. + * Unregisters the receiver if was registered. + * + * @param instructionId + */ + void poisonInstructionId(String instructionId); + /** * Creates a {@link BeamFnDataOutboundAggregator} for buffering and sending outbound data and * timers over the data plane. It is important that {@link diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java index 981b115c58e7..cd1ac26e364d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java @@ -82,6 +82,14 @@ public void unregisterReceiver( } } + @Override + public void poisonInstructionId(String instructionId) { + LOG.debug("Poisoning instruction {}", instructionId); + for (BeamFnDataGrpcMultiplexer client : multiplexerCache.values()) { + client.poisonInstructionId(instructionId); + } + } + @Override public BeamFnDataOutboundAggregator createOutboundAggregator( ApiServiceDescriptor apiServiceDescriptor, diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 3b9fccfa2a5e..81a2aa6d1cc6 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -105,7 +105,7 @@ public long getWeight() { // many different state subcaches. return 0; } - }; + } /** A mutable iterable that supports prefetch and is backed by a cache. */ static class CachingStateIterable extends PrefetchableIterables.Default { @@ -138,8 +138,8 @@ public long getWeight() { private static long sumWeight(List> blocks) { try { long sum = 0; - for (int i = 0; i < blocks.size(); ++i) { - sum = Math.addExact(sum, blocks.get(i).getWeight()); + for (Block block : blocks) { + sum = Math.addExact(sum, block.getWeight()); } return sum; } catch (ArithmeticException e) { @@ -437,50 +437,59 @@ public boolean hasNext() { if (currentBlock.getValues().size() > currentCachedBlockValueIndex) { return true; } - if (currentBlock.getNextToken() == null) { + final ByteString nextToken = currentBlock.getNextToken(); + if (nextToken == null) { return false; } - Blocks existing = cache.peek(IterableCacheKey.INSTANCE); - boolean isFirstBlock = ByteString.EMPTY.equals(currentBlock.getNextToken()); + // Release the block while we are loading the next one. + currentBlock = + Block.fromValues(new WeightedList<>(Collections.emptyList(), 0L), ByteString.EMPTY); + + @Nullable Blocks existing = cache.peek(IterableCacheKey.INSTANCE); + boolean isFirstBlock = ByteString.EMPTY.equals(nextToken); if (existing == null) { - currentBlock = loadNextBlock(currentBlock.getNextToken()); + currentBlock = loadNextBlock(nextToken); if (isFirstBlock) { cache.put( IterableCacheKey.INSTANCE, new BlocksPrefix<>(Collections.singletonList(currentBlock))); } + } else if (isFirstBlock) { + currentBlock = existing.getBlocks().get(0); } else { - if (isFirstBlock) { - currentBlock = existing.getBlocks().get(0); - } else { - checkState( - existing instanceof BlocksPrefix, - "Unexpected blocks type %s, expected a %s.", - existing.getClass(), - BlocksPrefix.class); - List> blocks = existing.getBlocks(); - int currentBlockIndex = 0; - for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { - if (currentBlock - .getNextToken() - .equals(blocks.get(currentBlockIndex).getNextToken())) { - break; - } + checkState( + existing instanceof BlocksPrefix, + "Unexpected blocks type %s, expected a %s.", + existing.getClass(), + BlocksPrefix.class); + List> blocks = existing.getBlocks(); + int currentBlockIndex = 0; + for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { + if (nextToken.equals(blocks.get(currentBlockIndex).getNextToken())) { + break; } - // Load the next block from cache if it was found. - if (currentBlockIndex + 1 < blocks.size()) { - currentBlock = blocks.get(currentBlockIndex + 1); - } else { - // Otherwise load the block from state API. - currentBlock = loadNextBlock(currentBlock.getNextToken()); - - // Append this block to the existing set of blocks if it is logically the next one. - if (currentBlockIndex == blocks.size() - 1) { - List> newBlocks = new ArrayList<>(currentBlockIndex + 1); - newBlocks.addAll(blocks); - newBlocks.add(currentBlock); - cache.put(IterableCacheKey.INSTANCE, new BlocksPrefix<>(newBlocks)); - } + } + // Take the next block from the cache if it was found. + if (currentBlockIndex + 1 < blocks.size()) { + currentBlock = blocks.get(currentBlockIndex + 1); + } else { + // Otherwise load the block from state API. + // Remove references on the cached values while we are loading the next block. + existing = null; + blocks = null; + currentBlock = loadNextBlock(nextToken); + existing = cache.peek(IterableCacheKey.INSTANCE); + // Append this block to the existing set of blocks if it is logically the next one + // according to the + // tokens. + if (existing != null + && !existing.getBlocks().isEmpty() + && nextToken.equals( + existing.getBlocks().get(existing.getBlocks().size() - 1).getNextToken())) { + List> newBlocks = new ArrayList<>(currentBlockIndex + 1); + newBlocks.addAll(existing.getBlocks()); + newBlocks.add(currentBlock); + cache.put(IterableCacheKey.INSTANCE, new BlocksPrefix<>(newBlocks)); } } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java index 9328dc86c009..acfd3bb70202 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java @@ -92,6 +92,11 @@ public BeamFnDataOutboundAggregator createOutboundAggregator( boolean collectElementsIfNoFlushes) { throw new UnsupportedOperationException("Unexpected call during test."); } + + @Override + public void poisonInstructionId(String instructionId) { + throw new UnsupportedOperationException("Unexpected call during test."); + } }) .beamFnStateClient( new BeamFnStateClient() { diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 2d1e323707f7..95b404aa6203 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -1516,6 +1516,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { // Ensure that we unregister during successful processing verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); + verify(beamFnDataClient).poisonInstructionId(eq("instructionId")); verifyNoMoreInteractions(beamFnDataClient); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java index 3489fe766891..514cf61ded40 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java @@ -23,14 +23,17 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.model.fnexecution.v1.BeamFnApi; @@ -281,6 +284,93 @@ public StreamObserver data( } } + @Test + public void testForInboundConsumerThatIsPoisoned() throws Exception { + CountDownLatch waitForClientToConnect = new CountDownLatch(1); + CountDownLatch receivedAElement = new CountDownLatch(1); + Collection> inboundValuesA = new ConcurrentLinkedQueue<>(); + Collection inboundServerValues = new ConcurrentLinkedQueue<>(); + AtomicReference> outboundServerObserver = + new AtomicReference<>(); + CallStreamObserver inboundServerObserver = + TestStreams.withOnNext(inboundServerValues::add).build(); + + Endpoints.ApiServiceDescriptor apiServiceDescriptor = + Endpoints.ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID()) + .build(); + Server server = + InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService( + new BeamFnDataGrpc.BeamFnDataImplBase() { + @Override + public StreamObserver data( + StreamObserver outboundObserver) { + outboundServerObserver.set(outboundObserver); + waitForClientToConnect.countDown(); + return inboundServerObserver; + } + }) + .build(); + server.start(); + + try { + ManagedChannel channel = + InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + BeamFnDataGrpcClient clientFactory = + new BeamFnDataGrpcClient( + PipelineOptionsFactory.create(), + (Endpoints.ApiServiceDescriptor descriptor) -> channel, + OutboundObserverFactory.trivial()); + + BeamFnDataInboundObserver observerA = + BeamFnDataInboundObserver.forConsumers( + Arrays.asList( + DataEndpoint.create( + TRANSFORM_ID_A, + CODER, + (WindowedValue elem) -> { + receivedAElement.countDown(); + inboundValuesA.add(elem); + })), + Collections.emptyList()); + CompletableFuture future = + CompletableFuture.runAsync( + () -> { + try { + observerA.awaitCompletion(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + clientFactory.registerReceiver( + INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + + waitForClientToConnect.await(); + outboundServerObserver.get().onNext(ELEMENTS_B_1); + clientFactory.poisonInstructionId(INSTRUCTION_ID_B); + + outboundServerObserver.get().onNext(ELEMENTS_B_1); + outboundServerObserver.get().onNext(ELEMENTS_A_1); + assertTrue(receivedAElement.await(5, TimeUnit.SECONDS)); + + clientFactory.poisonInstructionId(INSTRUCTION_ID_A); + try { + future.get(); + fail(); // We expect the awaitCompletion to fail due to closing. + } catch (Exception ignored) { + } + + outboundServerObserver.get().onNext(ELEMENTS_A_2); + + assertThat(inboundValuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + } finally { + server.shutdownNow(); + } + } + @Test public void testForOutboundConsumer() throws Exception { CountDownLatch waitForInboundServerValuesCompletion = new CountDownLatch(2); diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java index acdfcfc1ad09..e8b05a8a319e 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java @@ -19,7 +19,6 @@ import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; -import static org.apache.beam.sdk.io.aws2.schemas.AwsSchemaUtils.getter; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets.difference; @@ -46,6 +45,7 @@ import org.apache.beam.sdk.values.RowWithGetters; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; @@ -73,17 +73,20 @@ public class AwsSchemaProvider extends GetterBasedSchemaProviderV2 { return AwsTypes.schemaFor(sdkFields((Class) type.getRawType())); } - @SuppressWarnings("rawtypes") @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { ConverterFactory fromAws = ConverterFactory.fromAws(); Map> sdkFields = sdkFieldsByName((Class) targetTypeDescriptor.getRawType()); - List getters = new ArrayList<>(schema.getFieldCount()); - for (String field : schema.getFieldNames()) { + List> getters = new ArrayList<>(schema.getFieldCount()); + for (@NonNull String field : schema.getFieldNames()) { SdkField sdkField = checkStateNotNull(sdkFields.get(field), "Unknown field"); - getters.add(getter(field, fromAws.create(sdkField::getValueOrDefault, sdkField))); + getters.add( + AwsSchemaUtils.getter( + field, + (SerializableFunction<@NonNull T, Object>) + fromAws.create(sdkField::getValueOrDefault, sdkField))); } return getters; } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java index d36c197d80a4..9e994702fe61 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.awssdk.core.SdkPojo; import software.amazon.awssdk.utils.builder.SdkBuilder; @@ -78,7 +79,7 @@ static SdkBuilderSetter setter(String name, BiConsumer, Object> return new ValueSetter(name, setter); } - static FieldValueGetter getter( + static FieldValueGetter getter( String name, SerializableFunction getter) { return new ValueGetter<>(name, getter); } @@ -107,7 +108,8 @@ public String name() { } } - private static class ValueGetter implements FieldValueGetter { + private static class ValueGetter + implements FieldValueGetter { private final SerializableFunction getter; private final String name; diff --git a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java index 2bfa694aebc0..970d9483850c 100644 --- a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java +++ b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java @@ -56,7 +56,7 @@ public class DebeziumIOPostgresSqlConnectorIT { @ClassRule public static final PostgreSQLContainer POSTGRES_SQL_CONTAINER = new PostgreSQLContainer<>( - DockerImageName.parse("debezium/example-postgres:latest") + DockerImageName.parse("quay.io/debezium/example-postgres:latest") .asCompatibleSubstituteFor("postgres")) .withPassword("dbz") .withUsername("debezium") diff --git a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java index c75621040913..c4b5d2d1f890 100644 --- a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java +++ b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java @@ -46,7 +46,7 @@ public class DebeziumReadSchemaTransformTest { @ClassRule public static final PostgreSQLContainer POSTGRES_SQL_CONTAINER = new PostgreSQLContainer<>( - DockerImageName.parse("debezium/example-postgres:latest") + DockerImageName.parse("quay.io/debezium/example-postgres:latest") .asCompatibleSubstituteFor("postgres")) .withPassword("dbz") .withUsername("debezium") diff --git a/sdks/java/io/expansion-service/build.gradle b/sdks/java/io/expansion-service/build.gradle index 8b817163ae39..cc8eccf98997 100644 --- a/sdks/java/io/expansion-service/build.gradle +++ b/sdks/java/io/expansion-service/build.gradle @@ -35,6 +35,7 @@ configurations.runtimeClasspath { shadowJar { mergeServiceFiles() + outputs.upToDateWhen { false } } description = "Apache Beam :: SDKs :: Java :: IO :: Expansion Service" diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle index 3e322d976c1a..2acce3e94cc2 100644 --- a/sdks/java/io/google-cloud-platform/build.gradle +++ b/sdks/java/io/google-cloud-platform/build.gradle @@ -159,6 +159,7 @@ dependencies { testImplementation project(path: ":sdks:java:extensions:google-cloud-platform-core", configuration: "testRuntimeMigration") testImplementation project(path: ":sdks:java:extensions:protobuf", configuration: "testRuntimeMigration") testImplementation project(path: ":runners:direct-java", configuration: "shadow") + testImplementation project(":sdks:java:managed") testImplementation project(path: ":sdks:java:io:common") testImplementation project(path: ":sdks:java:testing:test-utils") testImplementation library.java.commons_math3 diff --git a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle index 1288d91964e1..f6c6f07d0cdf 100644 --- a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle +++ b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle @@ -36,6 +36,9 @@ dependencies { permitUnusedDeclared project(":sdks:java:io:google-cloud-platform") // BEAM-11761 implementation project(":sdks:java:extensions:schemaio-expansion-service") permitUnusedDeclared project(":sdks:java:extensions:schemaio-expansion-service") // BEAM-11761 + implementation project(":sdks:java:managed") + permitUnusedDeclared project(":sdks:java:managed") // BEAM-11761 + runtimeOnly library.java.slf4j_jdk14 } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java index 7a5aa2408d2e..d7ca787feea3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java @@ -31,7 +31,6 @@ import java.time.LocalTime; import java.time.temporal.ChronoUnit; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; @@ -221,11 +220,18 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) { case ITERABLE: @Nullable FieldType elementType = field.getType().getCollectionElementType(); if (elementType == null) { - throw new RuntimeException("Unexpected null element type!"); + throw new RuntimeException("Unexpected null element type on " + field.getName()); } + TypeName containedTypeName = + Preconditions.checkNotNull( + elementType.getTypeName(), + "Null type name found in contained type at " + field.getName()); Preconditions.checkState( - !Preconditions.checkNotNull(elementType.getTypeName()).isCollectionType(), - "Nested arrays not supported by BigQuery."); + !(containedTypeName.isCollectionType() || containedTypeName.isMapType()), + "Nested container types are not supported by BigQuery. Field " + + field.getName() + + " contains a type " + + containedTypeName.name()); TableFieldSchema elementFieldSchema = fieldDescriptorFromBeamField(Field.of(field.getName(), elementType)); builder = builder.setType(elementFieldSchema.getType()); @@ -244,7 +250,24 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) { builder = builder.setType(type); break; case MAP: - throw new RuntimeException("Map types not supported by BigQuery."); + @Nullable FieldType keyType = field.getType().getMapKeyType(); + @Nullable FieldType valueType = field.getType().getMapValueType(); + if (keyType == null) { + throw new RuntimeException( + "Unexpected null element type for the map's key on " + field.getName()); + } + if (valueType == null) { + throw new RuntimeException( + "Unexpected null element type for the map's value on " + field.getName()); + } + + builder = + builder + .setType(TableFieldSchema.Type.STRUCT) + .addFields(fieldDescriptorFromBeamField(Field.of("key", keyType))) + .addFields(fieldDescriptorFromBeamField(Field.of("value", valueType))) + .setMode(TableFieldSchema.Mode.REPEATED); + break; default: @Nullable TableFieldSchema.Type primitiveType = PRIMITIVE_TYPES.get(field.getType().getTypeName()); @@ -289,25 +312,34 @@ private static Object toProtoValue( case ROW: return messageFromBeamRow(fieldDescriptor.getMessageType(), (Row) value, null, -1); case ARRAY: - List list = (List) value; - @Nullable FieldType arrayElementType = beamFieldType.getCollectionElementType(); - if (arrayElementType == null) { - throw new RuntimeException("Unexpected null element type!"); - } - return list.stream() - .map(v -> toProtoValue(fieldDescriptor, arrayElementType, v)) - .collect(Collectors.toList()); case ITERABLE: Iterable iterable = (Iterable) value; @Nullable FieldType iterableElementType = beamFieldType.getCollectionElementType(); if (iterableElementType == null) { - throw new RuntimeException("Unexpected null element type!"); + throw new RuntimeException("Unexpected null element type: " + fieldDescriptor.getName()); } + return StreamSupport.stream(iterable.spliterator(), false) .map(v -> toProtoValue(fieldDescriptor, iterableElementType, v)) .collect(Collectors.toList()); case MAP: - throw new RuntimeException("Map types not supported by BigQuery."); + Map map = (Map) value; + @Nullable FieldType keyType = beamFieldType.getMapKeyType(); + @Nullable FieldType valueType = beamFieldType.getMapValueType(); + if (keyType == null) { + throw new RuntimeException("Unexpected null for key type: " + fieldDescriptor.getName()); + } + if (valueType == null) { + throw new RuntimeException( + "Unexpected null for value type: " + fieldDescriptor.getName()); + } + + return map.entrySet().stream() + .map( + (Map.Entry entry) -> + mapEntryToProtoValue( + fieldDescriptor.getMessageType(), keyType, valueType, entry)) + .collect(Collectors.toList()); default: return scalarToProtoValue(beamFieldType, value); } @@ -337,6 +369,28 @@ static Object scalarToProtoValue(FieldType beamFieldType, Object value) { } } + static Object mapEntryToProtoValue( + Descriptor descriptor, + FieldType keyFieldType, + FieldType valueFieldType, + Map.Entry entryValue) { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + FieldDescriptor keyFieldDescriptor = + Preconditions.checkNotNull(descriptor.findFieldByName("key")); + @Nullable Object key = toProtoValue(keyFieldDescriptor, keyFieldType, entryValue.getKey()); + if (key != null) { + builder.setField(keyFieldDescriptor, key); + } + FieldDescriptor valueFieldDescriptor = + Preconditions.checkNotNull(descriptor.findFieldByName("value")); + @Nullable + Object value = toProtoValue(valueFieldDescriptor, valueFieldType, entryValue.getValue()); + if (value != null) { + builder.setField(valueFieldDescriptor, value); + } + return builder.build(); + } + static ByteString serializeBigDecimalToNumeric(BigDecimal o) { return serializeBigDecimal(o, NUMERIC_SCALE, MAX_NUMERIC_VALUE, MIN_NUMERIC_VALUE, "Numeric"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java deleted file mode 100644 index f634b5ec6f60..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.bigquery; - -import com.google.auto.value.AutoValue; -import org.apache.beam.sdk.schemas.AutoValueSchema; -import org.apache.beam.sdk.schemas.annotations.DefaultSchema; - -/** - * Configuration for writing to BigQuery. - * - *

    This class is meant to be used with {@link BigQueryFileLoadsWriteSchemaTransformProvider}. - * - *

    Internal only: This class is actively being worked on, and it will likely change. We - * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam - * repository. - */ -@DefaultSchema(AutoValueSchema.class) -@AutoValue -public abstract class BigQueryFileLoadsWriteSchemaTransformConfiguration { - - /** Instantiates a {@link BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder}. */ - public static Builder builder() { - return new AutoValue_BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder(); - } - - /** - * Writes to the given table specification. See {@link BigQueryIO.Write#to(String)}} for the - * expected format. - */ - public abstract String getTableSpec(); - - /** Specifies whether the table should be created if it does not exist. */ - public abstract String getCreateDisposition(); - - /** Specifies what to do with existing data in the table, in case the table already exists. */ - public abstract String getWriteDisposition(); - - @AutoValue.Builder - public abstract static class Builder { - - /** - * Writes to the given table specification. See {@link BigQueryIO.Write#to(String)}} for the - * expected format. - */ - public abstract Builder setTableSpec(String value); - - /** Specifies whether the table should be created if it does not exist. */ - public abstract Builder setCreateDisposition(String value); - - /** Specifies what to do with existing data in the table, in case the table already exists. */ - public abstract Builder setWriteDisposition(String value); - - /** Builds the {@link BigQueryFileLoadsWriteSchemaTransformConfiguration} configuration. */ - public abstract BigQueryFileLoadsWriteSchemaTransformConfiguration build(); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java deleted file mode 100644 index 3212e2a30348..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.bigquery; - -import com.google.api.services.bigquery.model.Table; -import com.google.api.services.bigquery.model.TableReference; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.auto.service.AutoService; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.io.InvalidConfigurationException; -import org.apache.beam.sdk.schemas.transforms.SchemaTransform; -import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; -import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionRowTuple; -import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; - -/** - * An implementation of {@link TypedSchemaTransformProvider} for BigQuery write jobs configured - * using {@link BigQueryFileLoadsWriteSchemaTransformConfiguration}. - * - *

    Internal only: This class is actively being worked on, and it will likely change. We - * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam - * repository. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -@Internal -@AutoService(SchemaTransformProvider.class) -public class BigQueryFileLoadsWriteSchemaTransformProvider - extends TypedSchemaTransformProvider { - - private static final String IDENTIFIER = - "beam:schematransform:org.apache.beam:bigquery_fileloads_write:v1"; - static final String INPUT_TAG = "INPUT"; - - /** Returns the expected class of the configuration. */ - @Override - protected Class configurationClass() { - return BigQueryFileLoadsWriteSchemaTransformConfiguration.class; - } - - /** Returns the expected {@link SchemaTransform} of the configuration. */ - @Override - protected SchemaTransform from(BigQueryFileLoadsWriteSchemaTransformConfiguration configuration) { - return new BigQueryWriteSchemaTransform(configuration); - } - - /** Implementation of the {@link TypedSchemaTransformProvider} identifier method. */ - @Override - public String identifier() { - return IDENTIFIER; - } - - /** - * Implementation of the {@link TypedSchemaTransformProvider} inputCollectionNames method. Since a - * single is expected, this returns a list with a single name. - */ - @Override - public List inputCollectionNames() { - return Collections.singletonList(INPUT_TAG); - } - - /** - * Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. Since - * no output is expected, this returns an empty list. - */ - @Override - public List outputCollectionNames() { - return Collections.emptyList(); - } - - /** - * A {@link SchemaTransform} that performs {@link BigQueryIO.Write}s based on a {@link - * BigQueryFileLoadsWriteSchemaTransformConfiguration}. - */ - protected static class BigQueryWriteSchemaTransform extends SchemaTransform { - /** An instance of {@link BigQueryServices} used for testing. */ - private BigQueryServices testBigQueryServices = null; - - private final BigQueryFileLoadsWriteSchemaTransformConfiguration configuration; - - BigQueryWriteSchemaTransform(BigQueryFileLoadsWriteSchemaTransformConfiguration configuration) { - this.configuration = configuration; - } - - @Override - public void validate(PipelineOptions options) { - if (!configuration.getCreateDisposition().equals(CreateDisposition.CREATE_NEVER.name())) { - return; - } - - BigQueryOptions bigQueryOptions = options.as(BigQueryOptions.class); - - BigQueryServices bigQueryServices = new BigQueryServicesImpl(); - if (testBigQueryServices != null) { - bigQueryServices = testBigQueryServices; - } - - DatasetService datasetService = bigQueryServices.getDatasetService(bigQueryOptions); - TableReference tableReference = BigQueryUtils.toTableReference(configuration.getTableSpec()); - - try { - Table table = datasetService.getTable(tableReference); - if (table == null) { - throw new NullPointerException(); - } - - if (table.getSchema() == null) { - throw new InvalidConfigurationException( - String.format("could not fetch schema for table: %s", configuration.getTableSpec())); - } - - } catch (NullPointerException | InterruptedException | IOException ex) { - throw new InvalidConfigurationException( - String.format( - "could not fetch table %s, error: %s", - configuration.getTableSpec(), ex.getMessage())); - } - } - - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - validate(input); - PCollection rowPCollection = input.get(INPUT_TAG); - Schema schema = rowPCollection.getSchema(); - BigQueryIO.Write write = toWrite(schema); - if (testBigQueryServices != null) { - write = write.withTestServices(testBigQueryServices); - } - - PCollection tableRowPCollection = - rowPCollection.apply( - MapElements.into(TypeDescriptor.of(TableRow.class)).via(BigQueryUtils::toTableRow)); - tableRowPCollection.apply(write); - return PCollectionRowTuple.empty(input.getPipeline()); - } - - /** Instantiates a {@link BigQueryIO.Write} from a {@link Schema}. */ - BigQueryIO.Write toWrite(Schema schema) { - TableSchema tableSchema = BigQueryUtils.toTableSchema(schema); - CreateDisposition createDisposition = - CreateDisposition.valueOf(configuration.getCreateDisposition()); - WriteDisposition writeDisposition = - WriteDisposition.valueOf(configuration.getWriteDisposition()); - - return BigQueryIO.writeTableRows() - .to(configuration.getTableSpec()) - .withCreateDisposition(createDisposition) - .withWriteDisposition(writeDisposition) - .withSchema(tableSchema); - } - - /** Setter for testing using {@link BigQueryServices}. */ - @VisibleForTesting - void setTestBigQueryServices(BigQueryServices testBigQueryServices) { - this.testBigQueryServices = testBigQueryServices; - } - - /** Validate a {@link PCollectionRowTuple} input. */ - void validate(PCollectionRowTuple input) { - if (!input.has(INPUT_TAG)) { - throw new IllegalArgumentException( - String.format( - "%s %s is missing expected tag: %s", - getClass().getSimpleName(), input.getClass().getSimpleName(), INPUT_TAG)); - } - - PCollection rowInput = input.get(INPUT_TAG); - Schema sourceSchema = rowInput.getSchema(); - - if (sourceSchema == null) { - throw new IllegalArgumentException( - String.format("%s is null for input of tag: %s", Schema.class, INPUT_TAG)); - } - - if (!configuration.getCreateDisposition().equals(CreateDisposition.CREATE_NEVER.name())) { - return; - } - - BigQueryOptions bigQueryOptions = input.getPipeline().getOptions().as(BigQueryOptions.class); - - BigQueryServices bigQueryServices = new BigQueryServicesImpl(); - if (testBigQueryServices != null) { - bigQueryServices = testBigQueryServices; - } - - DatasetService datasetService = bigQueryServices.getDatasetService(bigQueryOptions); - TableReference tableReference = BigQueryUtils.toTableReference(configuration.getTableSpec()); - - try { - Table table = datasetService.getTable(tableReference); - if (table == null) { - throw new NullPointerException(); - } - - TableSchema tableSchema = table.getSchema(); - if (tableSchema == null) { - throw new NullPointerException(); - } - - Schema destinationSchema = BigQueryUtils.fromTableSchema(tableSchema); - if (destinationSchema == null) { - throw new NullPointerException(); - } - - validateMatching(sourceSchema, destinationSchema); - - } catch (NullPointerException | InterruptedException | IOException e) { - throw new InvalidConfigurationException( - String.format( - "could not validate input for create disposition: %s and table: %s, error: %s", - configuration.getCreateDisposition(), - configuration.getTableSpec(), - e.getMessage())); - } - } - - void validateMatching(Schema sourceSchema, Schema destinationSchema) { - if (!sourceSchema.equals(destinationSchema)) { - throw new IllegalArgumentException( - String.format( - "source and destination schema mismatch for table: %s", - configuration.getTableSpec())); - } - } - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java index 9d84abbbbf1a..8c6893ef5798 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java @@ -297,19 +297,15 @@ public void finishBundle(FinishBundleContext c) throws Exception { } for (Map.Entry> entry : writers.entrySet()) { - try { - DestinationT destination = entry.getKey(); - BigQueryRowWriter writer = entry.getValue(); - BigQueryRowWriter.Result result = writer.getResult(); - BoundedWindow window = writerWindows.get(destination); - Preconditions.checkStateNotNull(window); - c.output( - new Result<>(result.resourceId.toString(), result.byteSize, destination), - window.maxTimestamp(), - window); - } catch (Exception e) { - exceptionList.add(e); - } + DestinationT destination = entry.getKey(); + BigQueryRowWriter writer = entry.getValue(); + BigQueryRowWriter.Result result = writer.getResult(); + BoundedWindow window = writerWindows.get(destination); + Preconditions.checkStateNotNull(window); + c.output( + new Result<>(result.resourceId.toString(), result.byteSize, destination), + window.maxTimestamp(), + window); } writers.clear(); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java index e374d459af44..288b94ce081b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java @@ -76,10 +76,8 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; -import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -110,17 +108,13 @@ static class ResultCoder extends AtomicCoder { static final ResultCoder INSTANCE = new ResultCoder(); @Override - public void encode(Result value, @UnknownKeyFor @NonNull @Initialized OutputStream outStream) - throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull - @Initialized IOException { + public void encode(Result value, OutputStream outStream) throws CoderException, IOException { StringUtf8Coder.of().encode(value.getTableName(), outStream); BooleanCoder.of().encode(value.isFirstPane(), outStream); } @Override - public Result decode(@UnknownKeyFor @NonNull @Initialized InputStream inStream) - throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull - @Initialized IOException { + public Result decode(InputStream inStream) throws CoderException, IOException { return new AutoValue_WriteTables_Result( StringUtf8Coder.of().decode(inStream), BooleanCoder.of().decode(inStream)); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java index 8b8e8179ce7d..15b1b01d7f6c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; @@ -26,6 +27,7 @@ import java.util.Collections; import java.util.List; import javax.annotation.Nullable; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead; @@ -33,7 +35,9 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransformConfiguration; import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; @@ -62,7 +66,7 @@ public class BigQueryDirectReadSchemaTransformProvider extends TypedSchemaTransformProvider { - private static final String OUTPUT_TAG = "OUTPUT_ROWS"; + public static final String OUTPUT_TAG = "output"; @Override protected Class configurationClass() { @@ -76,7 +80,7 @@ protected SchemaTransform from(BigQueryDirectReadSchemaTransformConfiguration co @Override public String identifier() { - return "beam:schematransform:org.apache.beam:bigquery_storage_read:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ); } @Override @@ -139,6 +143,10 @@ public static Builder builder() { @Nullable public abstract List getSelectedFields(); + @SchemaFieldDescription("Use this Cloud KMS key to encrypt your data") + @Nullable + public abstract String getKmsKey(); + @Nullable /** Builder for the {@link BigQueryDirectReadSchemaTransformConfiguration}. */ @AutoValue.Builder @@ -151,6 +159,8 @@ public abstract static class Builder { public abstract Builder setSelectedFields(List selectedFields); + public abstract Builder setKmsKey(String kmsKey); + /** Builds a {@link BigQueryDirectReadSchemaTransformConfiguration} instance. */ public abstract BigQueryDirectReadSchemaTransformConfiguration build(); } @@ -161,7 +171,7 @@ public abstract static class Builder { * BigQueryDirectReadSchemaTransformConfiguration} and instantiated by {@link * BigQueryDirectReadSchemaTransformProvider}. */ - protected static class BigQueryDirectReadSchemaTransform extends SchemaTransform { + public static class BigQueryDirectReadSchemaTransform extends SchemaTransform { private BigQueryServices testBigQueryServices = null; private final BigQueryDirectReadSchemaTransformConfiguration configuration; @@ -172,6 +182,20 @@ protected static class BigQueryDirectReadSchemaTransform extends SchemaTransform this.configuration = configuration; } + public Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(BigQueryDirectReadSchemaTransformConfiguration.class) + .apply(configuration) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } + @VisibleForTesting public void setBigQueryServices(BigQueryServices testBigQueryServices) { this.testBigQueryServices = testBigQueryServices; @@ -211,6 +235,9 @@ BigQueryIO.TypedRead createDirectReadTransform() { } else { read = read.fromQuery(configuration.getQuery()); } + if (!Strings.isNullOrEmpty(configuration.getKmsKey())) { + read = read.withKmsKey(configuration.getKmsKey()); + } if (this.testBigQueryServices != null) { read = read.withTestServices(testBigQueryServices); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java new file mode 100644 index 000000000000..092cf42a29a4 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import com.google.auto.service.AutoService; +import java.util.Collections; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; + +/** + * An implementation of {@link TypedSchemaTransformProvider} for BigQuery write jobs configured + * using {@link org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteConfiguration}. + * + *

    Internal only: This class is actively being worked on, and it will likely change. We + * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam + * repository. + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +@Internal +@AutoService(SchemaTransformProvider.class) +public class BigQueryFileLoadsSchemaTransformProvider + extends TypedSchemaTransformProvider { + + static final String INPUT_TAG = "input"; + + @Override + protected SchemaTransform from(BigQueryWriteConfiguration configuration) { + return new BigQueryFileLoadsSchemaTransform(configuration); + } + + @Override + public String identifier() { + return "beam:schematransform:org.apache.beam:bigquery_fileloads:v1"; + } + + @Override + public List inputCollectionNames() { + return Collections.singletonList(INPUT_TAG); + } + + @Override + public List outputCollectionNames() { + return Collections.emptyList(); + } + + public static class BigQueryFileLoadsSchemaTransform extends SchemaTransform { + /** An instance of {@link BigQueryServices} used for testing. */ + private BigQueryServices testBigQueryServices = null; + + private final BigQueryWriteConfiguration configuration; + + BigQueryFileLoadsSchemaTransform(BigQueryWriteConfiguration configuration) { + configuration.validate(); + this.configuration = configuration; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + PCollection rowPCollection = input.getSinglePCollection(); + BigQueryIO.Write write = toWrite(input.getPipeline().getOptions()); + rowPCollection.apply(write); + + return PCollectionRowTuple.empty(input.getPipeline()); + } + + BigQueryIO.Write toWrite(PipelineOptions options) { + BigQueryIO.Write write = + BigQueryIO.write() + .to(configuration.getTable()) + .withMethod(BigQueryIO.Write.Method.FILE_LOADS) + .withFormatFunction(BigQueryUtils.toTableRow()) + // TODO(https://github.com/apache/beam/issues/33074) BatchLoad's + // createTempFilePrefixView() doesn't pick up the pipeline option + .withCustomGcsTempLocation( + ValueProvider.StaticValueProvider.of(options.getTempLocation())) + .withWriteDisposition(WriteDisposition.WRITE_APPEND) + .useBeamSchema(); + + if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { + CreateDisposition createDisposition = + CreateDisposition.valueOf(configuration.getCreateDisposition().toUpperCase()); + write = write.withCreateDisposition(createDisposition); + } + if (!Strings.isNullOrEmpty(configuration.getWriteDisposition())) { + WriteDisposition writeDisposition = + WriteDisposition.valueOf(configuration.getWriteDisposition().toUpperCase()); + write = write.withWriteDisposition(writeDisposition); + } + if (!Strings.isNullOrEmpty(configuration.getKmsKey())) { + write = write.withKmsKey(configuration.getKmsKey()); + } + if (testBigQueryServices != null) { + write = write.withTestServices(testBigQueryServices); + } + + return write; + } + + /** Setter for testing using {@link BigQueryServices}. */ + @VisibleForTesting + void setTestBigQueryServices(BigQueryServices testBigQueryServices) { + this.testBigQueryServices = testBigQueryServices; + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java new file mode 100644 index 000000000000..555df0d0a2b8 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransform; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteSchemaTransformProvider.BigQueryWriteSchemaTransform; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +public class BigQuerySchemaTransformTranslation { + public static class BigQueryStorageReadSchemaTransformTranslator + extends SchemaTransformTranslation.SchemaTransformPayloadTranslator< + BigQueryDirectReadSchemaTransform> { + @Override + public SchemaTransformProvider provider() { + return new BigQueryDirectReadSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(BigQueryDirectReadSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + public static class BigQueryWriteSchemaTransformTranslator + extends SchemaTransformTranslation.SchemaTransformPayloadTranslator< + BigQueryWriteSchemaTransform> { + @Override + public SchemaTransformProvider provider() { + return new BigQueryWriteSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(BigQueryWriteSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class ReadWriteRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put( + BigQueryDirectReadSchemaTransform.class, + new BigQueryStorageReadSchemaTransformTranslator()) + .put(BigQueryWriteSchemaTransform.class, new BigQueryWriteSchemaTransformTranslator()) + .build(); + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index c1c06fc592f4..c45433aaf0e7 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -17,20 +17,16 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import com.google.api.services.bigquery.model.TableConstraints; import com.google.api.services.bigquery.model.TableSchema; import com.google.auto.service.AutoService; -import com.google.auto.value.AutoValue; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; -import javax.annotation.Nullable; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method; @@ -42,15 +38,11 @@ import org.apache.beam.sdk.io.gcp.bigquery.RowMutationInformation; import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; import org.apache.beam.sdk.io.gcp.bigquery.WriteResult; -import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.annotations.DefaultSchema; -import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; @@ -65,12 +57,11 @@ import org.apache.beam.sdk.values.ValueInSingleWindow; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; /** * An implementation of {@link TypedSchemaTransformProvider} for BigQuery Storage Write API jobs - * configured via {@link BigQueryStorageWriteApiSchemaTransformConfiguration}. + * configured via {@link BigQueryWriteConfiguration}. * *

    Internal only: This class is actively being worked on, and it will likely change. We * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam @@ -81,7 +72,7 @@ }) @AutoService(SchemaTransformProvider.class) public class BigQueryStorageWriteApiSchemaTransformProvider - extends TypedSchemaTransformProvider { + extends TypedSchemaTransformProvider { private static final Integer DEFAULT_TRIGGER_FREQUENCY_SECS = 5; private static final Duration DEFAULT_TRIGGERING_FREQUENCY = Duration.standardSeconds(DEFAULT_TRIGGER_FREQUENCY_SECS); @@ -89,7 +80,6 @@ public class BigQueryStorageWriteApiSchemaTransformProvider private static final String FAILED_ROWS_TAG = "FailedRows"; private static final String FAILED_ROWS_WITH_ERRORS_TAG = "FailedRowsWithErrors"; // magic string that tells us to write to dynamic destinations - protected static final String DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS"; protected static final String ROW_PROPERTY_MUTATION_INFO = "row_mutation_info"; protected static final String ROW_PROPERTY_MUTATION_TYPE = "mutation_type"; protected static final String ROW_PROPERTY_MUTATION_SQN = "change_sequence_number"; @@ -100,14 +90,13 @@ public class BigQueryStorageWriteApiSchemaTransformProvider .build(); @Override - protected SchemaTransform from( - BigQueryStorageWriteApiSchemaTransformConfiguration configuration) { + protected SchemaTransform from(BigQueryWriteConfiguration configuration) { return new BigQueryStorageWriteApiSchemaTransform(configuration); } @Override public String identifier() { - return String.format("beam:schematransform:org.apache.beam:bigquery_storage_write:v2"); + return "beam:schematransform:org.apache.beam:bigquery_storage_write:v2"; } @Override @@ -130,201 +119,17 @@ public List outputCollectionNames() { return Arrays.asList(FAILED_ROWS_TAG, FAILED_ROWS_WITH_ERRORS_TAG, "errors"); } - /** Configuration for writing to BigQuery with Storage Write API. */ - @DefaultSchema(AutoValueSchema.class) - @AutoValue - public abstract static class BigQueryStorageWriteApiSchemaTransformConfiguration { - - static final Map CREATE_DISPOSITIONS = - ImmutableMap.builder() - .put(CreateDisposition.CREATE_IF_NEEDED.name(), CreateDisposition.CREATE_IF_NEEDED) - .put(CreateDisposition.CREATE_NEVER.name(), CreateDisposition.CREATE_NEVER) - .build(); - - static final Map WRITE_DISPOSITIONS = - ImmutableMap.builder() - .put(WriteDisposition.WRITE_TRUNCATE.name(), WriteDisposition.WRITE_TRUNCATE) - .put(WriteDisposition.WRITE_EMPTY.name(), WriteDisposition.WRITE_EMPTY) - .put(WriteDisposition.WRITE_APPEND.name(), WriteDisposition.WRITE_APPEND) - .build(); - - @AutoValue - public abstract static class ErrorHandling { - @SchemaFieldDescription("The name of the output PCollection containing failed writes.") - public abstract String getOutput(); - - public static Builder builder() { - return new AutoValue_BigQueryStorageWriteApiSchemaTransformProvider_BigQueryStorageWriteApiSchemaTransformConfiguration_ErrorHandling - .Builder(); - } - - @AutoValue.Builder - public abstract static class Builder { - public abstract Builder setOutput(String output); - - public abstract ErrorHandling build(); - } - } - - public void validate() { - String invalidConfigMessage = "Invalid BigQuery Storage Write configuration: "; - - // validate output table spec - checkArgument( - !Strings.isNullOrEmpty(this.getTable()), - invalidConfigMessage + "Table spec for a BigQuery Write must be specified."); - - // if we have an input table spec, validate it - if (!this.getTable().equals(DYNAMIC_DESTINATIONS)) { - checkNotNull(BigQueryHelpers.parseTableSpec(this.getTable())); - } - - // validate create and write dispositions - if (!Strings.isNullOrEmpty(this.getCreateDisposition())) { - checkNotNull( - CREATE_DISPOSITIONS.get(this.getCreateDisposition().toUpperCase()), - invalidConfigMessage - + "Invalid create disposition (%s) was specified. Available dispositions are: %s", - this.getCreateDisposition(), - CREATE_DISPOSITIONS.keySet()); - } - if (!Strings.isNullOrEmpty(this.getWriteDisposition())) { - checkNotNull( - WRITE_DISPOSITIONS.get(this.getWriteDisposition().toUpperCase()), - invalidConfigMessage - + "Invalid write disposition (%s) was specified. Available dispositions are: %s", - this.getWriteDisposition(), - WRITE_DISPOSITIONS.keySet()); - } - - if (this.getErrorHandling() != null) { - checkArgument( - !Strings.isNullOrEmpty(this.getErrorHandling().getOutput()), - invalidConfigMessage + "Output must not be empty if error handling specified."); - } - - if (this.getAutoSharding() != null - && this.getAutoSharding() - && this.getNumStreams() != null) { - checkArgument( - this.getNumStreams() == 0, - invalidConfigMessage - + "Cannot set a fixed number of streams when auto-sharding is enabled. Please pick only one of the two options."); - } - } - - /** - * Instantiates a {@link BigQueryStorageWriteApiSchemaTransformConfiguration.Builder} instance. - */ - public static Builder builder() { - return new AutoValue_BigQueryStorageWriteApiSchemaTransformProvider_BigQueryStorageWriteApiSchemaTransformConfiguration - .Builder(); - } - - @SchemaFieldDescription( - "The bigquery table to write to. Format: [${PROJECT}:]${DATASET}.${TABLE}") - public abstract String getTable(); - - @SchemaFieldDescription( - "Optional field that specifies whether the job is allowed to create new tables. " - + "The following values are supported: CREATE_IF_NEEDED (the job may create the table), CREATE_NEVER (" - + "the job must fail if the table does not exist already).") - @Nullable - public abstract String getCreateDisposition(); - - @SchemaFieldDescription( - "Specifies the action that occurs if the destination table already exists. " - + "The following values are supported: " - + "WRITE_TRUNCATE (overwrites the table data), " - + "WRITE_APPEND (append the data to the table), " - + "WRITE_EMPTY (job must fail if the table is not empty).") - @Nullable - public abstract String getWriteDisposition(); - - @SchemaFieldDescription( - "Determines how often to 'commit' progress into BigQuery. Default is every 5 seconds.") - @Nullable - public abstract Long getTriggeringFrequencySeconds(); - - @SchemaFieldDescription( - "This option enables lower latency for insertions to BigQuery but may ocassionally " - + "duplicate data elements.") - @Nullable - public abstract Boolean getUseAtLeastOnceSemantics(); - - @SchemaFieldDescription( - "This option enables using a dynamically determined number of Storage Write API streams to write to " - + "BigQuery. Only applicable to unbounded data.") - @Nullable - public abstract Boolean getAutoSharding(); - - @SchemaFieldDescription( - "Specifies the number of write streams that the Storage API sink will use. " - + "This parameter is only applicable when writing unbounded data.") - @Nullable - public abstract Integer getNumStreams(); - - @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") - @Nullable - public abstract ErrorHandling getErrorHandling(); - - @SchemaFieldDescription( - "This option enables the use of BigQuery CDC functionality. The expected PCollection" - + " should contain Beam Rows with a schema wrapping the record to be inserted and" - + " adding the CDC info similar to: {row_mutation_info: {mutation_type:\"...\", " - + "change_sequence_number:\"...\"}, record: {...}}") - @Nullable - public abstract Boolean getUseCdcWrites(); - - @SchemaFieldDescription( - "If CREATE_IF_NEEDED disposition is set, BigQuery table(s) will be created with this" - + " columns as primary key. Required when CDC writes are enabled with CREATE_IF_NEEDED.") - @Nullable - public abstract List getPrimaryKey(); - - /** Builder for {@link BigQueryStorageWriteApiSchemaTransformConfiguration}. */ - @AutoValue.Builder - public abstract static class Builder { - - public abstract Builder setTable(String table); - - public abstract Builder setCreateDisposition(String createDisposition); - - public abstract Builder setWriteDisposition(String writeDisposition); - - public abstract Builder setTriggeringFrequencySeconds(Long seconds); - - public abstract Builder setUseAtLeastOnceSemantics(Boolean use); - - public abstract Builder setAutoSharding(Boolean autoSharding); - - public abstract Builder setNumStreams(Integer numStreams); - - public abstract Builder setErrorHandling(ErrorHandling errorHandling); - - public abstract Builder setUseCdcWrites(Boolean cdcWrites); - - public abstract Builder setPrimaryKey(List pkColumns); - - /** Builds a {@link BigQueryStorageWriteApiSchemaTransformConfiguration} instance. */ - public abstract BigQueryStorageWriteApiSchemaTransformProvider - .BigQueryStorageWriteApiSchemaTransformConfiguration - build(); - } - } - /** * A {@link SchemaTransform} for BigQuery Storage Write API, configured with {@link - * BigQueryStorageWriteApiSchemaTransformConfiguration} and instantiated by {@link + * BigQueryWriteConfiguration} and instantiated by {@link * BigQueryStorageWriteApiSchemaTransformProvider}. */ - protected static class BigQueryStorageWriteApiSchemaTransform extends SchemaTransform { + public static class BigQueryStorageWriteApiSchemaTransform extends SchemaTransform { private BigQueryServices testBigQueryServices = null; - private final BigQueryStorageWriteApiSchemaTransformConfiguration configuration; + private final BigQueryWriteConfiguration configuration; - BigQueryStorageWriteApiSchemaTransform( - BigQueryStorageWriteApiSchemaTransformConfiguration configuration) { + BigQueryStorageWriteApiSchemaTransform(BigQueryWriteConfiguration configuration) { configuration.validate(); this.configuration = configuration; } @@ -420,8 +225,7 @@ public TableConstraints getTableConstraints(String destination) { @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { // Check that the input exists - checkArgument(input.has(INPUT_ROWS_TAG), "Missing expected input tag: %s", INPUT_ROWS_TAG); - PCollection inputRows = input.get(INPUT_ROWS_TAG); + PCollection inputRows = input.getSinglePCollection(); BigQueryIO.Write write = createStorageWriteApiTransform(inputRows.getSchema()); @@ -540,18 +344,18 @@ BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { CreateDisposition createDisposition = - BigQueryStorageWriteApiSchemaTransformConfiguration.CREATE_DISPOSITIONS.get( - configuration.getCreateDisposition().toUpperCase()); + CreateDisposition.valueOf(configuration.getCreateDisposition().toUpperCase()); write = write.withCreateDisposition(createDisposition); } if (!Strings.isNullOrEmpty(configuration.getWriteDisposition())) { WriteDisposition writeDisposition = - BigQueryStorageWriteApiSchemaTransformConfiguration.WRITE_DISPOSITIONS.get( - configuration.getWriteDisposition().toUpperCase()); + WriteDisposition.valueOf(configuration.getWriteDisposition().toUpperCase()); write = write.withWriteDisposition(writeDisposition); } - + if (!Strings.isNullOrEmpty(configuration.getKmsKey())) { + write = write.withKmsKey(configuration.getKmsKey()); + } if (this.testBigQueryServices != null) { write = write.withTestServices(testBigQueryServices); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java new file mode 100644 index 000000000000..4296da7e0cd5 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; + +/** + * Configuration for writing to BigQuery with SchemaTransforms. Used by {@link + * BigQueryStorageWriteApiSchemaTransformProvider} and {@link + * BigQueryFileLoadsSchemaTransformProvider}. + */ +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract class BigQueryWriteConfiguration { + protected static final String DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS"; + + @AutoValue + public abstract static class ErrorHandling { + @SchemaFieldDescription("The name of the output PCollection containing failed writes.") + public abstract String getOutput(); + + public static Builder builder() { + return new AutoValue_BigQueryWriteConfiguration_ErrorHandling.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setOutput(String output); + + public abstract ErrorHandling build(); + } + } + + public void validate() { + String invalidConfigMessage = "Invalid BigQuery Storage Write configuration: "; + + // validate output table spec + checkArgument( + !Strings.isNullOrEmpty(this.getTable()), + invalidConfigMessage + "Table spec for a BigQuery Write must be specified."); + + // if we have an input table spec, validate it + if (!this.getTable().equals(DYNAMIC_DESTINATIONS)) { + checkNotNull(BigQueryHelpers.parseTableSpec(this.getTable())); + } + + // validate create and write dispositions + String createDisposition = getCreateDisposition(); + if (createDisposition != null && !createDisposition.isEmpty()) { + List createDispositions = + Arrays.stream(BigQueryIO.Write.CreateDisposition.values()) + .map(c -> c.name()) + .collect(Collectors.toList()); + Preconditions.checkArgument( + createDispositions.contains(createDisposition.toUpperCase()), + "Invalid create disposition (%s) was specified. Available dispositions are: %s", + createDisposition, + createDispositions); + } + String writeDisposition = getWriteDisposition(); + if (writeDisposition != null && !writeDisposition.isEmpty()) { + List writeDispostions = + Arrays.stream(BigQueryIO.Write.WriteDisposition.values()) + .map(w -> w.name()) + .collect(Collectors.toList()); + Preconditions.checkArgument( + writeDispostions.contains(writeDisposition.toUpperCase()), + "Invalid write disposition (%s) was specified. Available dispositions are: %s", + writeDisposition, + writeDispostions); + } + + ErrorHandling errorHandling = getErrorHandling(); + if (errorHandling != null) { + checkArgument( + !Strings.isNullOrEmpty(errorHandling.getOutput()), + invalidConfigMessage + "Output must not be empty if error handling specified."); + } + + Boolean autoSharding = getAutoSharding(); + Integer numStreams = getNumStreams(); + if (autoSharding != null && autoSharding && numStreams != null) { + checkArgument( + numStreams == 0, + invalidConfigMessage + + "Cannot set a fixed number of streams when auto-sharding is enabled. Please pick only one of the two options."); + } + } + + /** Instantiates a {@link BigQueryWriteConfiguration.Builder} instance. */ + public static Builder builder() { + return new AutoValue_BigQueryWriteConfiguration.Builder(); + } + + @SchemaFieldDescription( + "The bigquery table to write to. Format: [${PROJECT}:]${DATASET}.${TABLE}") + public abstract String getTable(); + + @SchemaFieldDescription( + "Optional field that specifies whether the job is allowed to create new tables. " + + "The following values are supported: CREATE_IF_NEEDED (the job may create the table), CREATE_NEVER (" + + "the job must fail if the table does not exist already).") + @Nullable + public abstract String getCreateDisposition(); + + @SchemaFieldDescription( + "Specifies the action that occurs if the destination table already exists. " + + "The following values are supported: " + + "WRITE_TRUNCATE (overwrites the table data), " + + "WRITE_APPEND (append the data to the table), " + + "WRITE_EMPTY (job must fail if the table is not empty).") + @Nullable + public abstract String getWriteDisposition(); + + @SchemaFieldDescription( + "Determines how often to 'commit' progress into BigQuery. Default is every 5 seconds.") + @Nullable + public abstract Long getTriggeringFrequencySeconds(); + + @SchemaFieldDescription( + "This option enables lower latency for insertions to BigQuery but may ocassionally " + + "duplicate data elements.") + @Nullable + public abstract Boolean getUseAtLeastOnceSemantics(); + + @SchemaFieldDescription( + "This option enables using a dynamically determined number of Storage Write API streams to write to " + + "BigQuery. Only applicable to unbounded data.") + @Nullable + public abstract Boolean getAutoSharding(); + + @SchemaFieldDescription( + "Specifies the number of write streams that the Storage API sink will use. " + + "This parameter is only applicable when writing unbounded data.") + @Nullable + public abstract Integer getNumStreams(); + + @SchemaFieldDescription("Use this Cloud KMS key to encrypt your data") + @Nullable + public abstract String getKmsKey(); + + @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") + @Nullable + public abstract ErrorHandling getErrorHandling(); + + @SchemaFieldDescription( + "This option enables the use of BigQuery CDC functionality. The expected PCollection" + + " should contain Beam Rows with a schema wrapping the record to be inserted and" + + " adding the CDC info similar to: {row_mutation_info: {mutation_type:\"...\", " + + "change_sequence_number:\"...\"}, record: {...}}") + @Nullable + public abstract Boolean getUseCdcWrites(); + + @SchemaFieldDescription( + "If CREATE_IF_NEEDED disposition is set, BigQuery table(s) will be created with this" + + " columns as primary key. Required when CDC writes are enabled with CREATE_IF_NEEDED.") + @Nullable + public abstract List getPrimaryKey(); + + /** Builder for {@link BigQueryWriteConfiguration}. */ + @AutoValue.Builder + public abstract static class Builder { + + public abstract Builder setTable(String table); + + public abstract Builder setCreateDisposition(String createDisposition); + + public abstract Builder setWriteDisposition(String writeDisposition); + + public abstract Builder setTriggeringFrequencySeconds(Long seconds); + + public abstract Builder setUseAtLeastOnceSemantics(Boolean use); + + public abstract Builder setAutoSharding(Boolean autoSharding); + + public abstract Builder setNumStreams(Integer numStreams); + + public abstract Builder setKmsKey(String kmsKey); + + public abstract Builder setErrorHandling(ErrorHandling errorHandling); + + public abstract Builder setUseCdcWrites(Boolean cdcWrites); + + public abstract Builder setPrimaryKey(List pkColumns); + + /** Builds a {@link BigQueryWriteConfiguration} instance. */ + public abstract BigQueryWriteConfiguration build(); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java new file mode 100644 index 000000000000..abab169d6932 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; + +import com.google.auto.service.AutoService; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; + +/** + * A BigQuery Write SchemaTransformProvider that routes to either {@link + * BigQueryFileLoadsSchemaTransformProvider} or {@link + * BigQueryStorageWriteApiSchemaTransformProvider}. + * + *

    Internal only. Used by the Managed Transform layer. + */ +@Internal +@AutoService(SchemaTransformProvider.class) +public class BigQueryWriteSchemaTransformProvider + extends TypedSchemaTransformProvider { + @Override + public String identifier() { + return getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE); + } + + @Override + protected SchemaTransform from(BigQueryWriteConfiguration configuration) { + return new BigQueryWriteSchemaTransform(configuration); + } + + public static class BigQueryWriteSchemaTransform extends SchemaTransform { + private final BigQueryWriteConfiguration configuration; + + BigQueryWriteSchemaTransform(BigQueryWriteConfiguration configuration) { + configuration.validate(); + this.configuration = configuration; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + if (input.getSinglePCollection().isBounded().equals(PCollection.IsBounded.BOUNDED)) { + return input.apply(new BigQueryFileLoadsSchemaTransformProvider().from(configuration)); + } else { // UNBOUNDED + return input.apply( + new BigQueryStorageWriteApiSchemaTransformProvider().from(configuration)); + } + } + + public Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(BigQueryWriteConfiguration.class) + .apply(configuration) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java index dca12db0c211..2b187039d6cb 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java @@ -200,11 +200,10 @@ interface RpcWriteAttempt extends RpcAttempt { * provided {@code instant}. * * @param instant The intended start time of the next rpc - * @param The type which will be sent in the request * @param The {@link Element} type which the returned buffer will contain * @return a new {@link FlushBuffer} which queued messages can be staged to before final flush */ - > FlushBuffer newFlushBuffer(Instant instant); + > FlushBuffer newFlushBuffer(Instant instant); /** Record the start time of sending the rpc. */ void recordRequestStart(Instant start, int numWrites); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java index c600ae4224b4..1c83e45acb95 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java @@ -386,7 +386,7 @@ public boolean awaitSafeToProceed(Instant instant) throws InterruptedException { } @Override - public > FlushBufferImpl newFlushBuffer( + public > FlushBufferImpl newFlushBuffer( Instant instantSinceEpoch) { state.checkActive(); int availableWriteCountBudget = writeRampUp.getAvailableWriteCountBudget(instantSinceEpoch); @@ -935,7 +935,7 @@ private static O11y create( } } - static class FlushBufferImpl> implements FlushBuffer { + static class FlushBufferImpl> implements FlushBuffer { final int nextBatchMaxCount; final long nextBatchMaxBytes; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubClient.java index 2964a29dbb6b..bd01604643e1 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubClient.java @@ -507,6 +507,9 @@ public abstract void modifyAckDeadline( /** Return a list of topics for {@code project}. */ public abstract List listTopics(ProjectPath project) throws IOException; + /** Return {@literal true} if {@code topic} exists. */ + public abstract boolean isTopicExists(TopicPath topic) throws IOException; + /** Create {@code subscription} to {@code topic}. */ public abstract void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java index 93fdd5524007..0cfb06688108 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClient.java @@ -54,6 +54,7 @@ import io.grpc.Channel; import io.grpc.ClientInterceptors; import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; import io.grpc.auth.ClientAuthInterceptor; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NegotiationType; @@ -372,6 +373,21 @@ public List listTopics(ProjectPath project) throws IOException { return topics; } + @Override + public boolean isTopicExists(TopicPath topic) throws IOException { + GetTopicRequest request = GetTopicRequest.newBuilder().setTopic(topic.getPath()).build(); + try { + publisherStub().getTopic(request); + return true; + } catch (StatusRuntimeException e) { + if (e.getStatus().getCode() == io.grpc.Status.Code.NOT_FOUND) { + return false; + } + + throw e; + } + } + @Override public void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java index f59a68c40551..c6c8b3e71815 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java @@ -50,6 +50,7 @@ import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.SubscriptionPath; import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.TopicPath; import org.apache.beam.sdk.metrics.Lineage; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; @@ -860,6 +861,8 @@ public abstract static class Read extends PTransform> abstract ErrorHandler getBadRecordErrorHandler(); + abstract boolean getValidate(); + abstract Builder toBuilder(); static Builder newBuilder(SerializableFunction parseFn) { @@ -871,6 +874,7 @@ static Builder newBuilder(SerializableFunction parseFn) builder.setNeedsOrderingKey(false); builder.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER); builder.setBadRecordErrorHandler(new DefaultErrorHandler<>()); + builder.setValidate(false); return builder; } @@ -918,6 +922,8 @@ abstract static class Builder { abstract Builder setBadRecordErrorHandler( ErrorHandler badRecordErrorHandler); + abstract Builder setValidate(boolean validation); + abstract Read build(); } @@ -1097,6 +1103,11 @@ public Read withErrorHandler(ErrorHandler badRecordErrorHandler .build(); } + /** Enable validation of the PubSub Read. */ + public Read withValidation() { + return toBuilder().setValidate(true).build(); + } + @VisibleForTesting /** * Set's the internal Clock. @@ -1262,6 +1273,35 @@ public T apply(PubsubMessage input) { return read.setCoder(getCoder()); } + @Override + public void validate(PipelineOptions options) { + if (!getValidate()) { + return; + } + + PubsubOptions psOptions = options.as(PubsubOptions.class); + + // Validate the existence of the topic. + if (getTopicProvider() != null) { + PubsubTopic topic = getTopicProvider().get(); + boolean topicExists = true; + try (PubsubClient pubsubClient = + getPubsubClientFactory() + .newClient(getTimestampAttribute(), getIdAttribute(), psOptions)) { + topicExists = + pubsubClient.isTopicExists( + PubsubClient.topicPathFromName(topic.project, topic.topic)); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (!topicExists) { + throw new IllegalArgumentException( + String.format("Pubsub topic '%s' does not exist.", topic)); + } + } + } + @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); @@ -1341,6 +1381,8 @@ public abstract static class Write extends PTransform, PDone> abstract ErrorHandler getBadRecordErrorHandler(); + abstract boolean getValidate(); + abstract Builder toBuilder(); static Builder newBuilder( @@ -1350,6 +1392,7 @@ static Builder newBuilder( builder.setFormatFn(formatFn); builder.setBadRecordRouter(BadRecordRouter.THROWING_ROUTER); builder.setBadRecordErrorHandler(new DefaultErrorHandler<>()); + builder.setValidate(false); return builder; } @@ -1386,6 +1429,8 @@ abstract Builder setFormatFn( abstract Builder setBadRecordErrorHandler( ErrorHandler badRecordErrorHandler); + abstract Builder setValidate(boolean validation); + abstract Write build(); } @@ -1396,11 +1441,14 @@ abstract Builder setBadRecordErrorHandler( * {@code topic} string. */ public Write to(String topic) { + ValueProvider topicProvider = StaticValueProvider.of(topic); + validateTopic(topicProvider); return to(StaticValueProvider.of(topic)); } /** Like {@code topic()} but with a {@link ValueProvider}. */ public Write to(ValueProvider topic) { + validateTopic(topic); return toBuilder() .setTopicProvider(NestedValueProvider.of(topic, PubsubTopic::fromPath)) .setTopicFunction(null) @@ -1408,6 +1456,13 @@ public Write to(ValueProvider topic) { .build(); } + /** Handles validation of {@code topic}. */ + private static void validateTopic(ValueProvider topic) { + if (topic.isAccessible()) { + PubsubTopic.fromPath(topic.get()); + } + } + /** * Provides a function to dynamically specify the target topic per message. Not compatible with * any of the other to methods. If {@link #to} is called again specifying a topic, then this @@ -1497,6 +1552,11 @@ public Write withErrorHandler(ErrorHandler badRecordErrorHandle .build(); } + /** Enable validation of the PubSub Write. */ + public Write withValidation() { + return toBuilder().setValidate(true).build(); + } + @Override public PDone expand(PCollection input) { if (getTopicProvider() == null && !getDynamicDestinations()) { @@ -1566,6 +1626,35 @@ public PDone expand(PCollection input) { throw new RuntimeException(); // cases are exhaustive. } + @Override + public void validate(PipelineOptions options) { + if (!getValidate()) { + return; + } + + PubsubOptions psOptions = options.as(PubsubOptions.class); + + // Validate the existence of the topic. + if (getTopicProvider() != null) { + PubsubTopic topic = getTopicProvider().get(); + boolean topicExists = true; + try (PubsubClient pubsubClient = + getPubsubClientFactory() + .newClient(getTimestampAttribute(), getIdAttribute(), psOptions)) { + topicExists = + pubsubClient.isTopicExists( + PubsubClient.topicPathFromName(topic.project, topic.topic)); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (!topicExists) { + throw new IllegalArgumentException( + String.format("Pubsub topic '%s' does not exist.", topic)); + } + } + } + @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java index 386febcf005b..0a838da66f69 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClient.java @@ -19,6 +19,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import com.google.api.client.googleapis.json.GoogleJsonResponseException; import com.google.api.client.http.HttpRequestInitializer; import com.google.api.services.pubsub.Pubsub; import com.google.api.services.pubsub.Pubsub.Projects.Subscriptions; @@ -310,6 +311,19 @@ public List listTopics(ProjectPath project) throws IOException { return topics; } + @Override + public boolean isTopicExists(TopicPath topic) throws IOException { + try { + pubsub.projects().topics().get(topic.getPath()).execute(); + return true; + } catch (GoogleJsonResponseException e) { + if (e.getStatusCode() == 404) { + return false; + } + throw e; + } + } + @Override public void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java index c1f6b2b31754..8a628817fe27 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java @@ -43,10 +43,7 @@ import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; /** * An implementation of {@link TypedSchemaTransformProvider} for Pub/Sub reads configured using @@ -313,19 +310,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsub_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("output", "errors"); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java index a8109d05ec38..3d5a879fce15 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java @@ -605,6 +605,12 @@ public List listTopics(ProjectPath project) throws IOException { throw new UnsupportedOperationException(); } + @Override + public boolean isTopicExists(TopicPath topic) throws IOException { + // Always return true for testing purposes. + return true; + } + @Override public void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java index 6187f6f79d3e..2abd6f5fa95d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java @@ -44,9 +44,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; /** * An implementation of {@link TypedSchemaTransformProvider} for Pub/Sub reads configured using @@ -248,19 +245,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsub_write:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.singletonList("input"); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Collections.singletonList("errors"); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java index 8afe730f32ce..9e83619f7b8d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java @@ -63,10 +63,7 @@ import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -86,8 +83,7 @@ public class PubsubLiteReadSchemaTransformProvider public static final TupleTag ERROR_TAG = new TupleTag() {}; @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return PubsubLiteReadSchemaTransformConfiguration.class; } @@ -192,8 +188,7 @@ public void finish(FinishBundleContext c) { } @Override - public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - PubsubLiteReadSchemaTransformConfiguration configuration) { + public SchemaTransform from(PubsubLiteReadSchemaTransformConfiguration configuration) { if (!VALID_DATA_FORMATS.contains(configuration.getFormat())) { throw new IllegalArgumentException( String.format( @@ -399,19 +394,17 @@ public Uuid apply(SequencedMessage input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsublite_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("output", "errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java index 8ba8176035da..ebca921c57e1 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java @@ -60,10 +60,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -81,8 +78,7 @@ public class PubsubLiteWriteSchemaTransformProvider LoggerFactory.getLogger(PubsubLiteWriteSchemaTransformProvider.class); @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return PubsubLiteWriteSchemaTransformConfiguration.class; } @@ -172,8 +168,7 @@ public void finish() { } @Override - public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - PubsubLiteWriteSchemaTransformConfiguration configuration) { + public SchemaTransform from(PubsubLiteWriteSchemaTransformConfiguration configuration) { if (!SUPPORTED_FORMATS.contains(configuration.getFormat())) { throw new IllegalArgumentException( @@ -317,19 +312,17 @@ public byte[] apply(Row input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsublite_write:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.singletonList("input"); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Collections.singletonList("errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 435bbba9ae8e..a6cf7ebb12a5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -25,7 +25,6 @@ import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.DEFAULT_RPC_PRIORITY; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.MAX_INCLUSIVE_END_AT; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.THROUGHPUT_WINDOW_SECONDS; -import static org.apache.beam.sdk.io.gcp.spanner.changestreams.NameGenerator.generatePartitionMetadataTableName; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; @@ -61,6 +60,7 @@ import java.util.HashMap; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -77,6 +77,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.MetadataSpannerConfigFactory; import org.apache.beam.sdk.io.gcp.spanner.changestreams.action.ActionFactory; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.DaoFactory; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.PartitionMetadataTableNames; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.CleanUpReadChangeStreamDoFn; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.DetectNewPartitionsDoFn; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.InitializeDoFn; @@ -290,8 +291,8 @@ * grouped into batches. The default maximum size of the batch is set to 1MB or 5000 mutated cells, * or 500 rows (whichever is reached first). To override this use {@link * Write#withBatchSizeBytes(long) withBatchSizeBytes()}, {@link Write#withMaxNumMutations(long) - * withMaxNumMutations()} or {@link Write#withMaxNumMutations(long) withMaxNumRows()}. Setting - * either to a small value or zero disables batching. + * withMaxNumMutations()} or {@link Write#withMaxNumRows(long) withMaxNumRows()}. Setting either to + * a small value or zero disables batching. * *

    Note that the maximum @@ -1772,9 +1773,13 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta + fullPartitionMetadataDatabaseId + " has dialect " + metadataDatabaseDialect); - final String partitionMetadataTableName = - MoreObjects.firstNonNull( - getMetadataTable(), generatePartitionMetadataTableName(partitionMetadataDatabaseId)); + PartitionMetadataTableNames partitionMetadataTableNames = + Optional.ofNullable(getMetadataTable()) + .map( + table -> + PartitionMetadataTableNames.fromExistingTable( + partitionMetadataDatabaseId, table)) + .orElse(PartitionMetadataTableNames.generateRandom(partitionMetadataDatabaseId)); final String changeStreamName = getChangeStreamName(); final Timestamp startTimestamp = getInclusiveStartAt(); // Uses (Timestamp.MAX - 1ns) at max for end timestamp, because we add 1ns to transform the @@ -1791,7 +1796,7 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta changeStreamSpannerConfig, changeStreamName, partitionMetadataSpannerConfig, - partitionMetadataTableName, + partitionMetadataTableNames, rpcPriority, input.getPipeline().getOptions().getJobName(), changeStreamDatabaseDialect, @@ -1807,7 +1812,9 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta final PostProcessingMetricsDoFn postProcessingMetricsDoFn = new PostProcessingMetricsDoFn(metrics); - LOG.info("Partition metadata table that will be used is " + partitionMetadataTableName); + LOG.info( + "Partition metadata table that will be used is " + + partitionMetadataTableNames.getTableName()); final PCollection impulseOut = input.apply(Impulse.create()); final PCollection partitionsOut = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java index 9820bb39d09d..76440b1ebf1a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java @@ -40,10 +40,8 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +/** A provider for reading from Cloud Spanner using a Schema Transform Provider. */ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) @@ -57,43 +55,81 @@ * *

    The transformation leverages the {@link SpannerIO} to perform the read operation and maps the * results to Beam rows, preserving the schema. - * - *

    Example usage in a YAML pipeline using query: - * - *

    {@code
    - * pipeline:
    - *   transforms:
    - *     - type: ReadFromSpanner
    - *       name: ReadShipments
    - *       # Columns: shipment_id, customer_id, shipment_date, shipment_cost, customer_name, customer_email
    - *       config:
    - *         project_id: 'apache-beam-testing'
    - *         instance_id: 'shipment-test'
    - *         database_id: 'shipment'
    - *         query: 'SELECT * FROM shipments'
    - * }
    - * - *

    Example usage in a YAML pipeline using a table and columns: - * - *

    {@code
    - * pipeline:
    - *   transforms:
    - *     - type: ReadFromSpanner
    - *       name: ReadShipments
    - *       # Columns: shipment_id, customer_id, shipment_date, shipment_cost, customer_name, customer_email
    - *       config:
    - *         project_id: 'apache-beam-testing'
    - *         instance_id: 'shipment-test'
    - *         database_id: 'shipment'
    - *         table: 'shipments'
    - *         columns: ['customer_id', 'customer_name']
    - * }
    */ @AutoService(SchemaTransformProvider.class) public class SpannerReadSchemaTransformProvider extends TypedSchemaTransformProvider< SpannerReadSchemaTransformProvider.SpannerReadSchemaTransformConfiguration> { + @Override + public String identifier() { + return "beam:schematransform:org.apache.beam:spanner_read:v1"; + } + + @Override + public String description() { + return "Performs a Bulk read from Google Cloud Spanner using a specified SQL query or " + + "by directly accessing a single table and its columns.\n" + + "\n" + + "Both Query and Read APIs are supported. See more information about " + + "
    reading from Cloud Spanner.\n" + + "\n" + + "Example configuration for performing a read using a SQL query: ::\n" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " query: 'SELECT * FROM table'\n" + + "\n" + + "It is also possible to read a table by specifying a table name and a list of columns. For " + + "example, the following configuration will perform a read on an entire table: ::\n" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + " columns: ['col1', 'col2']\n" + + "\n" + + "Additionally, to read using a " + + "Secondary Index, specify the index name: ::" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + " index: 'my-index'\n" + + " columns: ['col1', 'col2']\n" + + "\n" + + "### Advanced Usage\n" + + "\n" + + "Reads by default use the " + + "PartitionQuery API which enforces some limitations on the type of queries that can be used so that " + + "the data can be read in parallel. If the query is not supported by the PartitionQuery API, then you " + + "can specify a non-partitioned read by setting batching to false.\n" + + "\n" + + "For example: ::" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " batching: false\n" + + " ...\n" + + "\n" + + "Note: See " + + "SpannerIO for more advanced information."; + } + static class SpannerSchemaTransformRead extends SchemaTransform implements Serializable { private final SpannerReadSchemaTransformConfiguration configuration; @@ -116,6 +152,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } else { read = read.withTable(configuration.getTableId()).withColumns(configuration.getColumns()); } + if (!Strings.isNullOrEmpty(configuration.getIndex())) { + read = read.withIndex(configuration.getIndex()); + } + if (Boolean.FALSE.equals(configuration.getBatching())) { + read = read.withBatching(false); + } PCollection spannerRows = input.getPipeline().apply(read); Schema schema = spannerRows.getSchema(); PCollection rows = @@ -128,19 +170,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:spanner_read:v1"; - } - - @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Collections.singletonList("output"); } @@ -162,6 +197,10 @@ public abstract static class Builder { public abstract Builder setColumns(List columns); + public abstract Builder setIndex(String index); + + public abstract Builder setBatching(Boolean batching); + public abstract SpannerReadSchemaTransformConfiguration build(); } @@ -198,16 +237,16 @@ public static Builder builder() { .Builder(); } - @SchemaFieldDescription("Specifies the GCP project ID.") - @Nullable - public abstract String getProjectId(); - @SchemaFieldDescription("Specifies the Cloud Spanner instance.") public abstract String getInstanceId(); @SchemaFieldDescription("Specifies the Cloud Spanner database.") public abstract String getDatabaseId(); + @SchemaFieldDescription("Specifies the GCP project ID.") + @Nullable + public abstract String getProjectId(); + @SchemaFieldDescription("Specifies the Cloud Spanner table.") @Nullable public abstract String getTableId(); @@ -216,20 +255,29 @@ public static Builder builder() { @Nullable public abstract String getQuery(); - @SchemaFieldDescription("Specifies the columns to read from the table.") + @SchemaFieldDescription( + "Specifies the columns to read from the table. This parameter is required when table is specified.") @Nullable public abstract List getColumns(); + + @SchemaFieldDescription( + "Specifies the Index to read from. This parameter can only be specified when using table.") + @Nullable + public abstract String getIndex(); + + @SchemaFieldDescription( + "Set to false to disable batching. Useful when using a query that is not compatible with the PartitionQuery API. Defaults to true.") + @Nullable + public abstract Boolean getBatching(); } @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return SpannerReadSchemaTransformConfiguration.class; } @Override - protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - SpannerReadSchemaTransformConfiguration configuration) { + protected SchemaTransform from(SpannerReadSchemaTransformConfiguration configuration) { return new SpannerSchemaTransformRead(configuration); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java index f50755d18155..8601da09ea09 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java @@ -51,9 +51,7 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) @@ -69,43 +67,6 @@ *

    The transformation uses the {@link SpannerIO} to perform the write operation and provides * options to handle failed mutations, either by throwing an error, or passing the failed mutation * further in the pipeline for dealing with accordingly. - * - *

    Example usage in a YAML pipeline without error handling: - * - *

    {@code
    - * pipeline:
    - *   transforms:
    - *     - type: WriteToSpanner
    - *       name: WriteShipments
    - *       config:
    - *         project_id: 'apache-beam-testing'
    - *         instance_id: 'shipment-test'
    - *         database_id: 'shipment'
    - *         table_id: 'shipments'
    - *
    - * }
    - * - *

    Example usage in a YAML pipeline using error handling: - * - *

    {@code
    - * pipeline:
    - *   transforms:
    - *     - type: WriteToSpanner
    - *       name: WriteShipments
    - *       config:
    - *         project_id: 'apache-beam-testing'
    - *         instance_id: 'shipment-test'
    - *         database_id: 'shipment'
    - *         table_id: 'shipments'
    - *         error_handling:
    - *           output: 'errors'
    - *
    - *     - type: WriteToJson
    - *       input: WriteSpanner.my_error_output
    - *       config:
    - *          path: errors.json
    - *
    - * }
    */ @AutoService(SchemaTransformProvider.class) public class SpannerWriteSchemaTransformProvider @@ -113,14 +74,37 @@ public class SpannerWriteSchemaTransformProvider SpannerWriteSchemaTransformProvider.SpannerWriteSchemaTransformConfiguration> { @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + public String identifier() { + return "beam:schematransform:org.apache.beam:spanner_write:v1"; + } + + @Override + public String description() { + return "Performs a bulk write to a Google Cloud Spanner table.\n" + + "\n" + + "Example configuration for performing a write to a single table: ::\n" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " project_id: 'my-project-id'\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + "\n" + + "Note: See " + + "SpannerIO for more advanced information."; + } + + @Override + protected Class configurationClass() { return SpannerWriteSchemaTransformConfiguration.class; } @Override - protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - SpannerWriteSchemaTransformConfiguration configuration) { + protected SchemaTransform from(SpannerWriteSchemaTransformConfiguration configuration) { return new SpannerSchemaTransformWrite(configuration); } @@ -230,19 +214,12 @@ public PCollectionRowTuple expand(@NonNull PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:spanner_write:v1"; - } - - @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.singletonList("input"); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("post-write", "errors"); } @@ -250,10 +227,6 @@ public PCollectionRowTuple expand(@NonNull PCollectionRowTuple input) { @DefaultSchema(AutoValueSchema.class) public abstract static class SpannerWriteSchemaTransformConfiguration implements Serializable { - @SchemaFieldDescription("Specifies the GCP project.") - @Nullable - public abstract String getProjectId(); - @SchemaFieldDescription("Specifies the Cloud Spanner instance.") public abstract String getInstanceId(); @@ -263,7 +236,11 @@ public abstract static class SpannerWriteSchemaTransformConfiguration implements @SchemaFieldDescription("Specifies the Cloud Spanner table.") public abstract String getTableId(); - @SchemaFieldDescription("Specifies how to handle errors.") + @SchemaFieldDescription("Specifies the GCP project.") + @Nullable + public abstract String getProjectId(); + + @SchemaFieldDescription("Whether and how to handle write errors.") @Nullable public abstract ErrorHandling getErrorHandling(); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java deleted file mode 100644 index 322e85cb07a2..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.spanner.changestreams; - -import java.util.UUID; - -/** - * This class generates a unique name for the partition metadata table, which is created when the - * Connector is initialized. - */ -public class NameGenerator { - - private static final String PARTITION_METADATA_TABLE_NAME_FORMAT = "Metadata_%s_%s"; - private static final int MAX_TABLE_NAME_LENGTH = 63; - - /** - * Generates an unique name for the partition metadata table in the form of {@code - * "Metadata__"}. - * - * @param databaseId The database id where the table will be created - * @return the unique generated name of the partition metadata table - */ - public static String generatePartitionMetadataTableName(String databaseId) { - // There are 11 characters in the name format. - // Maximum Spanner database ID length is 30 characters. - // UUID always generates a String with 36 characters. - // Since the Postgres table name length is 63, we may need to truncate the table name depending - // on the database length. - String fullString = - String.format(PARTITION_METADATA_TABLE_NAME_FORMAT, databaseId, UUID.randomUUID()) - .replaceAll("-", "_"); - if (fullString.length() < MAX_TABLE_NAME_LENGTH) { - return fullString; - } - return fullString.substring(0, MAX_TABLE_NAME_LENGTH); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java index f3562e4cd917..e7bc064b1f33 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java @@ -66,10 +66,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.com.google.gson.Gson; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; @@ -80,8 +77,7 @@ public class SpannerChangestreamsReadSchemaTransformProvider extends TypedSchemaTransformProvider< SpannerChangestreamsReadSchemaTransformProvider.SpannerChangestreamsReadConfiguration> { @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return SpannerChangestreamsReadConfiguration.class; } @@ -94,7 +90,7 @@ public class SpannerChangestreamsReadSchemaTransformProvider Schema.builder().addStringField("error").addNullableStringField("row").build(); @Override - public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( + public SchemaTransform from( SpannerChangestreamsReadSchemaTransformProvider.SpannerChangestreamsReadConfiguration configuration) { return new SchemaTransform() { @@ -142,19 +138,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:spanner_cdc_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("output", "errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java index b9718fdb675e..787abad02e02 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java @@ -44,7 +44,7 @@ public class DaoFactory implements Serializable { private final SpannerConfig metadataSpannerConfig; private final String changeStreamName; - private final String partitionMetadataTableName; + private final PartitionMetadataTableNames partitionMetadataTableNames; private final RpcPriority rpcPriority; private final String jobName; private final Dialect spannerChangeStreamDatabaseDialect; @@ -56,7 +56,7 @@ public class DaoFactory implements Serializable { * @param changeStreamSpannerConfig the configuration for the change streams DAO * @param changeStreamName the name of the change stream for the change streams DAO * @param metadataSpannerConfig the metadata tables configuration - * @param partitionMetadataTableName the name of the created partition metadata table + * @param partitionMetadataTableNames the names of the partition metadata ddl objects * @param rpcPriority the priority of the requests made by the DAO queries * @param jobName the name of the running job */ @@ -64,7 +64,7 @@ public DaoFactory( SpannerConfig changeStreamSpannerConfig, String changeStreamName, SpannerConfig metadataSpannerConfig, - String partitionMetadataTableName, + PartitionMetadataTableNames partitionMetadataTableNames, RpcPriority rpcPriority, String jobName, Dialect spannerChangeStreamDatabaseDialect, @@ -78,7 +78,7 @@ public DaoFactory( this.changeStreamSpannerConfig = changeStreamSpannerConfig; this.changeStreamName = changeStreamName; this.metadataSpannerConfig = metadataSpannerConfig; - this.partitionMetadataTableName = partitionMetadataTableName; + this.partitionMetadataTableNames = partitionMetadataTableNames; this.rpcPriority = rpcPriority; this.jobName = jobName; this.spannerChangeStreamDatabaseDialect = spannerChangeStreamDatabaseDialect; @@ -102,7 +102,7 @@ public synchronized PartitionMetadataAdminDao getPartitionMetadataAdminDao() { databaseAdminClient, metadataSpannerConfig.getInstanceId().get(), metadataSpannerConfig.getDatabaseId().get(), - partitionMetadataTableName, + partitionMetadataTableNames, this.metadataDatabaseDialect); } return partitionMetadataAdminDao; @@ -120,7 +120,7 @@ public synchronized PartitionMetadataDao getPartitionMetadataDao() { if (partitionMetadataDaoInstance == null) { partitionMetadataDaoInstance = new PartitionMetadataDao( - this.partitionMetadataTableName, + this.partitionMetadataTableNames.getTableName(), spannerAccessor.getDatabaseClient(), this.metadataDatabaseDialect); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java index 368cab7022b3..3e6045d8858b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java @@ -79,19 +79,13 @@ public class PartitionMetadataAdminDao { */ public static final String COLUMN_FINISHED_AT = "FinishedAt"; - /** Metadata table index for queries over the watermark column. */ - public static final String WATERMARK_INDEX = "WatermarkIndex"; - - /** Metadata table index for queries over the created at / start timestamp columns. */ - public static final String CREATED_AT_START_TIMESTAMP_INDEX = "CreatedAtStartTimestampIndex"; - private static final int TIMEOUT_MINUTES = 10; private static final int TTL_AFTER_PARTITION_FINISHED_DAYS = 1; private final DatabaseAdminClient databaseAdminClient; private final String instanceId; private final String databaseId; - private final String tableName; + private final PartitionMetadataTableNames names; private final Dialect dialect; /** @@ -101,18 +95,18 @@ public class PartitionMetadataAdminDao { * table * @param instanceId the instance where the metadata table will reside * @param databaseId the database where the metadata table will reside - * @param tableName the name of the metadata table + * @param names the names of the metadata table ddl objects */ PartitionMetadataAdminDao( DatabaseAdminClient databaseAdminClient, String instanceId, String databaseId, - String tableName, + PartitionMetadataTableNames names, Dialect dialect) { this.databaseAdminClient = databaseAdminClient; this.instanceId = instanceId; this.databaseId = databaseId; - this.tableName = tableName; + this.names = names; this.dialect = dialect; } @@ -128,8 +122,8 @@ public void createPartitionMetadataTable() { if (this.isPostgres()) { // Literals need be added around literals to preserve casing. ddl.add( - "CREATE TABLE \"" - + tableName + "CREATE TABLE IF NOT EXISTS \"" + + names.getTableName() + "\"(\"" + COLUMN_PARTITION_TOKEN + "\" text NOT NULL,\"" @@ -163,20 +157,20 @@ public void createPartitionMetadataTable() { + COLUMN_FINISHED_AT + "\""); ddl.add( - "CREATE INDEX \"" - + WATERMARK_INDEX + "CREATE INDEX IF NOT EXISTS \"" + + names.getWatermarkIndexName() + "\" on \"" - + tableName + + names.getTableName() + "\" (\"" + COLUMN_WATERMARK + "\") INCLUDE (\"" + COLUMN_STATE + "\")"); ddl.add( - "CREATE INDEX \"" - + CREATED_AT_START_TIMESTAMP_INDEX + "CREATE INDEX IF NOT EXISTS \"" + + names.getCreatedAtIndexName() + "\" ON \"" - + tableName + + names.getTableName() + "\" (\"" + COLUMN_CREATED_AT + "\",\"" @@ -184,8 +178,8 @@ public void createPartitionMetadataTable() { + "\")"); } else { ddl.add( - "CREATE TABLE " - + tableName + "CREATE TABLE IF NOT EXISTS " + + names.getTableName() + " (" + COLUMN_PARTITION_TOKEN + " STRING(MAX) NOT NULL," @@ -218,20 +212,20 @@ public void createPartitionMetadataTable() { + TTL_AFTER_PARTITION_FINISHED_DAYS + " DAY))"); ddl.add( - "CREATE INDEX " - + WATERMARK_INDEX + "CREATE INDEX IF NOT EXISTS " + + names.getWatermarkIndexName() + " on " - + tableName + + names.getTableName() + " (" + COLUMN_WATERMARK + ") STORING (" + COLUMN_STATE + ")"); ddl.add( - "CREATE INDEX " - + CREATED_AT_START_TIMESTAMP_INDEX + "CREATE INDEX IF NOT EXISTS " + + names.getCreatedAtIndexName() + " ON " - + tableName + + names.getTableName() + " (" + COLUMN_CREATED_AT + "," @@ -261,16 +255,14 @@ public void createPartitionMetadataTable() { * Drops the metadata table. This operation should complete in {@link * PartitionMetadataAdminDao#TIMEOUT_MINUTES} minutes. */ - public void deletePartitionMetadataTable() { + public void deletePartitionMetadataTable(List indexes) { List ddl = new ArrayList<>(); if (this.isPostgres()) { - ddl.add("DROP INDEX \"" + CREATED_AT_START_TIMESTAMP_INDEX + "\""); - ddl.add("DROP INDEX \"" + WATERMARK_INDEX + "\""); - ddl.add("DROP TABLE \"" + tableName + "\""); + indexes.forEach(index -> ddl.add("DROP INDEX \"" + index + "\"")); + ddl.add("DROP TABLE \"" + names.getTableName() + "\""); } else { - ddl.add("DROP INDEX " + CREATED_AT_START_TIMESTAMP_INDEX); - ddl.add("DROP INDEX " + WATERMARK_INDEX); - ddl.add("DROP TABLE " + tableName); + indexes.forEach(index -> ddl.add("DROP INDEX " + index)); + ddl.add("DROP TABLE " + names.getTableName()); } OperationFuture op = databaseAdminClient.updateDatabaseDdl(instanceId, databaseId, ddl, null); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java index 7867932cd1ad..654fd946663c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java @@ -96,6 +96,41 @@ public boolean tableExists() { } } + /** + * Finds all indexes for the metadata table. + * + * @return a list of index names for the metadata table. + */ + public List findAllTableIndexes() { + String indexesStmt; + if (this.isPostgres()) { + indexesStmt = + "SELECT index_name FROM information_schema.indexes" + + " WHERE table_schema = 'public'" + + " AND table_name = '" + + metadataTableName + + "' AND index_type != 'PRIMARY_KEY'"; + } else { + indexesStmt = + "SELECT index_name FROM information_schema.indexes" + + " WHERE table_schema = ''" + + " AND table_name = '" + + metadataTableName + + "' AND index_type != 'PRIMARY_KEY'"; + } + + List result = new ArrayList<>(); + try (ResultSet queryResultSet = + databaseClient + .singleUseReadOnlyTransaction() + .executeQuery(Statement.of(indexesStmt), Options.tag("query=findAllTableIndexes"))) { + while (queryResultSet.next()) { + result.add(queryResultSet.getString("index_name")); + } + } + return result; + } + /** * Fetches the partition metadata row data for the given partition token. * diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java new file mode 100644 index 000000000000..07d7b80676de --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.dao; + +import java.io.Serializable; +import java.util.Objects; +import java.util.UUID; +import javax.annotation.Nullable; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; + +/** + * Configuration for a partition metadata table. It encapsulates the name of the metadata table and + * indexes. + */ +public class PartitionMetadataTableNames implements Serializable { + + private static final long serialVersionUID = 8848098877671834584L; + + /** PostgreSQL max table and index length is 63 bytes. */ + @VisibleForTesting static final int MAX_NAME_LENGTH = 63; + + private static final String PARTITION_METADATA_TABLE_NAME_FORMAT = "Metadata_%s_%s"; + private static final String WATERMARK_INDEX_NAME_FORMAT = "WatermarkIdx_%s_%s"; + private static final String CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT = "CreatedAtIdx_%s_%s"; + + /** + * Generates a unique name for the partition metadata table and its indexes. The table name will + * be in the form of {@code "Metadata__"}. The watermark index will be in the + * form of {@code "WatermarkIdx__}. The createdAt / start timestamp index will + * be in the form of {@code "CreatedAtIdx__}. + * + * @param databaseId The database id where the table will be created + * @return the unique generated names of the partition metadata ddl + */ + public static PartitionMetadataTableNames generateRandom(String databaseId) { + UUID uuid = UUID.randomUUID(); + + String table = generateName(PARTITION_METADATA_TABLE_NAME_FORMAT, databaseId, uuid); + String watermarkIndex = generateName(WATERMARK_INDEX_NAME_FORMAT, databaseId, uuid); + String createdAtIndex = + generateName(CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT, databaseId, uuid); + + return new PartitionMetadataTableNames(table, watermarkIndex, createdAtIndex); + } + + /** + * Encapsulates a selected table name. Index names are generated, but will only be used if the + * given table does not exist. The watermark index will be in the form of {@code + * "WatermarkIdx__}. The createdAt / start timestamp index will be in the form + * of {@code "CreatedAtIdx__}. + * + * @param databaseId The database id for the table + * @param table The table name to be used + * @return an instance with the table name and generated index names + */ + public static PartitionMetadataTableNames fromExistingTable(String databaseId, String table) { + UUID uuid = UUID.randomUUID(); + + String watermarkIndex = generateName(WATERMARK_INDEX_NAME_FORMAT, databaseId, uuid); + String createdAtIndex = + generateName(CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT, databaseId, uuid); + return new PartitionMetadataTableNames(table, watermarkIndex, createdAtIndex); + } + + private static String generateName(String template, String databaseId, UUID uuid) { + String name = String.format(template, databaseId, uuid).replaceAll("-", "_"); + if (name.length() > MAX_NAME_LENGTH) { + return name.substring(0, MAX_NAME_LENGTH); + } + return name; + } + + private final String tableName; + private final String watermarkIndexName; + private final String createdAtIndexName; + + public PartitionMetadataTableNames( + String tableName, String watermarkIndexName, String createdAtIndexName) { + this.tableName = tableName; + this.watermarkIndexName = watermarkIndexName; + this.createdAtIndexName = createdAtIndexName; + } + + public String getTableName() { + return tableName; + } + + public String getWatermarkIndexName() { + return watermarkIndexName; + } + + public String getCreatedAtIndexName() { + return createdAtIndexName; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof PartitionMetadataTableNames)) { + return false; + } + PartitionMetadataTableNames that = (PartitionMetadataTableNames) o; + return Objects.equals(tableName, that.tableName) + && Objects.equals(watermarkIndexName, that.watermarkIndexName) + && Objects.equals(createdAtIndexName, that.createdAtIndexName); + } + + @Override + public int hashCode() { + return Objects.hash(tableName, watermarkIndexName, createdAtIndexName); + } + + @Override + public String toString() { + return "PartitionMetadataTableNames{" + + "tableName='" + + tableName + + '\'' + + ", watermarkIndexName='" + + watermarkIndexName + + '\'' + + ", createdAtIndexName='" + + createdAtIndexName + + '\'' + + '}'; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java index a048c885a001..f8aa497292bf 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn; import java.io.Serializable; +import java.util.List; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.DaoFactory; import org.apache.beam.sdk.transforms.DoFn; @@ -33,6 +34,7 @@ public CleanUpReadChangeStreamDoFn(DaoFactory daoFactory) { @ProcessElement public void processElement(OutputReceiver receiver) { - daoFactory.getPartitionMetadataAdminDao().deletePartitionMetadataTable(); + List indexes = daoFactory.getPartitionMetadataDao().findAllTableIndexes(); + daoFactory.getPartitionMetadataAdminDao().deletePartitionMetadataTable(indexes); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java index 387ffd603b14..ca93f34bf1ba 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java @@ -64,6 +64,7 @@ public InitializeDoFn( public void processElement(OutputReceiver receiver) { PartitionMetadataDao partitionMetadataDao = daoFactory.getPartitionMetadataDao(); if (!partitionMetadataDao.tableExists()) { + // Creates partition metadata table and associated indexes daoFactory.getPartitionMetadataAdminDao().createPartitionMetadataTable(); createFakeParentPartition(); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java index 4013f0018553..d8c580a0cd18 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import com.google.protobuf.ByteString; import com.google.protobuf.DescriptorProtos.DescriptorProto; @@ -36,8 +37,11 @@ import java.time.LocalTime; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; @@ -284,12 +288,14 @@ public class BeamRowToStorageApiProtoTest { .addField("nested", FieldType.row(BASE_SCHEMA).withNullable(true)) .addField("nestedArray", FieldType.array(FieldType.row(BASE_SCHEMA))) .addField("nestedIterable", FieldType.iterable(FieldType.row(BASE_SCHEMA))) + .addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.row(BASE_SCHEMA))) .build(); private static final Row NESTED_ROW = Row.withSchema(NESTED_SCHEMA) .withFieldValue("nested", BASE_ROW) .withFieldValue("nestedArray", ImmutableList.of(BASE_ROW, BASE_ROW)) .withFieldValue("nestedIterable", ImmutableList.of(BASE_ROW, BASE_ROW)) + .withFieldValue("nestedMap", ImmutableMap.of("key1", BASE_ROW, "key2", BASE_ROW)) .build(); @Test @@ -347,12 +353,12 @@ public void testNestedFromSchema() { .collect( Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel)); - assertEquals(3, types.size()); + assertEquals(4, types.size()); Map nestedTypes = descriptor.getNestedTypeList().stream() .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); - assertEquals(3, nestedTypes.size()); + assertEquals(4, nestedTypes.size()); assertEquals(Type.TYPE_MESSAGE, types.get("nested")); assertEquals(Label.LABEL_OPTIONAL, typeLabels.get("nested")); String nestedTypeName1 = typeNames.get("nested"); @@ -379,6 +385,87 @@ public void testNestedFromSchema() { .collect( Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); assertEquals(expectedBaseTypes, nestedTypes3); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmap")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmap")); + String nestedTypeName4 = typeNames.get("nestedmap"); + // expects 2 fields in the nested map, key and value + assertEquals(2, nestedTypes.get(nestedTypeName4).getFieldList().size()); + Supplier> stream = + () -> nestedTypes.get(nestedTypeName4).getFieldList().stream(); + assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("key"))); + assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("value"))); + + Map nestedTypes4 = + nestedTypes.get(nestedTypeName4).getNestedTypeList().stream() + .flatMap(vdesc -> vdesc.getFieldList().stream()) + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedBaseTypes, nestedTypes4); + } + + @Test + public void testParticularMapsFromSchemas() { + Schema nestedMapSchemaVariations = + Schema.builder() + .addField( + "nestedMultiMap", + FieldType.map(FieldType.STRING, FieldType.array(FieldType.STRING))) + .addField( + "nestedMapNullable", + FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true)) + .build(); + + DescriptorProto descriptor = + TableRowToStorageApiProto.descriptorSchemaFromTableSchema( + BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema((nestedMapSchemaVariations)), + true, + false); + + Map types = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + Map typeNames = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getTypeName)); + Map typeLabels = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel)); + + Map nestedTypes = + descriptor.getNestedTypeList().stream() + .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); + assertEquals(2, nestedTypes.size()); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmultimap")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmultimap")); + String nestedMultiMapName = typeNames.get("nestedmultimap"); + // expects 2 fields for the nested array of maps, key and value + assertEquals(2, nestedTypes.get(nestedMultiMapName).getFieldList().size()); + Supplier> stream = + () -> nestedTypes.get(nestedMultiMapName).getFieldList().stream(); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1); + assertTrue( + stream + .get() + .filter(fdp -> fdp.getName().equals("value")) + .filter(fdp -> fdp.getLabel().equals(Label.LABEL_REPEATED)) + .count() + == 1); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmapnullable")); + // even though the field is marked as optional in the row we will should see repeated in proto + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmapnullable")); + String nestedMapNullableName = typeNames.get("nestedmapnullable"); + // expects 2 fields in the nullable maps, key and value + assertEquals(2, nestedTypes.get(nestedMapNullableName).getFieldList().size()); + stream = () -> nestedTypes.get(nestedMapNullableName).getFieldList().stream(); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1); } private void assertBaseRecord(DynamicMessage msg) { @@ -395,7 +482,7 @@ public void testMessageFromTableRow() throws Exception { BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(NESTED_SCHEMA), true, false); DynamicMessage msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, null, -1); - assertEquals(3, msg.getAllFields().size()); + assertEquals(4, msg.getAllFields().size()); Map fieldDescriptors = descriptor.getFields().stream() @@ -404,6 +491,63 @@ public void testMessageFromTableRow() throws Exception { assertBaseRecord(nestedMsg); } + @Test + public void testMessageFromTableRowForArraysAndMaps() throws Exception { + Schema nestedMapSchemaVariations = + Schema.builder() + .addField("nestedArrayNullable", FieldType.array(FieldType.STRING).withNullable(true)) + .addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.STRING)) + .addField( + "nestedMultiMap", + FieldType.map(FieldType.STRING, FieldType.iterable(FieldType.STRING))) + .addField( + "nestedMapNullable", + FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true)) + .build(); + + Row nestedRow = + Row.withSchema(nestedMapSchemaVariations) + .withFieldValue("nestedArrayNullable", null) + .withFieldValue("nestedMap", ImmutableMap.of("key1", "value1")) + .withFieldValue( + "nestedMultiMap", + ImmutableMap.of("multikey1", ImmutableList.of("multivalue1", "multivalue2"))) + .withFieldValue("nestedMapNullable", null) + .build(); + + Descriptor descriptor = + TableRowToStorageApiProto.getDescriptorFromTableSchema( + BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(nestedMapSchemaVariations), + true, + false); + DynamicMessage msg = + BeamRowToStorageApiProto.messageFromBeamRow(descriptor, nestedRow, null, -1); + + Map fieldDescriptors = + descriptor.getFields().stream() + .collect(Collectors.toMap(FieldDescriptor::getName, Functions.identity())); + + DynamicMessage nestedMapEntryMsg = + (DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmap"), 0); + String value = + (String) + nestedMapEntryMsg.getField( + fieldDescriptors.get("nestedmap").getMessageType().findFieldByName("value")); + assertEquals("value1", value); + + DynamicMessage nestedMultiMapEntryMsg = + (DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmultimap"), 0); + List values = + (List) + nestedMultiMapEntryMsg.getField( + fieldDescriptors.get("nestedmultimap").getMessageType().findFieldByName("value")); + assertTrue(values.size() == 2); + assertEquals("multivalue1", values.get(0)); + + assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedarraynullable")) == 0); + assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedmapnullable")) == 0); + } + @Test public void testCdcFields() throws Exception { Descriptor descriptor = @@ -413,7 +557,7 @@ public void testCdcFields() throws Exception { assertNotNull(descriptor.findFieldByName(StorageApiCDC.CHANGE_SQN_COLUMN)); DynamicMessage msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, "UPDATE", 42); - assertEquals(5, msg.getAllFields().size()); + assertEquals(6, msg.getAllFields().size()); Map fieldDescriptors = descriptor.getFields().stream() diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java deleted file mode 100644 index dd8bb9fc8664..000000000000 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.bigquery; - -import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryFileLoadsWriteSchemaTransformProvider.INPUT_TAG; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; - -import com.google.api.services.bigquery.model.TableReference; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import java.io.IOException; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryFileLoadsWriteSchemaTransformProvider.BigQueryWriteSchemaTransform; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; -import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; -import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; -import org.apache.beam.sdk.io.gcp.testing.FakeJobService; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.Field; -import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.io.InvalidConfigurationException; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.display.DisplayData.Identifier; -import org.apache.beam.sdk.transforms.display.DisplayData.Item; -import org.apache.beam.sdk.values.PCollectionRowTuple; -import org.apache.beam.sdk.values.Row; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Test for {@link BigQueryFileLoadsWriteSchemaTransformProvider}. */ -@RunWith(JUnit4.class) -public class BigQueryFileLoadsWriteSchemaTransformProviderTest { - - private static final String PROJECT = "fakeproject"; - private static final String DATASET = "fakedataset"; - private static final String TABLE_ID = "faketable"; - - private static final TableReference TABLE_REFERENCE = - new TableReference().setProjectId(PROJECT).setDatasetId(DATASET).setTableId(TABLE_ID); - - private static final Schema SCHEMA = - Schema.of(Field.of("name", FieldType.STRING), Field.of("number", FieldType.INT64)); - - private static final TableSchema TABLE_SCHEMA = BigQueryUtils.toTableSchema(SCHEMA); - - private static final List ROWS = - Arrays.asList( - Row.withSchema(SCHEMA).withFieldValue("name", "a").withFieldValue("number", 1L).build(), - Row.withSchema(SCHEMA).withFieldValue("name", "b").withFieldValue("number", 2L).build(), - Row.withSchema(SCHEMA).withFieldValue("name", "c").withFieldValue("number", 3L).build()); - - private static final BigQueryOptions OPTIONS = - TestPipeline.testingPipelineOptions().as(BigQueryOptions.class); - private final FakeDatasetService fakeDatasetService = new FakeDatasetService(); - private final FakeJobService fakeJobService = new FakeJobService(); - private final TemporaryFolder temporaryFolder = new TemporaryFolder(); - private final FakeBigQueryServices fakeBigQueryServices = - new FakeBigQueryServices() - .withJobService(fakeJobService) - .withDatasetService(fakeDatasetService); - - @Before - public void setUp() throws IOException, InterruptedException { - FakeDatasetService.setUp(); - fakeDatasetService.createDataset(PROJECT, DATASET, "", "", null); - temporaryFolder.create(); - OPTIONS.setProject(PROJECT); - OPTIONS.setTempLocation(temporaryFolder.getRoot().getAbsolutePath()); - } - - @After - public void tearDown() { - temporaryFolder.delete(); - } - - @Rule public transient TestPipeline p = TestPipeline.fromOptions(OPTIONS); - - @Test - public void testLoad() throws IOException, InterruptedException { - BigQueryFileLoadsWriteSchemaTransformProvider provider = - new BigQueryFileLoadsWriteSchemaTransformProvider(); - BigQueryFileLoadsWriteSchemaTransformConfiguration configuration = - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name()) - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .build(); - BigQueryWriteSchemaTransform schemaTransform = - (BigQueryWriteSchemaTransform) provider.from(configuration); - schemaTransform.setTestBigQueryServices(fakeBigQueryServices); - String tag = provider.inputCollectionNames().get(0); - PCollectionRowTuple input = - PCollectionRowTuple.of(tag, p.apply(Create.of(ROWS).withRowSchema(SCHEMA))); - input.apply(schemaTransform); - - p.run(); - - assertNotNull(fakeDatasetService.getTable(TABLE_REFERENCE)); - assertEquals(ROWS.size(), fakeDatasetService.getAllRows(PROJECT, DATASET, TABLE_ID).size()); - } - - @Test - public void testValidatePipelineOptions() { - List< - Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, - Class>> - cases = - Arrays.asList( - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec("project.doesnot.exist") - .setCreateDisposition(CreateDisposition.CREATE_NEVER.name()) - .setWriteDisposition(WriteDisposition.WRITE_APPEND.name()), - InvalidConfigurationException.class), - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(String.format("%s.%s.%s", PROJECT, DATASET, "doesnotexist")) - .setCreateDisposition(CreateDisposition.CREATE_NEVER.name()) - .setWriteDisposition(WriteDisposition.WRITE_EMPTY.name()), - InvalidConfigurationException.class), - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec("project.doesnot.exist") - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .setWriteDisposition(WriteDisposition.WRITE_APPEND.name()), - null)); - for (Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, Class> - caze : cases) { - BigQueryWriteSchemaTransform transform = transformFrom(caze.getLeft().build()); - if (caze.getRight() != null) { - assertThrows(caze.getRight(), () -> transform.validate(p.getOptions())); - } else { - transform.validate(p.getOptions()); - } - } - } - - @Test - public void testToWrite() { - List< - Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, - BigQueryIO.Write>> - cases = - Arrays.asList( - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setCreateDisposition(CreateDisposition.CREATE_NEVER.name()) - .setWriteDisposition(WriteDisposition.WRITE_EMPTY.name()), - BigQueryIO.writeTableRows() - .to(TABLE_REFERENCE) - .withCreateDisposition(CreateDisposition.CREATE_NEVER) - .withWriteDisposition(WriteDisposition.WRITE_EMPTY) - .withSchema(TABLE_SCHEMA)), - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name()), - BigQueryIO.writeTableRows() - .to(TABLE_REFERENCE) - .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) - .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE) - .withSchema(TABLE_SCHEMA))); - for (Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, BigQueryIO.Write> - caze : cases) { - BigQueryWriteSchemaTransform transform = transformFrom(caze.getLeft().build()); - Map gotDisplayData = DisplayData.from(transform.toWrite(SCHEMA)).asMap(); - Map wantDisplayData = DisplayData.from(caze.getRight()).asMap(); - Set keys = new HashSet<>(); - keys.addAll(gotDisplayData.keySet()); - keys.addAll(wantDisplayData.keySet()); - for (Identifier key : keys) { - Item got = null; - Item want = null; - if (gotDisplayData.containsKey(key)) { - got = gotDisplayData.get(key); - } - if (wantDisplayData.containsKey(key)) { - want = wantDisplayData.get(key); - } - assertEquals(want, got); - } - } - } - - @Test - public void validatePCollectionRowTupleInput() { - PCollectionRowTuple empty = PCollectionRowTuple.empty(p); - PCollectionRowTuple valid = - PCollectionRowTuple.of( - INPUT_TAG, p.apply("CreateRowsWithValidSchema", Create.of(ROWS)).setRowSchema(SCHEMA)); - - PCollectionRowTuple invalid = - PCollectionRowTuple.of( - INPUT_TAG, - p.apply( - "CreateRowsWithInvalidSchema", - Create.of( - Row.nullRow( - Schema.builder().addNullableField("name", FieldType.STRING).build())))); - - BigQueryWriteSchemaTransform transform = - transformFrom( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .setWriteDisposition(WriteDisposition.WRITE_APPEND.name()) - .build()); - - assertThrows(IllegalArgumentException.class, () -> transform.validate(empty)); - - assertThrows(IllegalStateException.class, () -> transform.validate(invalid)); - - transform.validate(valid); - - p.run(); - } - - private BigQueryWriteSchemaTransform transformFrom( - BigQueryFileLoadsWriteSchemaTransformConfiguration configuration) { - BigQueryFileLoadsWriteSchemaTransformProvider provider = - new BigQueryFileLoadsWriteSchemaTransformProvider(); - BigQueryWriteSchemaTransform transform = - (BigQueryWriteSchemaTransform) provider.from(configuration); - - transform.setTestBigQueryServices(fakeBigQueryServices); - - return transform; - } -} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java index e26348b7b478..8b65e58a4601 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java @@ -698,6 +698,18 @@ public void testToTableSchema_map() { assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE)); } + @Test + public void testToTableSchema_map_array() { + TableSchema schema = toTableSchema(MAP_ARRAY_TYPE); + + assertThat(schema.getFields().size(), equalTo(1)); + TableFieldSchema field = schema.getFields().get(0); + assertThat(field.getName(), equalTo("map")); + assertThat(field.getType(), equalTo(StandardSQLTypeName.STRUCT.toString())); + assertThat(field.getMode(), equalTo(Mode.REPEATED.toString())); + assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE)); + } + @Test public void testToTableRow_flat() { TableRow row = toTableRow().apply(FLAT_ROW); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java new file mode 100644 index 000000000000..897d95da3b13 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import com.google.api.services.bigquery.model.TableReference; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryOptions; +import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsSchemaTransformProvider.BigQueryFileLoadsSchemaTransform; +import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; +import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; +import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test for {@link BigQueryFileLoadsSchemaTransformProvider}. */ +@RunWith(JUnit4.class) +public class BigQueryFileLoadsSchemaTransformProviderTest { + + private static final String PROJECT = "fakeproject"; + private static final String DATASET = "fakedataset"; + private static final String TABLE_ID = "faketable"; + + private static final TableReference TABLE_REFERENCE = + new TableReference().setProjectId(PROJECT).setDatasetId(DATASET).setTableId(TABLE_ID); + + private static final Schema SCHEMA = + Schema.of(Field.of("name", FieldType.STRING), Field.of("number", FieldType.INT64)); + + private static final List ROWS = + Arrays.asList( + Row.withSchema(SCHEMA).withFieldValue("name", "a").withFieldValue("number", 1L).build(), + Row.withSchema(SCHEMA).withFieldValue("name", "b").withFieldValue("number", 2L).build(), + Row.withSchema(SCHEMA).withFieldValue("name", "c").withFieldValue("number", 3L).build()); + + private static final BigQueryOptions OPTIONS = + TestPipeline.testingPipelineOptions().as(BigQueryOptions.class); + private final FakeDatasetService fakeDatasetService = new FakeDatasetService(); + private final FakeJobService fakeJobService = new FakeJobService(); + private final TemporaryFolder temporaryFolder = new TemporaryFolder(); + private final FakeBigQueryServices fakeBigQueryServices = + new FakeBigQueryServices() + .withJobService(fakeJobService) + .withDatasetService(fakeDatasetService); + + @Before + public void setUp() throws IOException, InterruptedException { + FakeDatasetService.setUp(); + fakeDatasetService.createDataset(PROJECT, DATASET, "", "", null); + temporaryFolder.create(); + OPTIONS.setProject(PROJECT); + OPTIONS.setTempLocation(temporaryFolder.getRoot().getAbsolutePath()); + } + + @After + public void tearDown() { + temporaryFolder.delete(); + } + + @Rule public transient TestPipeline p = TestPipeline.fromOptions(OPTIONS); + + @Test + public void testLoad() throws IOException, InterruptedException { + BigQueryFileLoadsSchemaTransformProvider provider = + new BigQueryFileLoadsSchemaTransformProvider(); + BigQueryWriteConfiguration configuration = + BigQueryWriteConfiguration.builder() + .setTable(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) + .setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name()) + .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) + .build(); + BigQueryFileLoadsSchemaTransform schemaTransform = + (BigQueryFileLoadsSchemaTransform) provider.from(configuration); + schemaTransform.setTestBigQueryServices(fakeBigQueryServices); + String tag = provider.inputCollectionNames().get(0); + PCollectionRowTuple input = + PCollectionRowTuple.of(tag, p.apply(Create.of(ROWS).withRowSchema(SCHEMA))); + input.apply(schemaTransform); + + p.run(); + + assertNotNull(fakeDatasetService.getTable(TABLE_REFERENCE)); + assertEquals(ROWS.size(), fakeDatasetService.getAllRows(PROJECT, DATASET, TABLE_ID).size()); + } + + @Test + public void testManagedChoosesFileLoadsForBoundedWrites() { + PCollection batchInput = p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); + batchInput.apply( + Managed.write(Managed.BIGQUERY) + .withConfig(ImmutableMap.of("table", "project.dataset.table"))); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> + tr.getUniqueName() + .contains(BigQueryFileLoadsSchemaTransform.class.getSimpleName())) + .collect(Collectors.toList()); + assertThat(writeTransformProto.size(), greaterThan(0)); + p.enableAbandonedNodeEnforcement(false); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java new file mode 100644 index 000000000000..63727107a651 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.testing.BigqueryClient; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PeriodicImpulse; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** This class tests the execution of {@link Managed} BigQueryIO. */ +@RunWith(JUnit4.class) +public class BigQueryManagedIT { + @Rule public TestName testName = new TestName(); + @Rule public transient TestPipeline writePipeline = TestPipeline.create(); + @Rule public transient TestPipeline readPipeline = TestPipeline.create(); + + private static final Schema SCHEMA = + Schema.of( + Schema.Field.of("str", Schema.FieldType.STRING), + Schema.Field.of("number", Schema.FieldType.INT64)); + + private static final List ROWS = + LongStream.range(0, 20) + .mapToObj( + i -> + Row.withSchema(SCHEMA) + .withFieldValue("str", Long.toString(i)) + .withFieldValue("number", i) + .build()) + .collect(Collectors.toList()); + + private static final BigqueryClient BQ_CLIENT = new BigqueryClient("BigQueryManagedIT"); + + private static final String PROJECT = + TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + private static final String BIG_QUERY_DATASET_ID = "bigquery_managed_" + System.nanoTime(); + + @BeforeClass + public static void setUpTestEnvironment() throws IOException, InterruptedException { + // Create one BQ dataset for all test cases. + BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID, null); + } + + @AfterClass + public static void cleanup() { + BQ_CLIENT.deleteDataset(PROJECT, BIG_QUERY_DATASET_ID); + } + + @Test + public void testBatchFileLoadsWriteRead() { + String table = + String.format("%s:%s.%s", PROJECT, BIG_QUERY_DATASET_ID, testName.getMethodName()); + Map config = ImmutableMap.of("table", table); + + // file loads requires a GCS temp location + String tempLocation = writePipeline.getOptions().as(TestPipelineOptions.class).getTempRoot(); + writePipeline.getOptions().setTempLocation(tempLocation); + + // batch write + PCollectionRowTuple.of("input", getInput(writePipeline, false)) + .apply(Managed.write(Managed.BIGQUERY).withConfig(config)); + writePipeline.run().waitUntilFinish(); + + // read and validate + PCollection outputRows = + readPipeline + .apply(Managed.read(Managed.BIGQUERY).withConfig(config)) + .getSinglePCollection(); + PAssert.that(outputRows).containsInAnyOrder(ROWS); + readPipeline.run().waitUntilFinish(); + } + + @Test + public void testStreamingStorageWriteRead() { + String table = + String.format("%s:%s.%s", PROJECT, BIG_QUERY_DATASET_ID, testName.getMethodName()); + Map config = ImmutableMap.of("table", table); + + // streaming write + PCollectionRowTuple.of("input", getInput(writePipeline, true)) + .apply(Managed.write(Managed.BIGQUERY).withConfig(config)); + writePipeline.run().waitUntilFinish(); + + // read and validate + PCollection outputRows = + readPipeline + .apply(Managed.read(Managed.BIGQUERY).withConfig(config)) + .getSinglePCollection(); + PAssert.that(outputRows).containsInAnyOrder(ROWS); + readPipeline.run().waitUntilFinish(); + } + + public PCollection getInput(Pipeline p, boolean isStreaming) { + if (isStreaming) { + return p.apply( + PeriodicImpulse.create() + .startAt(new Instant(0)) + .stopAt(new Instant(19)) + .withInterval(Duration.millis(1))) + .apply( + MapElements.into(TypeDescriptors.rows()) + .via( + i -> + Row.withSchema(SCHEMA) + .withFieldValue("str", Long.toString(i.getMillis())) + .withFieldValue("number", i.getMillis()) + .build())) + .setRowSchema(SCHEMA); + } + return p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java new file mode 100644 index 000000000000..822c607aa3c9 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransform; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQuerySchemaTransformTranslation.BigQueryStorageReadSchemaTransformTranslator; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQuerySchemaTransformTranslation.BigQueryWriteSchemaTransformTranslator; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteSchemaTransformProvider.BigQueryWriteSchemaTransform; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.construction.BeamUrns; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class BigQuerySchemaTransformTranslationTest { + static final BigQueryWriteSchemaTransformProvider WRITE_PROVIDER = + new BigQueryWriteSchemaTransformProvider(); + static final BigQueryDirectReadSchemaTransformProvider READ_PROVIDER = + new BigQueryDirectReadSchemaTransformProvider(); + static final Row WRITE_CONFIG_ROW = + Row.withSchema(WRITE_PROVIDER.configurationSchema()) + .withFieldValue("table", "project:dataset.table") + .withFieldValue("create_disposition", "create_never") + .withFieldValue("write_disposition", "write_append") + .withFieldValue("triggering_frequency_seconds", 5L) + .withFieldValue("use_at_least_once_semantics", false) + .withFieldValue("auto_sharding", false) + .withFieldValue("num_streams", 5) + .withFieldValue("error_handling", null) + .build(); + static final Row READ_CONFIG_ROW = + Row.withSchema(READ_PROVIDER.configurationSchema()) + .withFieldValue("query", null) + .withFieldValue("table_spec", "apache-beam-testing.samples.weather_stations") + .withFieldValue("row_restriction", "col < 5") + .withFieldValue("selected_fields", Arrays.asList("col1", "col2", "col3")) + .build(); + + @Test + public void testRecreateWriteTransformFromRow() { + BigQueryWriteSchemaTransform writeTransform = + (BigQueryWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG_ROW); + + BigQueryWriteSchemaTransformTranslator translator = + new BigQueryWriteSchemaTransformTranslator(); + Row translatedRow = translator.toConfigRow(writeTransform); + + BigQueryWriteSchemaTransform writeTransformFromRow = + translator.fromConfigRow(translatedRow, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG_ROW, writeTransformFromRow.getConfigurationRow()); + } + + @Test + public void testWriteTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + Schema inputSchema = Schema.builder().addByteArrayField("b").build(); + PCollection input = + p.apply( + Create.of( + Collections.singletonList( + Row.withSchema(inputSchema).addValue(new byte[] {1, 2, 3}).build()))) + .setRowSchema(inputSchema); + + BigQueryWriteSchemaTransform writeTransform = + (BigQueryWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG_ROW); + PCollectionRowTuple.of("input", input).apply(writeTransform); + + // Then translate the pipeline to a proto and extract KafkaWriteSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(WRITE_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, writeTransformProto.size()); + RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + + assertEquals(WRITE_CONFIG_ROW, rowFromSpec); + + // Use the information in the proto to recreate the KafkaWriteSchemaTransform + BigQueryWriteSchemaTransformTranslator translator = + new BigQueryWriteSchemaTransformTranslator(); + BigQueryWriteSchemaTransform writeTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG_ROW, writeTransformFromSpec.getConfigurationRow()); + } + + @Test + public void testReCreateReadTransformFromRow() { + BigQueryDirectReadSchemaTransform readTransform = + (BigQueryDirectReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG_ROW); + + BigQueryStorageReadSchemaTransformTranslator translator = + new BigQueryStorageReadSchemaTransformTranslator(); + Row row = translator.toConfigRow(readTransform); + + BigQueryDirectReadSchemaTransform readTransformFromRow = + translator.fromConfigRow(row, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG_ROW, readTransformFromRow.getConfigurationRow()); + } + + @Test + public void testReadTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + + BigQueryDirectReadSchemaTransform readTransform = + (BigQueryDirectReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG_ROW); + + PCollectionRowTuple.empty(p).apply(readTransform); + + // Then translate the pipeline to a proto and extract KafkaReadSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List readTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(READ_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, readTransformProto.size()); + RunnerApi.FunctionSpec spec = readTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(READ_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + assertEquals(READ_CONFIG_ROW, rowFromSpec); + + // Use the information in the proto to recreate the KafkaReadSchemaTransform + BigQueryStorageReadSchemaTransformTranslator translator = + new BigQueryStorageReadSchemaTransformTranslator(); + BigQueryDirectReadSchemaTransform readTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG_ROW, readTransformFromSpec.getConfigurationRow()); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java index 87ba2961461a..7b59552bbbe4 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; @@ -32,13 +34,14 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; +import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransform; -import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.metrics.MetricNameFilter; import org.apache.beam.sdk.metrics.MetricQueryResults; import org.apache.beam.sdk.metrics.MetricResult; @@ -50,13 +53,16 @@ import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.util.construction.PipelineTranslation; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -108,15 +114,14 @@ public void setUp() throws Exception { @Test public void testInvalidConfig() { - List invalidConfigs = + List invalidConfigs = Arrays.asList( - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() - .setTable("not_a_valid_table_spec"), - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration.builder().setTable("not_a_valid_table_spec"), + BigQueryWriteConfiguration.builder() .setTable("project:dataset.table") .setCreateDisposition("INVALID_DISPOSITION")); - for (BigQueryStorageWriteApiSchemaTransformConfiguration.Builder config : invalidConfigs) { + for (BigQueryWriteConfiguration.Builder config : invalidConfigs) { assertThrows( Exception.class, () -> { @@ -125,13 +130,11 @@ public void testInvalidConfig() { } } - public PCollectionRowTuple runWithConfig( - BigQueryStorageWriteApiSchemaTransformConfiguration config) { + public PCollectionRowTuple runWithConfig(BigQueryWriteConfiguration config) { return runWithConfig(config, ROWS); } - public PCollectionRowTuple runWithConfig( - BigQueryStorageWriteApiSchemaTransformConfiguration config, List inputRows) { + public PCollectionRowTuple runWithConfig(BigQueryWriteConfiguration config, List inputRows) { BigQueryStorageWriteApiSchemaTransformProvider provider = new BigQueryStorageWriteApiSchemaTransformProvider(); @@ -176,8 +179,8 @@ public boolean rowEquals(Row expectedRow, TableRow actualRow) { @Test public void testSimpleWrite() throws Exception { String tableSpec = "project:dataset.simple_write"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder().setTable(tableSpec).build(); runWithConfig(config, ROWS); p.run().waitUntilFinish(); @@ -189,9 +192,9 @@ public void testSimpleWrite() throws Exception { @Test public void testWriteToDynamicDestinations() throws Exception { - String dynamic = BigQueryStorageWriteApiSchemaTransformProvider.DYNAMIC_DESTINATIONS; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(dynamic).build(); + String dynamic = BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder().setTable(dynamic).build(); String baseTableSpec = "project:dataset.dynamic_write_"; @@ -273,8 +276,8 @@ public void testCDCWrites() throws Exception { String tableSpec = "project:dataset.cdc_write"; List primaryKeyColumns = ImmutableList.of("name"); - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setUseAtLeastOnceSemantics(true) .setTable(tableSpec) .setUseCdcWrites(true) @@ -304,9 +307,9 @@ public void testCDCWrites() throws Exception { @Test public void testCDCWriteToDynamicDestinations() throws Exception { List primaryKeyColumns = ImmutableList.of("name"); - String dynamic = BigQueryStorageWriteApiSchemaTransformProvider.DYNAMIC_DESTINATIONS; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + String dynamic = BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setUseAtLeastOnceSemantics(true) .setTable(dynamic) .setUseCdcWrites(true) @@ -338,8 +341,8 @@ public void testCDCWriteToDynamicDestinations() throws Exception { @Test public void testInputElementCount() throws Exception { String tableSpec = "project:dataset.input_count"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder().setTable(tableSpec).build(); runWithConfig(config); PipelineResult result = p.run(); @@ -368,13 +371,11 @@ public void testInputElementCount() throws Exception { @Test public void testFailedRows() throws Exception { String tableSpec = "project:dataset.write_with_fail"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setTable(tableSpec) .setErrorHandling( - BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder() - .setOutput("FailedRows") - .build()) + BigQueryWriteConfiguration.ErrorHandling.builder().setOutput("FailedRows").build()) .build(); String failValue = "fail_me"; @@ -420,13 +421,11 @@ public void testFailedRows() throws Exception { @Test public void testErrorCount() throws Exception { String tableSpec = "project:dataset.error_count"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setTable(tableSpec) .setErrorHandling( - BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder() - .setOutput("FailedRows") - .build()) + BigQueryWriteConfiguration.ErrorHandling.builder().setOutput("FailedRows").build()) .build(); Function shouldFailRow = @@ -456,4 +455,24 @@ public void testErrorCount() throws Exception { assertEquals(expectedCount, count.getAttempted()); } } + + @Test + public void testManagedChoosesStorageApiForUnboundedWrites() { + PCollection batchInput = + p.apply(TestStream.create(SCHEMA).addElements(ROWS.get(0)).advanceWatermarkToInfinity()); + batchInput.apply( + Managed.write(Managed.BIGQUERY) + .withConfig(ImmutableMap.of("table", "project.dataset.table"))); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> + tr.getUniqueName() + .contains(BigQueryStorageWriteApiSchemaTransform.class.getSimpleName())) + .collect(Collectors.toList()); + assertThat(writeTransformProto.size(), greaterThan(0)); + p.enableAbandonedNodeEnforcement(false); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java index d4fcf6153e47..73328afb397b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java @@ -137,7 +137,7 @@ public final void attemptsExhaustedForRetryableError() throws Exception { FlushBuffer> flushBuffer = spy(newFlushBuffer(rpcQosOptions)); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); when(flushBuffer.offer(element1)).thenReturn(true); when(flushBuffer.iterator()).thenReturn(newArrayList(element1).iterator()); when(flushBuffer.getBufferedElementsCount()).thenReturn(1); @@ -224,7 +224,7 @@ public final void endToEnd_success() throws Exception { FlushBuffer> flushBuffer = spy(newFlushBuffer(options)); when(processContext.element()).thenReturn(write); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(BatchWriteRequest.class); when(callable.call(requestCaptor.capture())).thenReturn(response); @@ -267,7 +267,7 @@ public final void endToEnd_exhaustingAttemptsResultsInException() throws Excepti FlushBuffer> flushBuffer = spy(newFlushBuffer(rpcQosOptions)); when(processContext.element()).thenReturn(write); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); when(flushBuffer.isFull()).thenReturn(true); when(flushBuffer.offer(element1)).thenReturn(true); when(flushBuffer.iterator()).thenReturn(newArrayList(element1).iterator()); @@ -324,14 +324,14 @@ public final void endToEnd_awaitSafeToProceed_falseIsTerminalForAttempt() throws when(attempt2.awaitSafeToProceed(any())) .thenReturn(true) .thenThrow(new IllegalStateException("too many attempt2#awaitSafeToProceed")); - when(attempt2.>newFlushBuffer(any())) + when(attempt2.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); // finish bundle attempt RpcQos.RpcWriteAttempt finishBundleAttempt = mock(RpcWriteAttempt.class); when(finishBundleAttempt.awaitSafeToProceed(any())) .thenReturn(true, true) .thenThrow(new IllegalStateException("too many finishBundleAttempt#awaitSafeToProceed")); - when(finishBundleAttempt.>newFlushBuffer(any())) + when(finishBundleAttempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); when(rpcQos.newWriteAttempt(any())).thenReturn(attempt, attempt2, finishBundleAttempt); when(callable.call(requestCaptor.capture())).thenReturn(response); @@ -519,20 +519,15 @@ public final void endToEnd_maxBatchSizeRespected() throws Exception { when(attempt.awaitSafeToProceed(any())).thenReturn(true); when(attempt2.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(enqueue0)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue1)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue2)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue3)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue4)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(enqueue0)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue1)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue2)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue3)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue4)).thenReturn(flushBuffer); when(callable.call(expectedGroup1Request)).thenReturn(group1Response); - when(attempt2.>newFlushBuffer(enqueue5)) - .thenReturn(newFlushBuffer(options)); - when(attempt2.>newFlushBuffer(finalFlush)).thenReturn(flushBuffer2); + when(attempt2.>newFlushBuffer(enqueue5)).thenReturn(newFlushBuffer(options)); + when(attempt2.>newFlushBuffer(finalFlush)).thenReturn(flushBuffer2); when(callable.call(expectedGroup2Request)).thenReturn(group2Response); runFunction( @@ -603,7 +598,7 @@ public final void endToEnd_partialSuccessReturnsWritesToQueue() throws Exception when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); when(attempt.isCodeRetryable(Code.INVALID_ARGUMENT)).thenReturn(true); when(attempt.isCodeRetryable(Code.FAILED_PRECONDITION)).thenReturn(true); @@ -673,9 +668,9 @@ public final void writesRemainInQueueWhenFlushIsNotReadyAndThenFlushesInFinishBu .thenThrow(new IllegalStateException("too many attempt calls")); when(attempt.awaitSafeToProceed(any())).thenReturn(true); when(attempt2.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); - when(attempt2.>newFlushBuffer(any())) + when(attempt2.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); FnT fn = getFn(clock, ff, options, CounterFactory.DEFAULT, DistributionFactory.DEFAULT); @@ -723,7 +718,7 @@ public final void queuedWritesMaintainPriorityIfNotFlushed() throws Exception { when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); FnT fn = getFn(clock, ff, options, CounterFactory.DEFAULT, DistributionFactory.DEFAULT); @@ -779,7 +774,7 @@ protected final void processElementsAndFinishBundle(FnT fn, int processElementCo } } - protected FlushBufferImpl> newFlushBuffer(RpcQosOptions options) { + protected FlushBufferImpl> newFlushBuffer(RpcQosOptions options) { return new FlushBufferImpl<>(options.getBatchMaxCount(), options.getBatchMaxBytes()); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java index d59b9354bd8b..2948be7658a9 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java @@ -177,7 +177,7 @@ public void nonRetryableWriteIsOutput() throws Exception { when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenReturn(newFlushBuffer(options)) .thenReturn(newFlushBuffer(options)) .thenThrow(new IllegalStateException("too many attempt#newFlushBuffer calls")); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java index 9acc3707e3ba..70c4ce5046a5 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java @@ -190,7 +190,7 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception { when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenReturn(newFlushBuffer(options)) .thenReturn(newFlushBuffer(options)) .thenThrow(new IllegalStateException("too many attempt#newFlushBuffer calls")); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java index bbf3e135e43f..7e24888ace43 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java @@ -236,7 +236,7 @@ private void safeToProceedAndWithBudgetAndWrite( assertTrue( msg(description, t, "awaitSafeToProceed was false, expected true"), attempt.awaitSafeToProceed(t)); - FlushBufferImpl> buffer = attempt.newFlushBuffer(t); + FlushBufferImpl> buffer = attempt.newFlushBuffer(t); assertEquals( msg(description, t, "unexpected batchMaxCount"), expectedBatchMaxCount, diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java index 2f3724d6bae7..9dff65bf2f63 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java @@ -455,7 +455,7 @@ public void offerOfElementWhichWouldCrossMaxBytesReturnFalse() { @Test public void flushBuffer_doesNotErrorWhenMaxIsOne() { - FlushBufferImpl> buffer = new FlushBufferImpl<>(1, 1000); + FlushBufferImpl> buffer = new FlushBufferImpl<>(1, 1000); assertTrue(buffer.offer(new FixedSerializationSize<>("a", 1))); assertFalse(buffer.offer(new FixedSerializationSize<>("b", 1))); assertEquals(1, buffer.getBufferedElementsCount()); @@ -463,7 +463,7 @@ public void flushBuffer_doesNotErrorWhenMaxIsOne() { @Test public void flushBuffer_doesNotErrorWhenMaxIsZero() { - FlushBufferImpl> buffer = new FlushBufferImpl<>(0, 1000); + FlushBufferImpl> buffer = new FlushBufferImpl<>(0, 1000); assertFalse(buffer.offer(new FixedSerializationSize<>("a", 1))); assertEquals(0, buffer.getBufferedElementsCount()); assertFalse(buffer.isFull()); @@ -703,7 +703,7 @@ private void doTest_initialBatchSizeRelativeToWorkerCount( .build(); RpcQosImpl qos = new RpcQosImpl(options, random, sleeper, counterFactory, distributionFactory); RpcWriteAttemptImpl attempt = qos.newWriteAttempt(RPC_ATTEMPT_CONTEXT); - FlushBufferImpl> buffer = attempt.newFlushBuffer(Instant.EPOCH); + FlushBufferImpl> buffer = attempt.newFlushBuffer(Instant.EPOCH); assertEquals(expectedBatchMaxCount, buffer.nextBatchMaxCount); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClientTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClientTest.java index 3724e169c612..6c4625f2e077 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClientTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubGrpcClientTest.java @@ -40,6 +40,7 @@ import com.google.pubsub.v1.Topic; import io.grpc.ManagedChannel; import io.grpc.Server; +import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; @@ -432,4 +433,43 @@ public void getSchema(GetSchemaRequest request, StreamObserver responseO server.shutdownNow(); } } + + @Test + public void isTopicExists() throws IOException { + initializeClient(null, null); + TopicPath topicDoesNotExist = + PubsubClient.topicPathFromPath("projects/testProject/topics/dontexist"); + TopicPath topicExists = PubsubClient.topicPathFromPath("projects/testProject/topics/exist"); + + PublisherImplBase publisherImplBase = + new PublisherImplBase() { + @Override + public void getTopic(GetTopicRequest request, StreamObserver responseObserver) { + String topicPath = request.getTopic(); + if (topicPath.equals(topicDoesNotExist.getPath())) { + responseObserver.onError( + new StatusRuntimeException(Status.fromCode(Status.Code.NOT_FOUND))); + } + if (topicPath.equals(topicExists.getPath())) { + responseObserver.onNext( + Topic.newBuilder() + .setName(topicPath) + .setSchemaSettings( + SchemaSettings.newBuilder().setSchema(SCHEMA.getPath()).build()) + .build()); + responseObserver.onCompleted(); + } + } + }; + Server server = + InProcessServerBuilder.forName(channelName).addService(publisherImplBase).build().start(); + try { + assertEquals(false, client.isTopicExists(topicDoesNotExist)); + + assertEquals(true, client.isTopicExists(topicExists)); + + } finally { + server.shutdownNow(); + } + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java index d4effbae40a4..bec157ae83cc 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIOTest.java @@ -83,6 +83,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.commons.lang3.RandomStringUtils; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; @@ -97,6 +98,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.junit.runners.model.Statement; +import org.mockito.Mockito; /** Tests for PubsubIO Read and Write transforms. */ @RunWith(JUnit4.class) @@ -928,4 +930,172 @@ public void testBigMessageBounded() throws IOException { pipeline.run(); } } + + @Test + public void testReadValidate() throws IOException { + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath existingTopic = PubsubClient.topicPathFromName("test-project", "testTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(existingTopic)).thenReturn(true); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + Read read = + Read.newBuilder() + .setTopicProvider( + StaticValueProvider.of( + PubsubIO.PubsubTopic.fromPath("projects/test-project/topics/testTopic"))) + .setTimestampAttribute("myTimestamp") + .setIdAttribute("myId") + .setPubsubClientFactory(mockFactory) + .setCoder(PubsubMessagePayloadOnlyCoder.of()) + .setValidate(true) + .build(); + + read.validate(options); + } + + @Test + public void testReadValidateTopicIsNotExists() throws Exception { + thrown.expect(IllegalArgumentException.class); + + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath nonExistingTopic = PubsubClient.topicPathFromName("test-project", "nonExistingTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(nonExistingTopic)).thenReturn(false); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + Read read = + Read.newBuilder() + .setTopicProvider( + StaticValueProvider.of( + PubsubIO.PubsubTopic.fromPath("projects/test-project/topics/nonExistingTopic"))) + .setTimestampAttribute("myTimestamp") + .setIdAttribute("myId") + .setPubsubClientFactory(mockFactory) + .setCoder(PubsubMessagePayloadOnlyCoder.of()) + .setValidate(true) + .build(); + + read.validate(options); + } + + @Test + public void testReadWithoutValidation() throws IOException { + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath nonExistingTopic = PubsubClient.topicPathFromName("test-project", "nonExistingTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(nonExistingTopic)).thenReturn(false); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + Read read = + PubsubIO.readMessages().fromTopic("projects/test-project/topics/nonExistingTopic"); + + read.validate(options); + } + + @Test + public void testWriteTopicValidationSuccess() throws Exception { + PubsubIO.writeStrings().to("projects/my-project/topics/abc"); + PubsubIO.writeStrings().to("projects/my-project/topics/ABC"); + PubsubIO.writeStrings().to("projects/my-project/topics/AbC-DeF"); + PubsubIO.writeStrings().to("projects/my-project/topics/AbC-1234"); + PubsubIO.writeStrings().to("projects/my-project/topics/AbC-1234-_.~%+-_.~%+-_.~%+-abc"); + PubsubIO.writeStrings() + .to( + new StringBuilder() + .append("projects/my-project/topics/A-really-long-one-") + .append(RandomStringUtils.randomAlphanumeric(100)) + .toString()); + } + + @Test + public void testWriteTopicValidationBadCharacter() throws Exception { + thrown.expect(IllegalArgumentException.class); + PubsubIO.writeStrings().to("projects/my-project/topics/abc-*-abc"); + } + + @Test + public void testWriteValidationTooLong() throws Exception { + thrown.expect(IllegalArgumentException.class); + PubsubIO.writeStrings() + .to( + new StringBuilder() + .append("projects/my-project/topics/A-really-long-one-") + .append(RandomStringUtils.randomAlphanumeric(1000)) + .toString()); + } + + @Test + public void testWriteValidate() throws IOException { + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath existingTopic = PubsubClient.topicPathFromName("test-project", "testTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(existingTopic)).thenReturn(true); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + PubsubIO.Write write = + PubsubIO.Write.newBuilder() + .setTopicProvider( + StaticValueProvider.of( + PubsubIO.PubsubTopic.fromPath("projects/test-project/topics/testTopic"))) + .setTimestampAttribute("myTimestamp") + .setIdAttribute("myId") + .setDynamicDestinations(false) + .setPubsubClientFactory(mockFactory) + .setValidate(true) + .build(); + + write.validate(options); + } + + @Test + public void testWriteValidateTopicIsNotExists() throws Exception { + thrown.expect(IllegalArgumentException.class); + + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath nonExistingTopic = PubsubClient.topicPathFromName("test-project", "nonExistingTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(nonExistingTopic)).thenReturn(false); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + PubsubIO.Write write = + PubsubIO.Write.newBuilder() + .setTopicProvider( + StaticValueProvider.of( + PubsubIO.PubsubTopic.fromPath("projects/test-project/topics/nonExistingTopic"))) + .setTimestampAttribute("myTimestamp") + .setIdAttribute("myId") + .setDynamicDestinations(false) + .setPubsubClientFactory(mockFactory) + .setValidate(true) + .build(); + + write.validate(options); + } + + @Test + public void testWithoutValidation() throws IOException { + PubsubOptions options = TestPipeline.testingPipelineOptions().as(PubsubOptions.class); + TopicPath nonExistingTopic = PubsubClient.topicPathFromName("test-project", "nonExistingTopic"); + PubsubClient mockClient = Mockito.mock(PubsubClient.class); + Mockito.when(mockClient.isTopicExists(nonExistingTopic)).thenReturn(false); + PubsubClient.PubsubClientFactory mockFactory = + Mockito.mock(PubsubClient.PubsubClientFactory.class); + Mockito.when(mockFactory.newClient("myTimestamp", "myId", options)).thenReturn(mockClient); + + PubsubIO.Write write = + PubsubIO.writeMessages().to("projects/test-project/topics/nonExistingTopic"); + + write.validate(options); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClientTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClientTest.java index 634ad42c937a..49681c86257e 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClientTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubJsonClientTest.java @@ -23,6 +23,10 @@ import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.when; +import com.google.api.client.googleapis.json.GoogleJsonError; +import com.google.api.client.googleapis.json.GoogleJsonResponseException; +import com.google.api.client.http.HttpHeaders; +import com.google.api.client.http.HttpResponseException; import com.google.api.services.pubsub.Pubsub; import com.google.api.services.pubsub.Pubsub.Projects.Subscriptions; import com.google.api.services.pubsub.Pubsub.Projects.Topics; @@ -425,4 +429,24 @@ public void getProtoSchema() throws IOException { IllegalArgumentException.class, () -> client.getSchema(SCHEMA)); } + + @Test + public void isTopicExists() throws Exception { + TopicPath topicExists = + PubsubClient.topicPathFromPath("projects/testProject/topics/topicExists"); + TopicPath topicDoesNotExist = + PubsubClient.topicPathFromPath("projects/testProject/topics/topicDoesNotExist"); + HttpResponseException.Builder builder = + new HttpResponseException.Builder(404, "topic is not found", new HttpHeaders()); + GoogleJsonError error = new GoogleJsonError(); + when(mockPubsub.projects().topics().get(topicExists.getPath()).execute()) + .thenReturn(new Topic().setName(topicExists.getName())); + when(mockPubsub.projects().topics().get(topicDoesNotExist.getPath()).execute()) + .thenThrow(new GoogleJsonResponseException(builder, error)); + + client = new PubsubJsonClient(null, null, mockPubsub); + + assertEquals(true, client.isTopicExists(topicExists)); + assertEquals(false, client.isTopicExists(topicDoesNotExist)); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java index 3752c2fb3afc..02b9d111583b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java @@ -33,7 +33,9 @@ import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SpannerException; import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Iterator; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -58,6 +60,8 @@ public class PartitionMetadataAdminDaoTest { private static final String DATABASE_ID = "SPANNER_DATABASE"; private static final String TABLE_NAME = "SPANNER_TABLE"; + private static final String WATERMARK_INDEX_NAME = "WATERMARK_INDEX"; + private static final String CREATED_AT_INDEX_NAME = "CREATED_AT_INDEX"; private static final int TIMEOUT_MINUTES = 10; @@ -68,12 +72,14 @@ public class PartitionMetadataAdminDaoTest { @Before public void setUp() { databaseAdminClient = mock(DatabaseAdminClient.class); + PartitionMetadataTableNames names = + new PartitionMetadataTableNames(TABLE_NAME, WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME); partitionMetadataAdminDao = new PartitionMetadataAdminDao( - databaseAdminClient, INSTANCE_ID, DATABASE_ID, TABLE_NAME, Dialect.GOOGLE_STANDARD_SQL); + databaseAdminClient, INSTANCE_ID, DATABASE_ID, names, Dialect.GOOGLE_STANDARD_SQL); partitionMetadataAdminDaoPostgres = new PartitionMetadataAdminDao( - databaseAdminClient, INSTANCE_ID, DATABASE_ID, TABLE_NAME, Dialect.POSTGRESQL); + databaseAdminClient, INSTANCE_ID, DATABASE_ID, names, Dialect.POSTGRESQL); op = (OperationFuture) mock(OperationFuture.class); statements = ArgumentCaptor.forClass(Iterable.class); when(databaseAdminClient.updateDatabaseDdl( @@ -89,9 +95,9 @@ public void testCreatePartitionMetadataTable() throws Exception { .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); Iterator it = statements.getValue().iterator(); - assertTrue(it.next().contains("CREATE TABLE")); - assertTrue(it.next().contains("CREATE INDEX")); - assertTrue(it.next().contains("CREATE INDEX")); + assertTrue(it.next().contains("CREATE TABLE IF NOT EXISTS")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS")); } @Test @@ -102,9 +108,9 @@ public void testCreatePartitionMetadataTablePostgres() throws Exception { .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); Iterator it = statements.getValue().iterator(); - assertTrue(it.next().contains("CREATE TABLE \"")); - assertTrue(it.next().contains("CREATE INDEX \"")); - assertTrue(it.next().contains("CREATE INDEX \"")); + assertTrue(it.next().contains("CREATE TABLE IF NOT EXISTS \"")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS \"")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS \"")); } @Test @@ -133,7 +139,8 @@ public void testCreatePartitionMetadataTableWithInterruptedException() throws Ex @Test public void testDeletePartitionMetadataTable() throws Exception { when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); verify(databaseAdminClient, times(1)) .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); @@ -143,10 +150,22 @@ public void testDeletePartitionMetadataTable() throws Exception { assertTrue(it.next().contains("DROP TABLE")); } + @Test + public void testDeletePartitionMetadataTableWithNoIndexes() throws Exception { + when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); + partitionMetadataAdminDao.deletePartitionMetadataTable(Collections.emptyList()); + verify(databaseAdminClient, times(1)) + .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); + assertEquals(1, ((Collection) statements.getValue()).size()); + Iterator it = statements.getValue().iterator(); + assertTrue(it.next().contains("DROP TABLE")); + } + @Test public void testDeletePartitionMetadataTablePostgres() throws Exception { when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); - partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable(); + partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); verify(databaseAdminClient, times(1)) .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); @@ -156,11 +175,23 @@ public void testDeletePartitionMetadataTablePostgres() throws Exception { assertTrue(it.next().contains("DROP TABLE \"")); } + @Test + public void testDeletePartitionMetadataTablePostgresWithNoIndexes() throws Exception { + when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); + partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable(Collections.emptyList()); + verify(databaseAdminClient, times(1)) + .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); + assertEquals(1, ((Collection) statements.getValue()).size()); + Iterator it = statements.getValue().iterator(); + assertTrue(it.next().contains("DROP TABLE \"")); + } + @Test public void testDeletePartitionMetadataTableWithTimeoutException() throws Exception { when(op.get(10, TimeUnit.MINUTES)).thenThrow(new TimeoutException(TIMED_OUT)); try { - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); fail(); } catch (SpannerException e) { assertTrue(e.getMessage().contains(TIMED_OUT)); @@ -171,7 +202,8 @@ public void testDeletePartitionMetadataTableWithTimeoutException() throws Except public void testDeletePartitionMetadataTableWithInterruptedException() throws Exception { when(op.get(10, TimeUnit.MINUTES)).thenThrow(new InterruptedException(INTERRUPTED)); try { - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); fail(); } catch (SpannerException e) { assertEquals(ErrorCode.CANCELLED, e.getErrorCode()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java new file mode 100644 index 000000000000..2aae5b26a2cb --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.dao; + +import static org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.PartitionMetadataTableNames.MAX_NAME_LENGTH; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class PartitionMetadataTableNamesTest { + @Test + public void testGeneratePartitionMetadataNamesRemovesHyphens() { + String databaseId = "my-database-id-12345"; + + PartitionMetadataTableNames names1 = PartitionMetadataTableNames.generateRandom(databaseId); + assertFalse(names1.getTableName().contains("-")); + assertFalse(names1.getWatermarkIndexName().contains("-")); + assertFalse(names1.getCreatedAtIndexName().contains("-")); + + PartitionMetadataTableNames names2 = PartitionMetadataTableNames.generateRandom(databaseId); + assertNotEquals(names1.getTableName(), names2.getTableName()); + assertNotEquals(names1.getWatermarkIndexName(), names2.getWatermarkIndexName()); + assertNotEquals(names1.getCreatedAtIndexName(), names2.getCreatedAtIndexName()); + } + + @Test + public void testGeneratePartitionMetadataNamesIsShorterThan64Characters() { + PartitionMetadataTableNames names = + PartitionMetadataTableNames.generateRandom( + "my-database-id-larger-than-maximum-length-1234567890-1234567890-1234567890"); + assertTrue(names.getTableName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getWatermarkIndexName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getCreatedAtIndexName().length() <= MAX_NAME_LENGTH); + + names = PartitionMetadataTableNames.generateRandom("d"); + assertTrue(names.getTableName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getWatermarkIndexName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getCreatedAtIndexName().length() <= MAX_NAME_LENGTH); + } + + @Test + public void testPartitionMetadataNamesFromExistingTable() { + PartitionMetadataTableNames names1 = + PartitionMetadataTableNames.fromExistingTable("databaseid", "mytable"); + assertEquals("mytable", names1.getTableName()); + assertFalse(names1.getWatermarkIndexName().contains("-")); + assertFalse(names1.getCreatedAtIndexName().contains("-")); + + PartitionMetadataTableNames names2 = + PartitionMetadataTableNames.fromExistingTable("databaseid", "mytable"); + assertEquals("mytable", names2.getTableName()); + assertNotEquals(names1.getWatermarkIndexName(), names2.getWatermarkIndexName()); + assertNotEquals(names1.getCreatedAtIndexName(), names2.getCreatedAtIndexName()); + } +} diff --git a/sdks/java/io/hadoop-common/build.gradle b/sdks/java/io/hadoop-common/build.gradle index 466aa8fb6730..b0303d29ff98 100644 --- a/sdks/java/io/hadoop-common/build.gradle +++ b/sdks/java/io/hadoop-common/build.gradle @@ -25,10 +25,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Hadoop Common" ext.summary = "Library to add shared Hadoop classes among Beam IOs." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hadoop-file-system/build.gradle b/sdks/java/io/hadoop-file-system/build.gradle index 3fc872bb5d02..fafa8b5c7e34 100644 --- a/sdks/java/io/hadoop-file-system/build.gradle +++ b/sdks/java/io/hadoop-file-system/build.gradle @@ -26,10 +26,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Hadoop File System" ext.summary = "Library to read and write Hadoop/HDFS file formats from Beam." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hadoop-format/build.gradle b/sdks/java/io/hadoop-format/build.gradle index dbb9f8fdd73d..4664005a1fc8 100644 --- a/sdks/java/io/hadoop-format/build.gradle +++ b/sdks/java/io/hadoop-format/build.gradle @@ -30,10 +30,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Hadoop Format" ext.summary = "IO to read data from sources and to write data to sinks that implement Hadoop MapReduce Format." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle index c4f1b76ec390..364c10fa738b 100644 --- a/sdks/java/io/hcatalog/build.gradle +++ b/sdks/java/io/hcatalog/build.gradle @@ -30,9 +30,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: HCatalog" ext.summary = "IO to read and write for HCatalog source." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", + "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/iceberg/build.gradle b/sdks/java/io/iceberg/build.gradle index e10c6f38e20f..6754b0aecf50 100644 --- a/sdks/java/io/iceberg/build.gradle +++ b/sdks/java/io/iceberg/build.gradle @@ -29,10 +29,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Iceberg" ext.summary = "Integration with Iceberg data warehouses." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", - "2102": "2.10.2", - "324": "3.2.4", + "2102": "2.10.2", + "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java index defe4f2a603d..d9768114e7c6 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java @@ -17,6 +17,12 @@ */ package org.apache.beam.sdk.io.iceberg; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.metrics.Counter; @@ -29,14 +35,21 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.iceberg.AppendFiles; import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Snapshot; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,9 +58,11 @@ class AppendFilesToTables extends PTransform, PCollection>> { private static final Logger LOG = LoggerFactory.getLogger(AppendFilesToTables.class); private final IcebergCatalogConfig catalogConfig; + private final String manifestFilePrefix; - AppendFilesToTables(IcebergCatalogConfig catalogConfig) { + AppendFilesToTables(IcebergCatalogConfig catalogConfig, String manifestFilePrefix) { this.catalogConfig = catalogConfig; + this.manifestFilePrefix = manifestFilePrefix; } @Override @@ -67,7 +82,7 @@ public String apply(FileWriteResult input) { .apply("Group metadata updates by table", GroupByKey.create()) .apply( "Append metadata updates to tables", - ParDo.of(new AppendFilesToTablesDoFn(catalogConfig))) + ParDo.of(new AppendFilesToTablesDoFn(catalogConfig, manifestFilePrefix))) .setCoder(KvCoder.of(StringUtf8Coder.of(), SnapshotInfo.CODER)); } @@ -75,19 +90,19 @@ private static class AppendFilesToTablesDoFn extends DoFn>, KV> { private final Counter snapshotsCreated = Metrics.counter(AppendFilesToTables.class, "snapshotsCreated"); - private final Counter dataFilesCommitted = - Metrics.counter(AppendFilesToTables.class, "dataFilesCommitted"); private final Distribution committedDataFileByteSize = Metrics.distribution(RecordWriter.class, "committedDataFileByteSize"); private final Distribution committedDataFileRecordCount = Metrics.distribution(RecordWriter.class, "committedDataFileRecordCount"); private final IcebergCatalogConfig catalogConfig; + private final String manifestFilePrefix; private transient @MonotonicNonNull Catalog catalog; - private AppendFilesToTablesDoFn(IcebergCatalogConfig catalogConfig) { + private AppendFilesToTablesDoFn(IcebergCatalogConfig catalogConfig, String manifestFilePrefix) { this.catalogConfig = catalogConfig; + this.manifestFilePrefix = manifestFilePrefix; } private Catalog getCatalog() { @@ -97,11 +112,22 @@ private Catalog getCatalog() { return catalog; } + private boolean containsMultiplePartitionSpecs(Iterable fileWriteResults) { + int id = fileWriteResults.iterator().next().getSerializableDataFile().getPartitionSpecId(); + for (FileWriteResult result : fileWriteResults) { + if (id != result.getSerializableDataFile().getPartitionSpecId()) { + return true; + } + } + return false; + } + @ProcessElement public void processElement( @Element KV> element, OutputReceiver> out, - BoundedWindow window) { + BoundedWindow window) + throws IOException { String tableStringIdentifier = element.getKey(); Iterable fileWriteResults = element.getValue(); if (!fileWriteResults.iterator().hasNext()) { @@ -109,24 +135,81 @@ public void processElement( } Table table = getCatalog().loadTable(TableIdentifier.parse(element.getKey())); + + // vast majority of the time, we will simply append data files. + // in the rare case we get a batch that contains multiple partition specs, we will group + // data into manifest files and append. + // note: either way, we must use a single commit operation for atomicity. + if (containsMultiplePartitionSpecs(fileWriteResults)) { + appendManifestFiles(table, fileWriteResults); + } else { + appendDataFiles(table, fileWriteResults); + } + + Snapshot snapshot = table.currentSnapshot(); + LOG.info("Created new snapshot for table '{}': {}", tableStringIdentifier, snapshot); + snapshotsCreated.inc(); + out.outputWithTimestamp( + KV.of(element.getKey(), SnapshotInfo.fromSnapshot(snapshot)), window.maxTimestamp()); + } + + // This works only when all files are using the same partition spec. + private void appendDataFiles(Table table, Iterable fileWriteResults) { AppendFiles update = table.newAppend(); - long numFiles = 0; for (FileWriteResult result : fileWriteResults) { - DataFile dataFile = result.getDataFile(table.spec()); + DataFile dataFile = result.getDataFile(table.specs()); update.appendFile(dataFile); committedDataFileByteSize.update(dataFile.fileSizeInBytes()); committedDataFileRecordCount.update(dataFile.recordCount()); - numFiles++; } - // this commit will create a ManifestFile. we don't need to manually create one. update.commit(); - dataFilesCommitted.inc(numFiles); + } - Snapshot snapshot = table.currentSnapshot(); - LOG.info("Created new snapshot for table '{}': {}", tableStringIdentifier, snapshot); - snapshotsCreated.inc(); - out.outputWithTimestamp( - KV.of(element.getKey(), SnapshotInfo.fromSnapshot(snapshot)), window.maxTimestamp()); + // When a user updates their table partition spec during runtime, we can end up with + // a batch of files where some are written with the old spec and some are written with the new + // spec. + // A table commit is limited to a single partition spec. + // To handle this, we create a manifest file for each partition spec, and group data files + // accordingly. + // Afterward, we append all manifests using a single commit operation. + private void appendManifestFiles(Table table, Iterable fileWriteResults) + throws IOException { + String uuid = UUID.randomUUID().toString(); + Map specs = table.specs(); + + Map> dataFilesBySpec = new HashMap<>(); + for (FileWriteResult result : fileWriteResults) { + DataFile dataFile = result.getDataFile(specs); + dataFilesBySpec.computeIfAbsent(dataFile.specId(), i -> new ArrayList<>()).add(dataFile); + } + + AppendFiles update = table.newAppend(); + for (Map.Entry> entry : dataFilesBySpec.entrySet()) { + int specId = entry.getKey(); + List files = entry.getValue(); + PartitionSpec spec = Preconditions.checkStateNotNull(specs.get(specId)); + ManifestWriter writer = + createManifestWriter(table.location(), uuid, spec, table.io()); + for (DataFile file : files) { + writer.add(file); + committedDataFileByteSize.update(file.fileSizeInBytes()); + committedDataFileRecordCount.update(file.recordCount()); + } + writer.close(); + update.appendManifest(writer.toManifestFile()); + } + update.commit(); + } + + private ManifestWriter createManifestWriter( + String tableLocation, String uuid, PartitionSpec spec, FileIO io) { + String location = + FileFormat.AVRO.addExtension( + String.format( + "%s/metadata/%s-%s-%s.manifest", + tableLocation, manifestFilePrefix, uuid, spec.specId())); + OutputFile outputFile = io.newOutputFile(location); + return ManifestFiles.write(spec, outputFile); } } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java index c4090d9e7e53..bf00bf8519fc 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.iceberg; import com.google.auto.value.AutoValue; +import java.util.Map; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; @@ -46,9 +47,9 @@ public TableIdentifier getTableIdentifier() { } @SchemaIgnore - public DataFile getDataFile(PartitionSpec spec) { + public DataFile getDataFile(Map specs) { if (cachedDataFile == null) { - cachedDataFile = getSerializableDataFile().createDataFile(spec); + cachedDataFile = getSerializableDataFile().createDataFile(specs); } return cachedDataFile; } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java index 9a3262e19845..7941c13b0dfe 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java @@ -136,4 +136,8 @@ public long bytesWritten() { public DataFile getDataFile() { return icebergDataWriter.toDataFile(); } + + public String path() { + return absoluteFilename; + } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java index 396db7c20f36..255fce9ece4e 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java @@ -25,7 +25,6 @@ import java.util.List; import java.util.Map; import java.util.UUID; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.util.Preconditions; @@ -91,6 +90,7 @@ class DestinationState { final Cache writers; private final List dataFiles = Lists.newArrayList(); @VisibleForTesting final Map writerCounts = Maps.newHashMap(); + private final List exceptions = Lists.newArrayList(); DestinationState(IcebergDestination icebergDestination, Table table) { this.icebergDestination = icebergDestination; @@ -113,11 +113,14 @@ class DestinationState { try { recordWriter.close(); } catch (IOException e) { - throw new RuntimeException( - String.format( - "Encountered an error when closing data writer for table '%s', partition %s", - icebergDestination.getTableIdentifier(), pk), - e); + RuntimeException rethrow = + new RuntimeException( + String.format( + "Encountered an error when closing data writer for table '%s', path: %s", + icebergDestination.getTableIdentifier(), recordWriter.path()), + e); + exceptions.add(rethrow); + throw rethrow; } openWriters--; dataFiles.add(SerializableDataFile.from(recordWriter.getDataFile(), pk)); @@ -195,7 +198,9 @@ private RecordWriter createWriter(PartitionKey partitionKey) { private final Map, List> totalSerializableDataFiles = Maps.newHashMap(); - private static final Cache TABLE_CACHE = + + @VisibleForTesting + static final Cache TABLE_CACHE = CacheBuilder.newBuilder().expireAfterAccess(10, TimeUnit.MINUTES).build(); private boolean isClosed = false; @@ -221,22 +226,28 @@ private RecordWriter createWriter(PartitionKey partitionKey) { private Table getOrCreateTable(TableIdentifier identifier, Schema dataSchema) { @Nullable Table table = TABLE_CACHE.getIfPresent(identifier); if (table == null) { - try { - table = catalog.loadTable(identifier); - } catch (NoSuchTableException e) { + synchronized (TABLE_CACHE) { try { - org.apache.iceberg.Schema tableSchema = - IcebergUtils.beamSchemaToIcebergSchema(dataSchema); - // TODO(ahmedabu98): support creating a table with a specified partition spec - table = catalog.createTable(identifier, tableSchema); - LOG.info("Created Iceberg table '{}' with schema: {}", identifier, tableSchema); - } catch (AlreadyExistsException alreadyExistsException) { - // handle race condition where workers are concurrently creating the same table. - // if running into already exists exception, we perform one last load table = catalog.loadTable(identifier); + } catch (NoSuchTableException e) { + try { + org.apache.iceberg.Schema tableSchema = + IcebergUtils.beamSchemaToIcebergSchema(dataSchema); + // TODO(ahmedabu98): support creating a table with a specified partition spec + table = catalog.createTable(identifier, tableSchema); + LOG.info("Created Iceberg table '{}' with schema: {}", identifier, tableSchema); + } catch (AlreadyExistsException alreadyExistsException) { + // handle race condition where workers are concurrently creating the same table. + // if running into already exists exception, we perform one last load + table = catalog.loadTable(identifier); + } } + TABLE_CACHE.put(identifier, table); } - TABLE_CACHE.put(identifier, table); + } else { + // If fetching from cache, refresh the table to avoid working with stale metadata + // (e.g. partition spec) + table.refresh(); } return table; } @@ -254,15 +265,7 @@ public boolean write(WindowedValue icebergDestination, Row r icebergDestination, destination -> { TableIdentifier identifier = destination.getValue().getTableIdentifier(); - Table table; - try { - table = - TABLE_CACHE.get( - identifier, () -> getOrCreateTable(identifier, row.getSchema())); - } catch (ExecutionException e) { - throw new RuntimeException( - "Error while fetching or creating table: " + identifier, e); - } + Table table = getOrCreateTable(identifier, row.getSchema()); return new DestinationState(destination.getValue(), table); }); @@ -283,6 +286,17 @@ public void close() throws IOException { // removing writers from the state's cache will trigger the logic to collect each writer's // data file. state.writers.invalidateAll(); + // first check for any exceptions swallowed by the cache + if (!state.exceptions.isEmpty()) { + IllegalStateException exception = + new IllegalStateException( + String.format("Encountered %s failed writer(s).", state.exceptions.size())); + for (Exception e : state.exceptions) { + exception.addSuppressed(e); + } + throw exception; + } + if (state.dataFiles.isEmpty()) { continue; } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java index 699d4fa4dfd0..59b456162008 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.iceberg; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + import com.google.auto.value.AutoValue; import java.nio.ByteBuffer; import java.util.HashMap; @@ -24,7 +26,6 @@ import java.util.Map; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.iceberg.DataFile; import org.apache.iceberg.DataFiles; import org.apache.iceberg.FileFormat; @@ -141,12 +142,14 @@ static SerializableDataFile from(DataFile f, PartitionKey key) { * it from Beam-compatible types. */ @SuppressWarnings("nullness") - DataFile createDataFile(PartitionSpec partitionSpec) { - Preconditions.checkState( - partitionSpec.specId() == getPartitionSpecId(), - "Invalid partition spec id '%s'. This DataFile was originally created with spec id '%s'.", - partitionSpec.specId(), - getPartitionSpecId()); + DataFile createDataFile(Map partitionSpecs) { + PartitionSpec partitionSpec = + checkStateNotNull( + partitionSpecs.get(getPartitionSpecId()), + "This DataFile was originally created with spec id '%s'. Could not find " + + "this among table's partition specs: %s.", + getPartitionSpecId(), + partitionSpecs.keySet()); Metrics dataFileMetrics = new Metrics( diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java index a2d0c320f58f..fb3bf43f3515 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java @@ -74,7 +74,7 @@ public IcebergWriteResult expand(PCollection> input) { // Commit files to tables PCollection> snapshots = - writtenFiles.apply(new AppendFilesToTables(catalogConfig)); + writtenFiles.apply(new AppendFilesToTables(catalogConfig, filePrefix)); return new IcebergWriteResult(input.getPipeline(), snapshots); } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java index 7adf6defe520..2bce390e0992 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java @@ -19,24 +19,30 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.either; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.IOException; +import java.util.List; import java.util.Map; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.commons.lang3.RandomStringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.DataFile; import org.apache.iceberg.FileFormat; import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.hadoop.HadoopCatalog; import org.checkerframework.checker.nullness.qual.Nullable; @@ -44,6 +50,7 @@ import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.rules.TemporaryFolder; import org.junit.rules.TestName; import org.junit.runner.RunWith; @@ -73,6 +80,7 @@ public void setUp() { windowedDestination = getWindowedDestination("table_" + testName.getMethodName(), PARTITION_SPEC); catalog = new HadoopCatalog(new Configuration(), warehouse.location); + RecordWriterManager.TABLE_CACHE.invalidateAll(); } private WindowedValue getWindowedDestination( @@ -269,6 +277,25 @@ public void testRequireClosingBeforeFetchingDataFiles() { assertThrows(IllegalStateException.class, writerManager::getSerializableDataFiles); } + /** DataFile doesn't implement a .equals() method. Check equality manually. */ + private static void checkDataFileEquality(DataFile d1, DataFile d2) { + assertEquals(d1.path(), d2.path()); + assertEquals(d1.format(), d2.format()); + assertEquals(d1.recordCount(), d2.recordCount()); + assertEquals(d1.partition(), d2.partition()); + assertEquals(d1.specId(), d2.specId()); + assertEquals(d1.keyMetadata(), d2.keyMetadata()); + assertEquals(d1.splitOffsets(), d2.splitOffsets()); + assertEquals(d1.columnSizes(), d2.columnSizes()); + assertEquals(d1.valueCounts(), d2.valueCounts()); + assertEquals(d1.nullValueCounts(), d2.nullValueCounts()); + assertEquals(d1.nanValueCounts(), d2.nanValueCounts()); + assertEquals(d1.equalityFieldIds(), d2.equalityFieldIds()); + assertEquals(d1.fileSequenceNumber(), d2.fileSequenceNumber()); + assertEquals(d1.dataSequenceNumber(), d2.dataSequenceNumber()); + assertEquals(d1.pos(), d2.pos()); + } + @Test public void testSerializableDataFileRoundTripEquality() throws IOException { PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA); @@ -288,22 +315,161 @@ public void testSerializableDataFileRoundTripEquality() throws IOException { assertEquals(2L, datafile.recordCount()); DataFile roundTripDataFile = - SerializableDataFile.from(datafile, partitionKey).createDataFile(PARTITION_SPEC); - // DataFile doesn't implement a .equals() method. Check equality manually - assertEquals(datafile.path(), roundTripDataFile.path()); - assertEquals(datafile.format(), roundTripDataFile.format()); - assertEquals(datafile.recordCount(), roundTripDataFile.recordCount()); - assertEquals(datafile.partition(), roundTripDataFile.partition()); - assertEquals(datafile.specId(), roundTripDataFile.specId()); - assertEquals(datafile.keyMetadata(), roundTripDataFile.keyMetadata()); - assertEquals(datafile.splitOffsets(), roundTripDataFile.splitOffsets()); - assertEquals(datafile.columnSizes(), roundTripDataFile.columnSizes()); - assertEquals(datafile.valueCounts(), roundTripDataFile.valueCounts()); - assertEquals(datafile.nullValueCounts(), roundTripDataFile.nullValueCounts()); - assertEquals(datafile.nanValueCounts(), roundTripDataFile.nanValueCounts()); - assertEquals(datafile.equalityFieldIds(), roundTripDataFile.equalityFieldIds()); - assertEquals(datafile.fileSequenceNumber(), roundTripDataFile.fileSequenceNumber()); - assertEquals(datafile.dataSequenceNumber(), roundTripDataFile.dataSequenceNumber()); - assertEquals(datafile.pos(), roundTripDataFile.pos()); + SerializableDataFile.from(datafile, partitionKey) + .createDataFile(ImmutableMap.of(PARTITION_SPEC.specId(), PARTITION_SPEC)); + + checkDataFileEquality(datafile, roundTripDataFile); + } + + /** + * Users may update the table's spec while a write pipeline is running. Sometimes, this can happen + * after converting {@link DataFile} to {@link SerializableDataFile}s. When converting back to + * {@link DataFile} to commit in the {@link AppendFilesToTables} step, we need to make sure to use + * the same {@link PartitionSpec} it was originally created with. + * + *

    This test checks that we're preserving the right {@link PartitionSpec} when such an update + * happens. + */ + @Test + public void testRecreateSerializableDataAfterUpdatingPartitionSpec() throws IOException { + PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA); + + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "abcxyz", true).build(); + // same partition for both records (name_trunc=abc, bool=true) + partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + + // write some rows + RecordWriter writer = + new RecordWriter(catalog, windowedDestination.getValue(), "test_file_name", partitionKey); + writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row2)); + writer.close(); + + // fetch data file and its serializable version + DataFile datafile = writer.getDataFile(); + SerializableDataFile serializableDataFile = SerializableDataFile.from(datafile, partitionKey); + + assertEquals(2L, datafile.recordCount()); + assertEquals(serializableDataFile.getPartitionSpecId(), datafile.specId()); + + // update spec + Table table = catalog.loadTable(windowedDestination.getValue().getTableIdentifier()); + table.updateSpec().addField("id").removeField("bool").commit(); + + Map updatedSpecs = table.specs(); + DataFile roundTripDataFile = serializableDataFile.createDataFile(updatedSpecs); + + checkDataFileEquality(datafile, roundTripDataFile); + } + + @Test + public void testWriterKeepsUpWithUpdatingPartitionSpec() throws IOException { + Table table = catalog.loadTable(windowedDestination.getValue().getTableIdentifier()); + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "abcxyz", true).build(); + + // write some rows + RecordWriterManager writer = + new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, Integer.MAX_VALUE); + writer.write(windowedDestination, row); + writer.write(windowedDestination, row2); + writer.close(); + DataFile dataFile = + writer + .getSerializableDataFiles() + .get(windowedDestination) + .get(0) + .createDataFile(table.specs()); + + // check data file path contains the correct partition components + assertEquals(2L, dataFile.recordCount()); + assertEquals(dataFile.specId(), PARTITION_SPEC.specId()); + assertThat(dataFile.path().toString(), containsString("name_trunc=abc")); + assertThat(dataFile.path().toString(), containsString("bool=true")); + + // table is cached + assertEquals(1, RecordWriterManager.TABLE_CACHE.size()); + + // update spec + table.updateSpec().addField("id").removeField("bool").commit(); + + // write a second data file + // should refresh the table and use the new partition spec + RecordWriterManager writer2 = + new RecordWriterManager(catalog, "test_prefix_2", Long.MAX_VALUE, Integer.MAX_VALUE); + writer2.write(windowedDestination, row); + writer2.write(windowedDestination, row2); + writer2.close(); + + List serializableDataFiles = + writer2.getSerializableDataFiles().get(windowedDestination); + assertEquals(2, serializableDataFiles.size()); + for (SerializableDataFile serializableDataFile : serializableDataFiles) { + assertEquals(table.spec().specId(), serializableDataFile.getPartitionSpecId()); + dataFile = serializableDataFile.createDataFile(table.specs()); + assertEquals(1L, dataFile.recordCount()); + assertThat(dataFile.path().toString(), containsString("name_trunc=abc")); + assertThat( + dataFile.path().toString(), either(containsString("id=1")).or(containsString("id=2"))); + } + } + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testWriterExceptionGetsCaught() throws IOException { + RecordWriterManager writerManager = new RecordWriterManager(catalog, "test_file_name", 100, 2); + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA); + partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + + writerManager.write(windowedDestination, row); + + RecordWriterManager.DestinationState state = + writerManager.destinations.get(windowedDestination); + // replace with a failing record writer + FailingRecordWriter failingWriter = + new FailingRecordWriter( + catalog, windowedDestination.getValue(), "test_failing_writer", partitionKey); + state.writers.put(partitionKey, failingWriter); + writerManager.write(windowedDestination, row); + + // this tests that we indeed enter the catch block + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Encountered 1 failed writer(s)"); + try { + writerManager.close(); + } catch (IllegalStateException e) { + // fetch underlying exceptions and validate + Throwable[] underlyingExceptions = e.getSuppressed(); + assertEquals(1, underlyingExceptions.length); + for (Throwable t : underlyingExceptions) { + assertThat( + t.getMessage(), + containsString("Encountered an error when closing data writer for table")); + assertThat( + t.getMessage(), + containsString(windowedDestination.getValue().getTableIdentifier().toString())); + assertThat(t.getMessage(), containsString(failingWriter.path())); + Throwable realCause = t.getCause(); + assertEquals("I am failing!", realCause.getMessage()); + } + + throw e; + } + } + + static class FailingRecordWriter extends RecordWriter { + FailingRecordWriter( + Catalog catalog, IcebergDestination destination, String filename, PartitionKey partitionKey) + throws IOException { + super(catalog, destination, filename, partitionKey); + } + + @Override + public void close() throws IOException { + throw new IOException("I am failing!"); + } } } diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java index eddcb0de5561..266d04342d1f 100644 --- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java +++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java @@ -215,7 +215,7 @@ public void testPublishingThenReadingAll() throws IOException, JMSException { int unackRecords = countRemain(QUEUE); assertTrue( String.format("Too many unacknowledged messages: %d", unackRecords), - unackRecords < OPTIONS.getNumberOfRecords() * 0.002); + unackRecords < OPTIONS.getNumberOfRecords() * 0.003); // acknowledged records int ackRecords = OPTIONS.getNumberOfRecords() - unackRecords; diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle index ec4654bd88df..c2f056b0b7cb 100644 --- a/sdks/java/io/kafka/build.gradle +++ b/sdks/java/io/kafka/build.gradle @@ -35,9 +35,6 @@ ext { } def kafkaVersions = [ - '01103': "0.11.0.3", - '100': "1.0.0", - '111': "1.1.1", '201': "2.0.1", '211': "2.1.1", '222': "2.2.2", @@ -139,15 +136,13 @@ task kafkaVersionsCompatibilityTest { description = 'Runs KafkaIO with different Kafka client APIs' def testNames = createTestList(kafkaVersions, "Test") dependsOn testNames - dependsOn (":sdks:java:io:kafka:kafka-01103:kafkaVersion01103BatchIT") - dependsOn (":sdks:java:io:kafka:kafka-100:kafkaVersion100BatchIT") - dependsOn (":sdks:java:io:kafka:kafka-111:kafkaVersion111BatchIT") dependsOn (":sdks:java:io:kafka:kafka-201:kafkaVersion201BatchIT") dependsOn (":sdks:java:io:kafka:kafka-211:kafkaVersion211BatchIT") dependsOn (":sdks:java:io:kafka:kafka-222:kafkaVersion222BatchIT") dependsOn (":sdks:java:io:kafka:kafka-231:kafkaVersion231BatchIT") dependsOn (":sdks:java:io:kafka:kafka-241:kafkaVersion241BatchIT") dependsOn (":sdks:java:io:kafka:kafka-251:kafkaVersion251BatchIT") + dependsOn (":sdks:java:io:kafka:kafka-312:kafkaVersion312BatchIT") } static def createTestList(Map prefixMap, String suffix) { diff --git a/sdks/java/io/kafka/kafka-01103/build.gradle b/sdks/java/io/kafka/kafka-01103/build.gradle deleted file mode 100644 index 3a74bf04ef22..000000000000 --- a/sdks/java/io/kafka/kafka-01103/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -project.ext { - delimited="0.11.0.3" - undelimited="01103" - sdfCompatible=false -} - -apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-100/build.gradle b/sdks/java/io/kafka/kafka-100/build.gradle deleted file mode 100644 index bd5fa67b1cfc..000000000000 --- a/sdks/java/io/kafka/kafka-100/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -project.ext { - delimited="1.0.0" - undelimited="100" - sdfCompatible=false -} - -apply from: "../kafka-integration-test.gradle" diff --git a/sdks/java/io/kafka/kafka-111/build.gradle b/sdks/java/io/kafka/kafka-111/build.gradle deleted file mode 100644 index c2b0c8f82827..000000000000 --- a/sdks/java/io/kafka/kafka-111/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -project.ext { - delimited="1.1.1" - undelimited="111" - sdfCompatible=false -} - -apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java new file mode 100644 index 000000000000..147a30dcdd1a --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import com.google.auto.value.AutoValue; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Stores and exports metrics for a batch of Kafka Client RPCs. */ +public interface KafkaMetrics { + + void updateSuccessfulRpcMetrics(String topic, Duration elapsedTime); + + void updateKafkaMetrics(); + + /** No-op implementation of {@code KafkaResults}. */ + class NoOpKafkaMetrics implements KafkaMetrics { + private NoOpKafkaMetrics() {} + + @Override + public void updateSuccessfulRpcMetrics(String topic, Duration elapsedTime) {} + + @Override + public void updateKafkaMetrics() {} + + private static NoOpKafkaMetrics singleton = new NoOpKafkaMetrics(); + + static NoOpKafkaMetrics getInstance() { + return singleton; + } + } + + /** + * Metrics of a batch of RPCs. Member variables are thread safe; however, this class does not have + * atomicity across member variables. + * + *

    Expected usage: A number of threads record metrics in an instance of this class with the + * member methods. Afterwards, a single thread should call {@code updateStreamingInsertsMetrics} + * which will export all counters metrics and RPC latency distribution metrics to the underlying + * {@code perWorkerMetrics} container. Afterwards, metrics should not be written/read from this + * object. + */ + @AutoValue + abstract class KafkaMetricsImpl implements KafkaMetrics { + + private static final Logger LOG = LoggerFactory.getLogger(KafkaMetricsImpl.class); + + static HashMap latencyHistograms = new HashMap(); + + abstract HashMap> perTopicRpcLatencies(); + + abstract AtomicBoolean isWritable(); + + public static KafkaMetricsImpl create() { + return new AutoValue_KafkaMetrics_KafkaMetricsImpl( + new HashMap>(), new AtomicBoolean(true)); + } + + /** Record the rpc status and latency of a successful Kafka poll RPC call. */ + @Override + public void updateSuccessfulRpcMetrics(String topic, Duration elapsedTime) { + if (isWritable().get()) { + ConcurrentLinkedQueue latencies = perTopicRpcLatencies().get(topic); + if (latencies == null) { + latencies = new ConcurrentLinkedQueue(); + latencies.add(elapsedTime); + perTopicRpcLatencies().put(topic, latencies); + } else { + latencies.add(elapsedTime); + } + } + } + + /** Record rpc latency histogram metrics for all recorded topics. */ + private void recordRpcLatencyMetrics() { + for (Map.Entry> topicLatencies : + perTopicRpcLatencies().entrySet()) { + Histogram topicHistogram; + if (latencyHistograms.containsKey(topicLatencies.getKey())) { + topicHistogram = latencyHistograms.get(topicLatencies.getKey()); + } else { + topicHistogram = + KafkaSinkMetrics.createRPCLatencyHistogram( + KafkaSinkMetrics.RpcMethod.POLL, topicLatencies.getKey()); + latencyHistograms.put(topicLatencies.getKey(), topicHistogram); + } + // update all the latencies + for (Duration d : topicLatencies.getValue()) { + Preconditions.checkArgumentNotNull(topicHistogram); + topicHistogram.update(d.toMillis()); + } + } + } + + /** + * Export all metrics recorded in this instance to the underlying {@code perWorkerMetrics} + * containers. This function will only report metrics once per instance. Subsequent calls to + * this function will no-op. + */ + @Override + public void updateKafkaMetrics() { + if (!isWritable().compareAndSet(true, false)) { + LOG.warn("Updating stale Kafka metrics container"); + return; + } + recordRpcLatencyMetrics(); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java new file mode 100644 index 000000000000..f71926f97d27 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import org.apache.beam.sdk.metrics.DelegatingHistogram; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.LabeledMetricNameUtils; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.util.HistogramData; + +/** + * Helper class to create per worker metrics for Kafka Sink stages. + * + *

    Metrics will be in the namespace 'KafkaSink' and have their name formatted as: + * + *

    '{baseName}-{metricLabelKey1}:{metricLabelVal1};...{metricLabelKeyN}:{metricLabelValN};' ???? + */ + +// TODO, refactor out common parts for BQ sink, so it can be reused with other sinks, eg, GCS? +// @SuppressWarnings("unused") +public class KafkaSinkMetrics { + private static boolean supportKafkaMetrics = false; + + public static final String METRICS_NAMESPACE = "KafkaSink"; + + // Base Metric names + private static final String RPC_LATENCY = "RpcLatency"; + + // Kafka Consumer Method names + enum RpcMethod { + POLL, + } + + // Metric labels + private static final String TOPIC_LABEL = "topic_name"; + private static final String RPC_METHOD = "rpc_method"; + + /** + * Creates an Histogram metric to record RPC latency. Metric will have name. + * + *

    'RpcLatency*rpc_method:{method};topic_name:{topic};' + * + * @param method Kafka method associated with this metric. + * @param topic Kafka topic associated with this metric. + * @return Histogram with exponential buckets with a sqrt(2) growth factor. + */ + public static Histogram createRPCLatencyHistogram(RpcMethod method, String topic) { + LabeledMetricNameUtils.MetricNameBuilder nameBuilder = + LabeledMetricNameUtils.MetricNameBuilder.baseNameBuilder(RPC_LATENCY); + nameBuilder.addLabel(RPC_METHOD, method.toString()); + nameBuilder.addLabel(TOPIC_LABEL, topic); + + MetricName metricName = nameBuilder.build(METRICS_NAMESPACE); + HistogramData.BucketType buckets = HistogramData.ExponentialBuckets.of(1, 17); + + return new DelegatingHistogram(metricName, buckets, false, true); + } + + /** + * Returns a container to store metrics for Kafka metrics in Unbounded Readed. If these metrics + * are disabled, then we return a no-op container. + */ + static KafkaMetrics kafkaMetrics() { + if (supportKafkaMetrics) { + return KafkaMetrics.KafkaMetricsImpl.create(); + } else { + return KafkaMetrics.NoOpKafkaMetrics.getInstance(); + } + } + + public static void setSupportKafkaMetrics(boolean supportKafkaMetrics) { + KafkaSinkMetrics.supportKafkaMetrics = supportKafkaMetrics; + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index fed03047cf16..d86a5d0ce686 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; @@ -137,14 +138,12 @@ public boolean start() throws IOException { name, spec.getOffsetConsumerConfig(), spec.getConsumerConfig()); offsetConsumer = spec.getConsumerFactoryFn().apply(offsetConsumerConfig); - ConsumerSpEL.evaluateAssign(offsetConsumer, topicPartitions); // Fetch offsets once before running periodically. updateLatestOffsets(); offsetFetcherThread.scheduleAtFixedRate( this::updateLatestOffsets, 0, OFFSET_UPDATE_INTERVAL_SECONDS, TimeUnit.SECONDS); - return advance(); } @@ -158,6 +157,9 @@ public boolean advance() throws IOException { */ while (true) { if (curBatch.hasNext()) { + // Initalize metrics container. + kafkaResults = KafkaSinkMetrics.kafkaMetrics(); + PartitionState pState = curBatch.next(); if (!pState.recordIter.hasNext()) { // -- (c) @@ -228,8 +230,10 @@ public boolean advance() throws IOException { for (Map.Entry backlogSplit : perPartitionBacklogMetrics.entrySet()) { backlogBytesOfSplit.set(backlogSplit.getValue()); } - return true; + // Pass metrics to container. + kafkaResults.updateKafkaMetrics(); + return true; } else { // -- (b) nextBatch(); @@ -377,6 +381,7 @@ public long getSplitBacklogBytes() { .setDaemon(true) .setNameFormat("KafkaConsumerPoll-thread") .build()); + private AtomicReference consumerPollException = new AtomicReference<>(); private final SynchronousQueue> availableRecordsQueue = new SynchronousQueue<>(); @@ -399,6 +404,11 @@ public long getSplitBacklogBytes() { /** watermark before any records have been read. */ private static Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + // Created in each next batch, and updated at the end. + public KafkaMetrics kafkaResults = KafkaSinkMetrics.kafkaMetrics(); + private Stopwatch stopwatch = Stopwatch.createUnstarted(); + private String kafkaTopic = ""; + @Override public String toString() { return name; @@ -509,6 +519,13 @@ String name() { List partitions = Preconditions.checkArgumentNotNull(source.getSpec().getTopicPartitions()); + + // Each source has a single unique topic. + for (TopicPartition topicPartition : partitions) { + this.kafkaTopic = topicPartition.topic(); + break; + } + List> states = new ArrayList<>(partitions.size()); if (checkpointMark != null) { @@ -568,7 +585,16 @@ private void consumerPollLoop() { while (!closed.get()) { try { if (records.isEmpty()) { + // Each source has a single unique topic. + List topicPartitions = source.getSpec().getTopicPartitions(); + Preconditions.checkStateNotNull(topicPartitions); + + stopwatch.start(); records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); + stopwatch.stop(); + kafkaResults.updateSuccessfulRpcMetrics( + kafkaTopic, java.time.Duration.ofMillis(stopwatch.elapsed(TimeUnit.MILLISECONDS))); + } else if (availableRecordsQueue.offer( records, RECORDS_ENQUEUE_POLL_TIMEOUT.getMillis(), TimeUnit.MILLISECONDS)) { records = ConsumerRecords.empty(); @@ -592,7 +618,6 @@ private void consumerPollLoop() { private void commitCheckpointMark() { KafkaCheckpointMark checkpointMark = finalizedCheckpointMark.getAndSet(null); - if (checkpointMark != null) { LOG.debug("{}: Committing finalized checkpoint {}", this, checkpointMark); Consumer consumer = Preconditions.checkStateNotNull(this.consumer); @@ -685,23 +710,28 @@ private void setupInitialOffset(PartitionState pState) { // Called from setupInitialOffset() at the start and then periodically from offsetFetcher thread. private void updateLatestOffsets() { Consumer offsetConsumer = Preconditions.checkStateNotNull(this.offsetConsumer); - for (PartitionState p : partitionStates) { - try { - Instant fetchTime = Instant.now(); - ConsumerSpEL.evaluateSeek2End(offsetConsumer, p.topicPartition); - long offset = offsetConsumer.position(p.topicPartition); - p.setLatestOffset(offset, fetchTime); - } catch (Exception e) { - if (closed.get()) { // Ignore the exception if the reader is closed. - break; - } + List topicPartitions = + Preconditions.checkStateNotNull(source.getSpec().getTopicPartitions()); + Instant fetchTime = Instant.now(); + try { + Map endOffsets = offsetConsumer.endOffsets(topicPartitions); + for (PartitionState p : partitionStates) { + p.setLatestOffset( + Preconditions.checkStateNotNull( + endOffsets.get(p.topicPartition), + "No end offset found for partition %s.", + p.topicPartition), + fetchTime); + } + } catch (Exception e) { + if (!closed.get()) { // Ignore the exception if the reader is closed. LOG.warn( - "{}: exception while fetching latest offset for partition {}. will be retried.", + "{}: exception while fetching latest offset for partitions {}. will be retried.", this, - p.topicPartition, + topicPartitions, e); - // Don't update the latest offset. } + // Don't update the latest offset. } LOG.debug("{}: backlog {}", this, getSplitBacklogBytes()); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 4bda8cf28d4e..4d7aa6b32aef 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -19,13 +19,14 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.TimeUnit; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; import org.apache.beam.sdk.io.kafka.KafkaIOUtils.MovingAvg; @@ -247,19 +248,21 @@ private static class KafkaLatestOffsetEstimator private final Consumer offsetConsumer; private final TopicPartition topicPartition; private final Supplier memoizedBacklog; - private boolean closed; KafkaLatestOffsetEstimator( Consumer offsetConsumer, TopicPartition topicPartition) { this.offsetConsumer = offsetConsumer; this.topicPartition = topicPartition; - ConsumerSpEL.evaluateAssign(this.offsetConsumer, ImmutableList.of(this.topicPartition)); memoizedBacklog = Suppliers.memoizeWithExpiration( () -> { synchronized (offsetConsumer) { - ConsumerSpEL.evaluateSeek2End(offsetConsumer, topicPartition); - return offsetConsumer.position(topicPartition); + return Preconditions.checkStateNotNull( + offsetConsumer + .endOffsets(Collections.singleton(topicPartition)) + .get(topicPartition), + "No end offset found for partition %s.", + topicPartition); } }, 1, @@ -270,7 +273,6 @@ private static class KafkaLatestOffsetEstimator protected void finalize() { try { Closeables.close(offsetConsumer, true); - closed = true; LOG.info("Offset Estimator consumer was closed for {}", topicPartition); } catch (Exception anyException) { LOG.warn("Failed to close offset consumer for {}", topicPartition); @@ -281,10 +283,6 @@ protected void finalize() { public long estimate() { return memoizedBacklog.get(); } - - public boolean isClosed() { - return closed; - } } @GetInitialRestriction @@ -340,13 +338,18 @@ public WatermarkEstimator newWatermarkEstimator( public double getSize( @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange offsetRange) throws Exception { + // If present, estimates the record size to offset gap ratio. Compacted topics may hold less + // records than the estimated offset range due to record deletion within a partition. final LoadingCache avgRecordSize = Preconditions.checkStateNotNull(this.avgRecordSize); - double numRecords = + // The tracker estimates the offset range by subtracting the last claimed position from the + // currently observed end offset for the partition belonging to this split. + double estimatedOffsetRange = restrictionTracker(kafkaSourceDescriptor, offsetRange).getProgress().getWorkRemaining(); // Before processing elements, we don't have a good estimated size of records and offset gap. + // Return the estimated offset range without scaling by a size to gap ratio. if (!avgRecordSize.asMap().containsKey(kafkaSourceDescriptor.getTopicPartition())) { - return numRecords; + return estimatedOffsetRange; } if (offsetEstimatorCache != null) { for (Map.Entry tp : @@ -355,7 +358,12 @@ public double getSize( } } - return avgRecordSize.get(kafkaSourceDescriptor.getTopicPartition()).getTotalSize(numRecords); + // When processing elements, a moving average estimates the size of records and offset gap. + // Return the estimated offset range scaled by the estimated size to gap ratio. + return estimatedOffsetRange + * avgRecordSize + .get(kafkaSourceDescriptor.getTopicPartition()) + .estimateRecordByteSizeToOffsetCountRatio(); } @NewTracker @@ -373,7 +381,7 @@ public OffsetRangeTracker restrictionTracker( TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); KafkaLatestOffsetEstimator offsetEstimator = offsetEstimatorCacheInstance.get(topicPartition); - if (offsetEstimator == null || offsetEstimator.isClosed()) { + if (offsetEstimator == null) { Map updatedConsumerConfig = overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); @@ -453,7 +461,8 @@ public ProcessContinuation processElement( // and move to process the next element. if (rawRecords.isEmpty()) { if (!topicPartitionExists( - kafkaSourceDescriptor.getTopicPartition(), consumer.listTopics())) { + kafkaSourceDescriptor.getTopicPartition(), + consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) { return ProcessContinuation.stop(); } if (timestampPolicy != null) { @@ -547,20 +556,10 @@ public ProcessContinuation processElement( } private boolean topicPartitionExists( - TopicPartition topicPartition, Map> topicListMap) { + TopicPartition topicPartition, List partitionInfos) { // Check if the current TopicPartition still exists. - Set existingTopicPartitions = new HashSet<>(); - for (List topicPartitionList : topicListMap.values()) { - topicPartitionList.forEach( - partitionInfo -> { - existingTopicPartitions.add( - new TopicPartition(partitionInfo.topic(), partitionInfo.partition())); - }); - } - if (!existingTopicPartitions.contains(topicPartition)) { - return false; - } - return true; + return partitionInfos.stream() + .anyMatch(partitionInfo -> partitionInfo.partition() == (topicPartition.partition())); } // see https://github.com/apache/beam/issues/25962 @@ -667,8 +666,15 @@ private Map overrideBootstrapServersConfig( return config; } + // TODO: Collapse the two moving average trackers into a single accumulator using a single Guava + // AtomicDouble. Note that this requires that a single thread will call update and that while get + // may be called by multiple threads the method must only load the accumulator itself. + @ThreadSafe private static class AverageRecordSize { + @GuardedBy("this") private MovingAvg avgRecordSize; + + @GuardedBy("this") private MovingAvg avgRecordGap; public AverageRecordSize() { @@ -676,13 +682,26 @@ public AverageRecordSize() { this.avgRecordGap = new MovingAvg(); } - public void update(int recordSize, long gap) { + public synchronized void update(int recordSize, long gap) { avgRecordSize.update(recordSize); avgRecordGap.update(gap); } - public double getTotalSize(double numRecords) { - return avgRecordSize.get() * numRecords / (1 + avgRecordGap.get()); + public double estimateRecordByteSizeToOffsetCountRatio() { + double avgRecordSize; + double avgRecordGap; + + synchronized (this) { + avgRecordSize = this.avgRecordSize.get(); + avgRecordGap = this.avgRecordGap.get(); + } + + // The offset increases between records in a batch fetched from a compacted topic may be + // greater than 1. Compacted topics only store records with the greatest offset per key per + // partition, the records in between are deleted and will not be observed by a consumer. + // The observed gap between offsets is used to estimate the number of records that are likely + // to be observed for the provided number of records. + return avgRecordSize / (1 + avgRecordGap); } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java index 764e406f71cb..e614320db150 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -77,7 +77,7 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.io.kafka.KafkaIO.Read.FakeFlinkPipelineOptions; -import org.apache.beam.sdk.io.kafka.KafkaMocks.PositionErrorConsumerFactory; +import org.apache.beam.sdk.io.kafka.KafkaMocks.EndOffsetErrorConsumerFactory; import org.apache.beam.sdk.io.kafka.KafkaMocks.SendErrorProducerFactory; import org.apache.beam.sdk.metrics.DistributionResult; import org.apache.beam.sdk.metrics.Lineage; @@ -267,10 +267,6 @@ private static MockConsumer mkMockConsumer( public synchronized void assign(final Collection assigned) { super.assign(assigned); assignedPartitions.set(ImmutableList.copyOf(assigned)); - for (TopicPartition tp : assigned) { - updateBeginningOffsets(ImmutableMap.of(tp, 0L)); - updateEndOffsets(ImmutableMap.of(tp, (long) records.get(tp).size())); - } } // Override offsetsForTimes() in order to look up the offsets by timestamp. @Override @@ -290,9 +286,12 @@ public synchronized Map offsetsForTimes( } }; - for (String topic : topics) { - consumer.updatePartitions(topic, partitionMap.get(topic)); - } + partitionMap.forEach(consumer::updatePartitions); + consumer.updateBeginningOffsets( + records.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> 0L))); + consumer.updateEndOffsets( + records.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size()))); // MockConsumer does not maintain any relationship between partition seek position and the // records added. e.g. if we add 10 records to a partition and then seek to end of the @@ -1525,13 +1524,14 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception { List topics = ImmutableList.of("topic_a"); - PositionErrorConsumerFactory positionErrorConsumerFactory = new PositionErrorConsumerFactory(); + EndOffsetErrorConsumerFactory endOffsetErrorConsumerFactory = + new EndOffsetErrorConsumerFactory(); UnboundedSource, KafkaCheckpointMark> source = KafkaIO.read() .withBootstrapServers("myServer1:9092,myServer2:9092") .withTopics(topics) - .withConsumerFactoryFn(positionErrorConsumerFactory) + .withConsumerFactoryFn(endOffsetErrorConsumerFactory) .withKeyDeserializer(IntegerDeserializer.class) .withValueDeserializer(LongDeserializer.class) .makeSource(); @@ -1540,7 +1540,7 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception { reader.start(); - unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partition"); + unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partitions"); reader.close(); } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java new file mode 100644 index 000000000000..b84e143be773 --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.runners.core.metrics.MetricsContainerImpl; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.MetricsEnvironment; +import org.apache.beam.sdk.util.HistogramData; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link KafkaSinkMetrics}. */ +// TODO:Naireen - Refactor to remove duplicate code between the two sinks +@RunWith(JUnit4.class) +public class KafkaMetricsTest { + public static class TestHistogram implements Histogram { + public List values = Lists.newArrayList(); + private MetricName metricName = MetricName.named("KafkaSink", "name"); + + @Override + public void update(double value) { + values.add(value); + } + + @Override + public MetricName getName() { + return metricName; + } + } + + public static class TestMetricsContainer extends MetricsContainerImpl { + public ConcurrentHashMap, TestHistogram> + perWorkerHistograms = + new ConcurrentHashMap, TestHistogram>(); + + public TestMetricsContainer() { + super("TestStep"); + } + + @Override + public Histogram getPerWorkerHistogram( + MetricName metricName, HistogramData.BucketType bucketType) { + perWorkerHistograms.computeIfAbsent(KV.of(metricName, bucketType), kv -> new TestHistogram()); + return perWorkerHistograms.get(KV.of(metricName, bucketType)); + } + + @Override + public void reset() { + perWorkerHistograms.clear(); + } + } + + @Test + public void testNoOpKafkaMetrics() throws Exception { + TestMetricsContainer testContainer = new TestMetricsContainer(); + MetricsEnvironment.setCurrentContainer(testContainer); + + KafkaMetrics results = KafkaMetrics.NoOpKafkaMetrics.getInstance(); + results.updateSuccessfulRpcMetrics("test-topic", Duration.ofMillis(10)); + + results.updateKafkaMetrics(); + + assertThat(testContainer.perWorkerHistograms.size(), equalTo(0)); + } + + @Test + public void testKafkaRPCLatencyMetrics() throws Exception { + TestMetricsContainer testContainer = new TestMetricsContainer(); + MetricsEnvironment.setCurrentContainer(testContainer); + + KafkaSinkMetrics.setSupportKafkaMetrics(true); + + KafkaMetrics results = KafkaSinkMetrics.kafkaMetrics(); + + results.updateSuccessfulRpcMetrics("test-topic", Duration.ofMillis(10)); + + results.updateKafkaMetrics(); + // RpcLatency*rpc_method:POLL;topic_name:test-topic + MetricName histogramName = + MetricName.named("KafkaSink", "RpcLatency*rpc_method:POLL;topic_name:test-topic;"); + HistogramData.BucketType bucketType = HistogramData.ExponentialBuckets.of(1, 17); + + assertThat(testContainer.perWorkerHistograms.size(), equalTo(1)); + assertThat( + testContainer.perWorkerHistograms.get(KV.of(histogramName, bucketType)).values, + containsInAnyOrder(Double.valueOf(10.0))); + } + + @Test + public void testKafkaRPCLatencyMetricsAreNotRecorded() throws Exception { + TestMetricsContainer testContainer = new TestMetricsContainer(); + MetricsEnvironment.setCurrentContainer(testContainer); + + KafkaSinkMetrics.setSupportKafkaMetrics(false); + + KafkaMetrics results = KafkaSinkMetrics.kafkaMetrics(); + + results.updateSuccessfulRpcMetrics("test-topic", Duration.ofMillis(10)); + + results.updateKafkaMetrics(); + assertThat(testContainer.perWorkerHistograms.size(), equalTo(0)); + } +} diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java index 0844d71e7105..1303f1da3bcd 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.kafka; import java.io.Serializable; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -27,8 +28,8 @@ import org.apache.beam.sdk.values.KV; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; -import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; import org.apache.kafka.clients.producer.Callback; import org.apache.kafka.clients.producer.MockProducer; import org.apache.kafka.clients.producer.Producer; @@ -66,51 +67,33 @@ public Producer apply(Map input) { } } - public static final class PositionErrorConsumer extends MockConsumer { - - public PositionErrorConsumer() { - super(null); - } - - @Override - public synchronized long position(TopicPartition partition) { - throw new KafkaException("fakeException"); - } - - @Override - public synchronized List partitionsFor(String topic) { - return Collections.singletonList( - new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null)); - } - } - - public static final class PositionErrorConsumerFactory + public static final class EndOffsetErrorConsumerFactory implements SerializableFunction, Consumer> { - public PositionErrorConsumerFactory() {} + public EndOffsetErrorConsumerFactory() {} @Override public MockConsumer apply(Map input) { + final MockConsumer consumer; if (input.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { - return new PositionErrorConsumer(); - } else { - MockConsumer consumer = - new MockConsumer(null) { + consumer = + new MockConsumer(OffsetResetStrategy.EARLIEST) { @Override - public synchronized long position(TopicPartition partition) { - return 1L; - } - - @Override - public synchronized ConsumerRecords poll(long timeout) { - return ConsumerRecords.empty(); + public synchronized Map endOffsets( + Collection partitions) { + throw new KafkaException("fakeException"); } }; - consumer.updatePartitions( - "topic_a", - Collections.singletonList( - new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null))); - return consumer; + } else { + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST); } + consumer.updatePartitions( + "topic_a", + Collections.singletonList( + new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null))); + consumer.updateBeginningOffsets( + Collections.singletonMap(new TopicPartition("topic_a", 1), 0L)); + consumer.updateEndOffsets(Collections.singletonMap(new TopicPartition("topic_a", 1), 0L)); + return consumer; } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java new file mode 100644 index 000000000000..625a75c5316b --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.MetricName; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link KafkaSinkMetrics}. */ +// TODO:Naireen - Refactor to remove duplicate code between the Kafka and BigQuery sinks +@RunWith(JUnit4.class) +public class KafkaSinkMetricsTest { + @Test + public void testCreatingHistogram() throws Exception { + + Histogram histogram = + KafkaSinkMetrics.createRPCLatencyHistogram(KafkaSinkMetrics.RpcMethod.POLL, "topic1"); + + MetricName histogramName = + MetricName.named("KafkaSink", "RpcLatency*rpc_method:POLL;topic_name:topic1;"); + assertThat(histogram.getName(), equalTo(histogramName)); + } +} diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index 3189bbb140f0..cbff0f896619 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -205,6 +205,8 @@ public SimpleMockKafkaConsumer( OffsetResetStrategy offsetResetStrategy, TopicPartition topicPartition) { super(offsetResetStrategy); this.topicPartition = topicPartition; + updateBeginningOffsets(ImmutableMap.of(topicPartition, 0L)); + updateEndOffsets(ImmutableMap.of(topicPartition, Long.MAX_VALUE)); } public void reset() { @@ -214,6 +216,8 @@ public void reset() { this.startOffsetForTime = KV.of(0L, Instant.now()); this.stopOffsetForTime = KV.of(Long.MAX_VALUE, null); this.numOfRecordsPerPoll = 0L; + updateBeginningOffsets(ImmutableMap.of(topicPartition, 0L)); + updateEndOffsets(ImmutableMap.of(topicPartition, Long.MAX_VALUE)); } public void setRemoved() { @@ -248,6 +252,17 @@ public synchronized Map> listTopics() { topicPartition.topic(), topicPartition.partition(), null, null, null))); } + @Override + public synchronized List partitionsFor(String partition) { + if (this.isRemoved) { + return ImmutableList.of(); + } else { + return ImmutableList.of( + new PartitionInfo( + topicPartition.topic(), topicPartition.partition(), null, null, null)); + } + } + @Override public synchronized void assign(Collection partitions) { assertTrue(Iterables.getOnlyElement(partitions).equals(this.topicPartition)); diff --git a/sdks/java/io/parquet/build.gradle b/sdks/java/io/parquet/build.gradle index e8f1603f0b58..d5f22b31cc56 100644 --- a/sdks/java/io/parquet/build.gradle +++ b/sdks/java/io/parquet/build.gradle @@ -27,10 +27,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Parquet" ext.summary = "IO to read and write on Parquet storage format." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/snowflake/build.gradle b/sdks/java/io/snowflake/build.gradle index 2bdb9a867a34..9be257033edb 100644 --- a/sdks/java/io/snowflake/build.gradle +++ b/sdks/java/io/snowflake/build.gradle @@ -30,7 +30,7 @@ dependencies { implementation project(path: ":sdks:java:extensions:google-cloud-platform-core") permitUnusedDeclared project(path: ":sdks:java:extensions:google-cloud-platform-core") implementation library.java.slf4j_api - implementation group: 'net.snowflake', name: 'snowflake-jdbc', version: '3.12.11' + implementation group: 'net.snowflake', name: 'snowflake-jdbc', version: '3.20.0' implementation group: 'com.opencsv', name: 'opencsv', version: '5.0' implementation 'net.snowflake:snowflake-ingest-sdk:0.9.9' implementation "org.bouncycastle:bcprov-jdk15on:1.70" diff --git a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java index 73b3709da832..3094ea47d6ad 100644 --- a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java +++ b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java @@ -202,10 +202,10 @@ private Schema.Field beamField(FieldMetaData fieldDescriptor) { @SuppressWarnings("rawtypes") @Override - public @NonNull List fieldValueGetters( - @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { + public @NonNull List> fieldValueGetters( + @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { return schemaFieldDescriptors(targetTypeDescriptor.getRawType(), schema).keySet().stream() - .map(FieldExtractor::new) + .>map(FieldExtractor::new) .collect(Collectors.toList()); } @@ -242,11 +242,12 @@ private FieldValueTypeInformation fieldValueTypeInfo(Class type, String field if (factoryMethods.size() > 1) { throw new IllegalStateException("Overloaded factory methods: " + factoryMethods); } - return FieldValueTypeInformation.forSetter(factoryMethods.get(0), "", Collections.emptyMap()); + return FieldValueTypeInformation.forSetter( + TypeDescriptor.of(type), factoryMethods.get(0), ""); } else { try { return FieldValueTypeInformation.forField( - type.getDeclaredField(fieldName), 0, Collections.emptyMap()); + TypeDescriptor.of(type), type.getDeclaredField(fieldName), 0); } catch (NoSuchFieldException e) { throw new IllegalArgumentException(e); } @@ -374,7 +375,7 @@ private & TEnum> FieldType beamType(FieldValueMetaDat } } - private static class FieldExtractor> + private static class FieldExtractor implements FieldValueGetter { private final FieldT field; @@ -384,8 +385,9 @@ private FieldExtractor(FieldT field) { @Override public @Nullable Object get(T thrift) { - if (!(thrift instanceof TUnion) || thrift.isSet(field)) { - final Object value = thrift.getFieldValue(field); + TBase t = (TBase) thrift; + if (!(thrift instanceof TUnion) || t.isSet(field)) { + final Object value = t.getFieldValue(field); if (value instanceof Enum) { return ((Enum) value).ordinal(); } else { diff --git a/sdks/java/javadoc/build.gradle b/sdks/java/javadoc/build.gradle index c0622b173043..284cef130bd3 100644 --- a/sdks/java/javadoc/build.gradle +++ b/sdks/java/javadoc/build.gradle @@ -62,7 +62,7 @@ task aggregateJavadoc(type: Javadoc) { source exportedJavadocProjects.collect { project(it).sourceSets.main.allJava } classpath = files(exportedJavadocProjects.collect { project(it).sourceSets.main.compileClasspath }) destinationDir = file("${buildDir}/docs/javadoc") - failOnError = true + failOnError = false exclude "org/apache/beam/examples/*" exclude "org/apache/beam/fn/harness/*" diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index 8477726686ee..8e7e0862eff4 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -86,17 +86,20 @@ public class Managed { // TODO: Dynamically generate a list of supported transforms public static final String ICEBERG = "iceberg"; public static final String KAFKA = "kafka"; + public static final String BIGQUERY = "bigquery"; // Supported SchemaTransforms public static final Map READ_TRANSFORMS = ImmutableMap.builder() .put(ICEBERG, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_READ)) .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ)) + .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ)) .build(); public static final Map WRITE_TRANSFORMS = ImmutableMap.builder() .put(ICEBERG, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_WRITE)) .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE)) + .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE)) .build(); /** @@ -104,7 +107,9 @@ public class Managed { * supported managed sources are: * *

      - *
    • {@link Managed#ICEBERG} : Read from Apache Iceberg + *
    • {@link Managed#ICEBERG} : Read from Apache Iceberg tables + *
    • {@link Managed#KAFKA} : Read from Apache Kafka topics + *
    • {@link Managed#BIGQUERY} : Read from GCP BigQuery tables *
    */ public static ManagedTransform read(String source) { @@ -124,7 +129,9 @@ public static ManagedTransform read(String source) { * managed sinks are: * *
      - *
    • {@link Managed#ICEBERG} : Write to Apache Iceberg + *
    • {@link Managed#ICEBERG} : Write to Apache Iceberg tables + *
    • {@link Managed#KAFKA} : Write to Apache Kafka topics + *
    • {@link Managed#BIGQUERY} : Write to GCP BigQuery tables *
    */ public static ManagedTransform write(String sink) { diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java index 6f97983d3260..b705306b9478 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java @@ -117,7 +117,7 @@ protected void validate() { "Please specify a config or a config URL, but not both."); } - public @Nullable String resolveUnderlyingConfig() { + private Map resolveUnderlyingConfig() { String yamlTransformConfig = getConfig(); // If YAML string is empty, then attempt to read from YAML file if (Strings.isNullOrEmpty(yamlTransformConfig)) { @@ -131,7 +131,8 @@ protected void validate() { throw new RuntimeException(e); } } - return yamlTransformConfig; + + return YamlUtils.yamlStringToMap(yamlTransformConfig); } } @@ -152,34 +153,34 @@ protected SchemaTransform from(ManagedConfig managedConfig) { static class ManagedSchemaTransform extends SchemaTransform { private final ManagedConfig managedConfig; - private final Row underlyingTransformConfig; + private final Row underlyingRowConfig; private final SchemaTransformProvider underlyingTransformProvider; ManagedSchemaTransform( ManagedConfig managedConfig, SchemaTransformProvider underlyingTransformProvider) { // parse config before expansion to check if it matches underlying transform's config schema Schema transformConfigSchema = underlyingTransformProvider.configurationSchema(); - Row underlyingTransformConfig; + Row underlyingRowConfig; try { - underlyingTransformConfig = getRowConfig(managedConfig, transformConfigSchema); + underlyingRowConfig = getRowConfig(managedConfig, transformConfigSchema); } catch (Exception e) { throw new IllegalArgumentException( "Encountered an error when retrieving a Row configuration", e); } - this.managedConfig = managedConfig; - this.underlyingTransformConfig = underlyingTransformConfig; + this.underlyingRowConfig = underlyingRowConfig; this.underlyingTransformProvider = underlyingTransformProvider; + this.managedConfig = managedConfig; } @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { LOG.debug( - "Building transform \"{}\" with Row configuration: {}", + "Building transform \"{}\" with configuration: {}", underlyingTransformProvider.identifier(), - underlyingTransformConfig); + underlyingRowConfig); - return input.apply(underlyingTransformProvider.from(underlyingTransformConfig)); + return input.apply(underlyingTransformProvider.from(underlyingRowConfig)); } public ManagedConfig getManagedConfig() { @@ -201,16 +202,14 @@ Row getConfigurationRow() { } } + // May return an empty row (perhaps the underlying transform doesn't have any required + // parameters) @VisibleForTesting static Row getRowConfig(ManagedConfig config, Schema transformSchema) { - // May return an empty row (perhaps the underlying transform doesn't have any required - // parameters) - String yamlConfig = config.resolveUnderlyingConfig(); - Map configMap = YamlUtils.yamlStringToMap(yamlConfig); - - // The config Row object will be used to build the underlying SchemaTransform. - // If a mapping for the SchemaTransform exists, we use it to update parameter names and align - // with the underlying config schema + Map configMap = config.resolveUnderlyingConfig(); + // Build a config Row that will be used to build the underlying SchemaTransform. + // If a mapping for the SchemaTransform exists, we use it to update parameter names to align + // with the underlying SchemaTransform config schema Map mapping = MAPPINGS.get(config.getTransformIdentifier()); if (mapping != null && configMap != null) { Map remappedConfig = new HashMap<>(); @@ -227,7 +226,7 @@ static Row getRowConfig(ManagedConfig config, Schema transformSchema) { return YamlUtils.toBeamRow(configMap, transformSchema, false); } - // We load providers seperately, after construction, to prevent the + // We load providers separately, after construction, to prevent the // 'ManagedSchemaTransformProvider' from being initialized in a recursive loop // when being loaded using 'AutoValue'. synchronized Map getAllProviders() { diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java index 4cf752747be5..30476a30d373 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java @@ -50,9 +50,27 @@ public class ManagedTransformConstants { private static final Map KAFKA_WRITE_MAPPINGS = ImmutableMap.builder().put("data_format", "format").build(); + private static final Map BIGQUERY_READ_MAPPINGS = + ImmutableMap.builder() + .put("table", "table_spec") + .put("fields", "selected_fields") + .build(); + + private static final Map BIGQUERY_WRITE_MAPPINGS = + ImmutableMap.builder() + .put("at_least_once", "use_at_least_once_semantics") + .put("triggering_frequency", "triggering_frequency_seconds") + .build(); + public static final Map> MAPPINGS = ImmutableMap.>builder() .put(getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ), KAFKA_READ_MAPPINGS) .put(getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE), KAFKA_WRITE_MAPPINGS) + .put( + getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ), + BIGQUERY_READ_MAPPINGS) + .put( + getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE), + BIGQUERY_WRITE_MAPPINGS) .build(); } diff --git a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java index e9edf8751e34..a287ec6260ce 100644 --- a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java +++ b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java @@ -88,8 +88,7 @@ public void testGetConfigRowFromYamlFile() throws URISyntaxException { .withFieldValue("extra_integer", 123) .build(); Row configRow = - ManagedSchemaTransformProvider.getRowConfig( - config, new TestSchemaTransformProvider().configurationSchema()); + ManagedSchemaTransformProvider.getRowConfig(config, TestSchemaTransformProvider.SCHEMA); assertEquals(expectedRow, configRow); } diff --git a/sdks/python/apache_beam/__init__.py b/sdks/python/apache_beam/__init__.py index 6e08083bc0de..af88934b0e71 100644 --- a/sdks/python/apache_beam/__init__.py +++ b/sdks/python/apache_beam/__init__.py @@ -70,17 +70,11 @@ import warnings if sys.version_info.major == 3: - if sys.version_info.minor <= 7 or sys.version_info.minor >= 13: + if sys.version_info.minor <= 8 or sys.version_info.minor >= 13: warnings.warn( 'This version of Apache Beam has not been sufficiently tested on ' 'Python %s.%s. You may encounter bugs or missing features.' % (sys.version_info.major, sys.version_info.minor)) - elif sys.version_info.minor == 8: - warnings.warn( - 'Python 3.8 reaches EOL in October 2024 and support will ' - 'be removed from Apache Beam in version 2.61.0. See ' - 'https://github.com/apache/beam/issues/31192 for more ' - 'information.') pass else: raise RuntimeError( diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index ff5fb5bef7ac..dfdb247d781d 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -1975,7 +1975,7 @@ class DecimalCoderImpl(StreamCoderImpl): def encode_to_stream(self, value, out, nested): # type: (decimal.Decimal, create_OutputStream, bool) -> None - scale = -value.as_tuple().exponent + scale = -value.as_tuple().exponent # type: ignore[operator] int_value = int(value.scaleb(scale)) out.write_var_int64(scale) self.BIG_INT_CODER_IMPL.encode_to_stream(int_value, out, nested) diff --git a/sdks/python/apache_beam/coders/coders_property_based_test.py b/sdks/python/apache_beam/coders/coders_property_based_test.py index be18dd3586b0..9279fc31c099 100644 --- a/sdks/python/apache_beam/coders/coders_property_based_test.py +++ b/sdks/python/apache_beam/coders/coders_property_based_test.py @@ -141,7 +141,7 @@ def test_row_coder(self, data: st.DataObject): coders_registry.register_coder(RowType, RowCoder) # TODO(https://github.com/apache/beam/issues/23002): Apply nulls for these - row = RowType( # type: ignore + row = RowType( **{ name: data.draw(SCHEMA_TYPES_TO_STRATEGY[type_]) for name, diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 7dcfae83f10e..4bd9698dd57b 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -53,7 +53,7 @@ except ImportError: dataclasses = None # type: ignore -MyNamedTuple = collections.namedtuple('A', ['x', 'y']) +MyNamedTuple = collections.namedtuple('A', ['x', 'y']) # type: ignore[name-match] MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)]) diff --git a/sdks/python/apache_beam/examples/complete/estimate_pi.py b/sdks/python/apache_beam/examples/complete/estimate_pi.py index 089767d2a99e..530a270308d9 100644 --- a/sdks/python/apache_beam/examples/complete/estimate_pi.py +++ b/sdks/python/apache_beam/examples/complete/estimate_pi.py @@ -30,9 +30,8 @@ import json import logging import random +from collections.abc import Iterable from typing import Any -from typing import Iterable -from typing import Tuple import apache_beam as beam from apache_beam.io import WriteToText @@ -40,7 +39,7 @@ from apache_beam.options.pipeline_options import SetupOptions -@beam.typehints.with_output_types(Tuple[int, int, int]) +@beam.typehints.with_output_types(tuple[int, int, int]) @beam.typehints.with_input_types(int) def run_trials(runs): """Run trials and return a 3-tuple representing the results. @@ -62,8 +61,8 @@ def run_trials(runs): return runs, inside_runs, 0 -@beam.typehints.with_output_types(Tuple[int, int, float]) -@beam.typehints.with_input_types(Iterable[Tuple[int, int, Any]]) +@beam.typehints.with_output_types(tuple[int, int, float]) +@beam.typehints.with_input_types(Iterable[tuple[int, int, Any]]) def combine_results(results): """Combiner function to sum up trials and compute the estimate. diff --git a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py index 6b5573aa4569..0a8c55d17d3a 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py @@ -25,7 +25,6 @@ import unittest import uuid from typing import TYPE_CHECKING -from typing import List import pytest import pytz @@ -53,7 +52,7 @@ if TYPE_CHECKING: import google.cloud.bigtable.instance -EXISTING_INSTANCES: List['google.cloud.bigtable.instance.Instance'] = [] +EXISTING_INSTANCES: list['google.cloud.bigtable.instance.Instance'] = [] LABEL_KEY = 'python-bigtable-beam' label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC) label_stamp_micros = _microseconds_from_datetime(label_stamp) diff --git a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py index 65ea7990a2d8..9d71ac32aff2 100644 --- a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py +++ b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py @@ -59,7 +59,7 @@ import logging import re import sys -from typing import Iterable +from collections.abc import Iterable from typing import Optional from typing import Text import uuid diff --git a/sdks/python/apache_beam/examples/cookbook/group_with_coder.py b/sdks/python/apache_beam/examples/cookbook/group_with_coder.py index 3ce7836b491a..8a959138d3da 100644 --- a/sdks/python/apache_beam/examples/cookbook/group_with_coder.py +++ b/sdks/python/apache_beam/examples/cookbook/group_with_coder.py @@ -30,7 +30,6 @@ import argparse import logging import sys -import typing import apache_beam as beam from apache_beam import coders @@ -71,7 +70,7 @@ def is_deterministic(self): # Annotate the get_players function so that the typehint system knows that the # input to the CombinePerKey operation is a key-value pair of a Player object # and an integer. -@with_output_types(typing.Tuple[Player, int]) +@with_output_types(tuple[Player, int]) def get_players(descriptor): name, points = descriptor.split(',') return Player(name), int(points) diff --git a/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py index 5eb57c8fc080..69c2eacc593d 100644 --- a/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py +++ b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py @@ -27,10 +27,8 @@ import argparse import logging -from typing import Dict -from typing import Iterable -from typing import Iterator -from typing import Tuple +from collections.abc import Iterable +from collections.abc import Iterator import apache_beam as beam import torch @@ -45,14 +43,14 @@ from transformers import AutoTokenizer -def add_mask_to_last_word(text: str) -> Tuple[str, str]: +def add_mask_to_last_word(text: str) -> tuple[str, str]: text_list = text.split() return text, ' '.join(text_list[:-2] + ['', text_list[-1]]) def tokenize_sentence( - text_and_mask: Tuple[str, str], - tokenizer: AutoTokenizer) -> Tuple[str, Dict[str, torch.Tensor]]: + text_and_mask: tuple[str, str], + tokenizer: AutoTokenizer) -> tuple[str, dict[str, torch.Tensor]]: text, masked_text = text_and_mask tokenized_sentence = tokenizer.encode_plus(masked_text, return_tensors="pt") @@ -81,7 +79,7 @@ def __init__(self, tokenizer: AutoTokenizer): super().__init__() self.tokenizer = tokenizer - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction_result = element inputs = prediction_result.example logits = prediction_result.inference['logits'] diff --git a/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py index 9005ea5d11d7..7d4899cc38d9 100644 --- a/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py +++ b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py @@ -28,8 +28,7 @@ import argparse import logging -from typing import Iterable -from typing import Tuple +from collections.abc import Iterable import apache_beam as beam from apache_beam.ml.inference.base import KeyedModelHandler @@ -49,7 +48,7 @@ class PostProcessor(beam.DoFn): Hugging Face Pipeline for Question Answering returns a dictionary with score, start and end index of answer and the answer. """ - def process(self, result: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, result: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction = result predicted_answer = prediction.inference['answer'] yield text + ';' + predicted_answer diff --git a/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py b/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py index 18f697f673bf..0e62ab865431 100644 --- a/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py +++ b/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py @@ -28,9 +28,8 @@ import argparse import logging -from typing import Iterable -from typing import Iterator -from typing import Tuple +from collections.abc import Iterable +from collections.abc import Iterator import numpy as np @@ -47,7 +46,7 @@ def tokenize_sentence(text: str, - tokenizer: RobertaTokenizer) -> Tuple[str, torch.Tensor]: + tokenizer: RobertaTokenizer) -> tuple[str, torch.Tensor]: tokenized_sentence = tokenizer.encode(text, add_special_tokens=True) # Workaround to manually remove batch dim until we have the feature to @@ -63,7 +62,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: class PostProcessor(beam.DoFn): - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction = np.argmax(prediction_result.inference, axis=0) yield filename + ';' + str(prediction) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py index d627001bcb82..c24a6d0a910e 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py @@ -21,9 +21,8 @@ import io import logging import os -from typing import Iterator +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -41,7 +40,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -122,13 +121,13 @@ def run( model_class = models.mobilenet_v2 model_params = {'num_classes': 1000} - def preprocess(image_name: str) -> Tuple[str, torch.Tensor]: + def preprocess(image_name: str) -> tuple[str, torch.Tensor]: image_name, image = read_image( image_file_name=image_name, path_to_dir=known_args.images_dir) return (image_name, preprocess_image(image)) - def postprocess(element: Tuple[str, PredictionResult]) -> str: + def postprocess(element: tuple[str, PredictionResult]) -> str: filename, prediction_result = element prediction = torch.argmax(prediction_result.inference, dim=0) return filename + ',' + str(prediction.item()) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py index 2a4e6e9a9bc6..787341263fde 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py @@ -62,10 +62,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -84,7 +83,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -116,7 +115,7 @@ class PostProcessor(beam.DoFn): Return filename, prediction and the model id used to perform the prediction """ - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction = torch.argmax(prediction_result.inference, dim=0) yield filename, prediction, prediction_result.model_id diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py index cdecb826d6e3..5e5f77a679c3 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py @@ -21,10 +21,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -138,7 +137,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -161,7 +160,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: class PostProcessor(beam.DoFn): - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction_labels = prediction_result.inference['labels'] classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels] diff --git a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py index 9de10e73e11b..a616998d2c73 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py @@ -26,10 +26,8 @@ import argparse import logging -from typing import Dict -from typing import Iterable -from typing import Iterator -from typing import Tuple +from collections.abc import Iterable +from collections.abc import Iterator import apache_beam as beam import torch @@ -45,14 +43,14 @@ from transformers import BertTokenizer -def add_mask_to_last_word(text: str) -> Tuple[str, str]: +def add_mask_to_last_word(text: str) -> tuple[str, str]: text_list = text.split() return text, ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]]) def tokenize_sentence( - text_and_mask: Tuple[str, str], - bert_tokenizer: BertTokenizer) -> Tuple[str, Dict[str, torch.Tensor]]: + text_and_mask: tuple[str, str], + bert_tokenizer: BertTokenizer) -> tuple[str, dict[str, torch.Tensor]]: text, masked_text = text_and_mask tokenized_sentence = bert_tokenizer.encode_plus( masked_text, return_tensors="pt") @@ -84,7 +82,7 @@ def __init__(self, bert_tokenizer: BertTokenizer): super().__init__() self.bert_tokenizer = bert_tokenizer - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction_result = element inputs = prediction_result.example logits = prediction_result.inference['logits'] diff --git a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py index f0b5462d5335..18c4c3e653b4 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py @@ -24,10 +24,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -143,7 +142,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -168,15 +167,15 @@ def filter_empty_lines(text: str) -> Iterator[str]: class KeyExamplesForEachModelType(beam.DoFn): """Duplicate data to run against each model type""" def process( - self, element: Tuple[torch.Tensor, - str]) -> Iterable[Tuple[str, torch.Tensor]]: + self, element: tuple[torch.Tensor, + str]) -> Iterable[tuple[str, torch.Tensor]]: yield 'v1', element[0] yield 'v2', element[0] class PostProcessor(beam.DoFn): def process( - self, element: Tuple[str, PredictionResult]) -> Tuple[torch.Tensor, str]: + self, element: tuple[str, PredictionResult]) -> tuple[torch.Tensor, str]: model, prediction_result = element prediction_labels = prediction_result.inference['labels'] classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels] diff --git a/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py b/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py index a6e4dc2bdb03..755eff17c163 100644 --- a/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py +++ b/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py @@ -22,9 +22,9 @@ import argparse import logging import time -from typing import Iterable +from collections.abc import Iterable +from collections.abc import Sequence from typing import Optional -from typing import Sequence import apache_beam as beam from apache_beam.ml.inference import base diff --git a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py index 3aa2f362fa64..0a527e88dec2 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py @@ -31,7 +31,7 @@ import argparse import os -from typing import Iterable +from collections.abc import Iterable import pandas diff --git a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py index 5392cdf7ddae..d7d08e294e9d 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py @@ -27,9 +27,7 @@ import argparse import logging import os -from typing import Iterable -from typing import List -from typing import Tuple +from collections.abc import Iterable import apache_beam as beam from apache_beam.ml.inference.base import KeyedModelHandler @@ -42,7 +40,7 @@ from apache_beam.runners.runner import PipelineResult -def process_input(row: str) -> Tuple[int, List[int]]: +def process_input(row: str) -> tuple[int, list[int]]: data = row.split(',') label, pixels = int(data[0]), data[1:] pixels = [int(pixel) for pixel in pixels] @@ -53,7 +51,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ - def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = prediction_result.inference yield '{},{}'.format(label, prediction) diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py index a0f249dcfbf0..b44d775f4ad3 100644 --- a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py @@ -17,8 +17,8 @@ import argparse import logging -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator import numpy diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py index 6cf746e77cd2..bf85bb1aef16 100644 --- a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py +++ b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py @@ -17,8 +17,7 @@ import argparse import logging -from typing import Iterable -from typing import Tuple +from collections.abc import Iterable import numpy @@ -33,7 +32,7 @@ from apache_beam.runners.runner import PipelineResult -def process_input(row: str) -> Tuple[int, numpy.ndarray]: +def process_input(row: str) -> tuple[int, numpy.ndarray]: data = row.split(',') label, pixels = int(data[0]), data[1:] pixels = [int(pixel) for pixel in pixels] @@ -46,7 +45,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ - def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = numpy.argmax(prediction_result.inference, axis=0) yield '{},{}'.format(label, prediction) diff --git a/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py b/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py index 1faf502c71af..677d36b9b767 100644 --- a/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py +++ b/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py @@ -22,9 +22,8 @@ import argparse import io import os -from typing import Iterable +from collections.abc import Iterable from typing import Optional -from typing import Tuple import numpy as np @@ -134,14 +133,14 @@ def attach_im_size_to_key( - data: Tuple[str, Image.Image]) -> Tuple[Tuple[str, int, int], Image.Image]: + data: tuple[str, Image.Image]) -> tuple[tuple[str, int, int], Image.Image]: filename, image = data width, height = image.size return ((filename, width, height), image) def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -168,7 +167,7 @@ class PostProcessor(beam.DoFn): an integer that we can transform into actual string class using COCO_OBJ_DET_CLASSES as reference. """ - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: key, prediction_result = element filename, im_width, im_height = key num_detections = prediction_result.inference[0] diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py index 09a70caa4ede..5df0b51e36d7 100644 --- a/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py @@ -32,10 +32,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import tensorflow as tf @@ -60,7 +59,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: def read_and_process_image( image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, tf.Tensor]: + path_to_dir: Optional[str] = None) -> tuple[str, tf.Tensor]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -97,7 +96,7 @@ def convert_image_to_example_proto(tensor: tf.Tensor) -> tf.train.Example: class ProcessInferenceToString(beam.DoFn): def process( - self, element: Tuple[str, + self, element: tuple[str, prediction_log_pb2.PredictionLog]) -> Iterable[str]: """ Args: diff --git a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py index 73126569e988..20312e7d3c88 100644 --- a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py @@ -27,9 +27,7 @@ import argparse import io import logging -from typing import Iterable -from typing import List -from typing import Tuple +from collections.abc import Iterable import apache_beam as beam import tensorflow as tf @@ -102,13 +100,13 @@ def parse_known_args(argv): COLUMNS = ['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses'] -def read_image(image_file_name: str) -> Tuple[str, bytes]: +def read_image(image_file_name: str) -> tuple[str, bytes]: with FileSystems().open(image_file_name, 'r') as file: data = io.BytesIO(file.read()).getvalue() return image_file_name, data -def preprocess_image(data: bytes) -> List[float]: +def preprocess_image(data: bytes) -> list[float]: """Preprocess the image, resizing it and normalizing it before converting to a list. """ @@ -119,7 +117,7 @@ def preprocess_image(data: bytes) -> List[float]: class PostProcessor(beam.DoFn): - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: img_name, prediction_result = element prediction_vals = prediction_result.inference index = prediction_vals.index(max(prediction_vals)) diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index 3cf7d04cb03e..2708c0f3d1a1 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -25,7 +25,7 @@ import argparse import logging -from typing import Iterable +from collections.abc import Iterable import apache_beam as beam from apache_beam.ml.inference.base import PredictionResult diff --git a/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py b/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py index 963187fd210d..498511a5a2cf 100644 --- a/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py +++ b/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py @@ -17,10 +17,8 @@ import argparse import logging -from typing import Callable -from typing import Iterable -from typing import List -from typing import Tuple +from collections.abc import Callable +from collections.abc import Iterable from typing import Union import numpy @@ -48,7 +46,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ - def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = prediction_result.inference yield '{},{}'.format(label, prediction) @@ -89,7 +87,7 @@ def parse_known_args(argv): def load_sklearn_iris_test_data( data_type: Callable, split: bool = True, - seed: int = 999) -> List[Union[numpy.array, pandas.DataFrame]]: + seed: int = 999) -> list[Union[numpy.array, pandas.DataFrame]]: """ Loads test data from the sklearn Iris dataset in a given format, either in a single or multiple batches. diff --git a/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py b/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py index 1cdd266c3df4..9b4889017077 100644 --- a/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py +++ b/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py @@ -26,7 +26,6 @@ import logging import sys -import typing import apache_beam as beam from apache_beam.io.kafka import ReadFromKafka @@ -97,7 +96,7 @@ def convert_kafka_record_to_dictionary(record): topic='projects/pubsub-public-data/topics/taxirides-realtime'). with_output_types(bytes) | beam.Map(lambda x: (b'', x)).with_output_types( - typing.Tuple[bytes, bytes]) # Kafka write transforms expects KVs. + tuple[bytes, bytes]) # Kafka write transforms expects KVs. | beam.WindowInto(beam.window.FixedWindows(window_size)) | WriteToKafka( producer_config={'bootstrap.servers': bootstrap_servers}, diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py index 715011d302d2..c849af4a00b3 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets.py +++ b/sdks/python/apache_beam/examples/snippets/snippets.py @@ -1143,6 +1143,60 @@ def model_multiple_pcollections_flatten(contents, output_path): merged | beam.io.WriteToText(output_path) +def model_multiple_pcollections_flatten_with(contents, output_path): + """Merging a PCollection with FlattenWith.""" + some_hash_fn = lambda s: ord(s[0]) + partition_fn = lambda element, partitions: some_hash_fn(element) % partitions + import apache_beam as beam + with TestPipeline() as pipeline: # Use TestPipeline for testing. + + # Partition into deciles + partitioned = pipeline | beam.Create(contents) | beam.Partition( + partition_fn, 3) + pcoll1 = partitioned[0] + pcoll2 = partitioned[1] + pcoll3 = partitioned[2] + SomeTransform = lambda: beam.Map(lambda x: x) + SomeOtherTransform = lambda: beam.Map(lambda x: x) + + # Flatten them back into 1 + + # A collection of PCollection objects can be represented simply + # as a tuple (or list) of PCollections. + # (The SDK for Python has no separate type to store multiple + # PCollection objects, whether containing the same or different + # types.) + # [START model_multiple_pcollections_flatten_with] + merged = ( + pcoll1 + | SomeTransform() + | beam.FlattenWith(pcoll2, pcoll3) + | SomeOtherTransform()) + # [END model_multiple_pcollections_flatten_with] + merged | beam.io.WriteToText(output_path) + + +def model_multiple_pcollections_flatten_with_transform(contents, output_path): + """Merging output of PTransform with FlattenWith.""" + some_hash_fn = lambda s: ord(s[0]) + partition_fn = lambda element, partitions: some_hash_fn(element) % partitions + import apache_beam as beam + with TestPipeline() as pipeline: # Use TestPipeline for testing. + + pcoll = pipeline | beam.Create(contents) + SomeTransform = lambda: beam.Map(lambda x: x) + SomeOtherTransform = lambda: beam.Map(lambda x: x) + + # [START model_multiple_pcollections_flatten_with_transform] + merged = ( + pcoll + | SomeTransform() + | beam.FlattenWith(beam.Create(['x', 'y', 'z'])) + | SomeOtherTransform()) + # [END model_multiple_pcollections_flatten_with_transform] + merged | beam.io.WriteToText(output_path) + + def model_multiple_pcollections_partition(contents, output_path): """Splitting a PCollection with Partition.""" some_hash_fn = lambda s: ord(s[0]) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index e8cb8960cf4d..54a57673b5f4 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -917,6 +917,19 @@ def test_model_multiple_pcollections_flatten(self): snippets.model_multiple_pcollections_flatten(contents, result_path) self.assertEqual(contents, self.get_output(result_path)) + def test_model_multiple_pcollections_flatten_with(self): + contents = ['a', 'b', 'c', 'd', 'e', 'f'] + result_path = self.create_temp_file() + snippets.model_multiple_pcollections_flatten_with(contents, result_path) + self.assertEqual(contents, self.get_output(result_path)) + + def test_model_multiple_pcollections_flatten_with_transform(self): + contents = ['a', 'b', 'c', 'd', 'e', 'f'] + result_path = self.create_temp_file() + snippets.model_multiple_pcollections_flatten_with_transform( + contents, result_path) + self.assertEqual(contents + ['x', 'y', 'z'], self.get_output(result_path)) + def test_model_multiple_pcollections_partition(self): contents = [17, 42, 64, 32, 0, 99, 53, 89] result_path = self.create_temp_file() diff --git a/sdks/python/apache_beam/examples/wordcount_xlang_sql.py b/sdks/python/apache_beam/examples/wordcount_xlang_sql.py index 9d7d756f223f..632e90303010 100644 --- a/sdks/python/apache_beam/examples/wordcount_xlang_sql.py +++ b/sdks/python/apache_beam/examples/wordcount_xlang_sql.py @@ -24,7 +24,7 @@ import argparse import logging import re -import typing +from typing import NamedTuple import apache_beam as beam from apache_beam import coders @@ -41,7 +41,7 @@ # # Here we create and register a simple NamedTuple with a single str typed # field named 'word' which we will use below. -MyRow = typing.NamedTuple('MyRow', [('word', str)]) +MyRow = NamedTuple('MyRow', [('word', str)]) coders.registry.register_coder(MyRow, coders.RowCoder) diff --git a/sdks/python/apache_beam/internal/dill_pickler.py b/sdks/python/apache_beam/internal/dill_pickler.py index 7f7ac5b214fa..e1d6b7e74e49 100644 --- a/sdks/python/apache_beam/internal/dill_pickler.py +++ b/sdks/python/apache_beam/internal/dill_pickler.py @@ -46,9 +46,15 @@ settings = {'dill_byref': None} -if sys.version_info >= (3, 10) and dill.__version__ == "0.3.1.1": - # Let's make dill 0.3.1.1 support Python 3.11. +patch_save_code = sys.version_info >= (3, 10) and dill.__version__ == "0.3.1.1" + +def get_normalized_path(path): + """Returns a normalized path. This function is intended to be overridden.""" + return path + + +if patch_save_code: # The following function is based on 'save_code' from 'dill' # Author: Mike McKerns (mmckerns @caltech and @uqfoundation) # Copyright (c) 2008-2015 California Institute of Technology. @@ -66,6 +72,7 @@ @dill.register(CodeType) def save_code(pickler, obj): + co_filename = get_normalized_path(obj.co_filename) if hasattr(obj, "co_endlinetable"): # python 3.11a (20 args) args = ( obj.co_argcount, @@ -78,7 +85,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_qualname, obj.co_firstlineno, @@ -100,7 +107,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_qualname, obj.co_firstlineno, @@ -120,7 +127,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_firstlineno, obj.co_linetable, @@ -138,7 +145,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, diff --git a/sdks/python/apache_beam/internal/pickler_test.py b/sdks/python/apache_beam/internal/pickler_test.py index 824c4c59c0ce..c26a8ee3e653 100644 --- a/sdks/python/apache_beam/internal/pickler_test.py +++ b/sdks/python/apache_beam/internal/pickler_test.py @@ -94,6 +94,11 @@ def test_pickle_rlock(self): self.assertIsInstance(loads(dumps(rlock_instance)), rlock_type) + def test_save_paths(self): + f = loads(dumps(lambda x: x)) + co_filename = f.__code__.co_filename + self.assertTrue(co_filename.endswith('pickler_test.py')) + @unittest.skipIf(NO_MAPPINGPROXYTYPE, 'test if MappingProxyType introduced') def test_dump_and_load_mapping_proxy(self): self.assertEqual( diff --git a/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py b/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py index 7829baba8b69..abe9530787e8 100644 --- a/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py @@ -107,7 +107,7 @@ def start_db_container(self, retries): for i in range(retries): try: self.db = PostgresContainer( - 'debezium/example-postgres:latest', + 'quay.io/debezium/example-postgres:latest', user=self.username, password=self.password, dbname=self.database) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py index a7311ad6d063..3145fb511068 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py @@ -777,26 +777,10 @@ def process( GlobalWindows.windowed_value((destination, job_reference))) def finish_bundle(self): - dataset_locations = {} - for windowed_value in self.pending_jobs: - table_ref = bigquery_tools.parse_table_reference(windowed_value.value[0]) - project_dataset = (table_ref.projectId, table_ref.datasetId) - job_ref = windowed_value.value[1] - # In some cases (e.g. when the load job op returns a 409 ALREADY_EXISTS), - # the returned job reference may not include a location. In such cases, - # we need to override with the dataset's location. - job_location = job_ref.location - if not job_location and project_dataset not in dataset_locations: - job_location = self.bq_wrapper.get_table_location( - table_ref.projectId, table_ref.datasetId, table_ref.tableId) - dataset_locations[project_dataset] = job_location - self.bq_wrapper.wait_for_bq_job( - job_ref, - sleep_duration_sec=_SLEEP_DURATION_BETWEEN_POLLS, - location=job_location) + job_ref, sleep_duration_sec=_SLEEP_DURATION_BETWEEN_POLLS) return self.pending_jobs diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index e4c0e34d9c1f..10453d9c8baf 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -427,7 +427,6 @@ def test_records_traverse_transform_with_mocks(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = bigquery_api.Job() result_job.jobReference = job_reference @@ -483,7 +482,6 @@ def test_load_job_id_used(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'loadJobProject' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = bigquery_api.Job() result_job.jobReference = job_reference @@ -521,7 +519,6 @@ def test_load_job_id_use_for_copy_job(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'loadJobProject' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = mock.Mock() result_job.jobReference = job_reference @@ -577,12 +574,10 @@ def test_wait_for_load_job_completion(self, sleep_mock): job_1.jobReference = bigquery_api.JobReference() job_1.jobReference.projectId = 'project1' job_1.jobReference.jobId = 'jobId1' - job_1.jobReference.location = 'US' job_2 = bigquery_api.Job() job_2.jobReference = bigquery_api.JobReference() job_2.jobReference.projectId = 'project1' job_2.jobReference.jobId = 'jobId2' - job_2.jobReference.location = 'US' job_1_waiting = mock.Mock() job_1_waiting.status.state = 'RUNNING' @@ -622,12 +617,10 @@ def test_one_load_job_failed_after_waiting(self, sleep_mock): job_1.jobReference = bigquery_api.JobReference() job_1.jobReference.projectId = 'project1' job_1.jobReference.jobId = 'jobId1' - job_1.jobReference.location = 'US' job_2 = bigquery_api.Job() job_2.jobReference = bigquery_api.JobReference() job_2.jobReference.projectId = 'project1' job_2.jobReference.jobId = 'jobId2' - job_2.jobReference.location = 'US' job_1_waiting = mock.Mock() job_1_waiting.status.state = 'RUNNING' @@ -664,7 +657,6 @@ def test_multiple_partition_files(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = mock.Mock() result_job.jobReference = job_reference @@ -750,7 +742,6 @@ def test_multiple_partition_files_write_dispositions( job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = mock.Mock() result_job.jobReference = job_reference @@ -793,7 +784,6 @@ def test_triggering_frequency(self, is_streaming, with_auto_sharding): job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = bigquery_api.Job() result_job.jobReference = job_reference diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index c7128e7899ec..b31f6449fe90 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -32,6 +32,7 @@ import io import json import logging +import re import sys import time import uuid @@ -558,6 +559,19 @@ def _insert_load_job( )) return self._start_job(request, stream=source_stream).jobReference + @staticmethod + def _parse_location_from_exc(content, job_id): + """Parse job location from Exception content.""" + if isinstance(content, bytes): + content = content.decode('ascii', 'replace') + # search for "Already Exists: Job :." + m = re.search(r"Already Exists: Job \S+\:(\S+)\." + job_id, content) + if not m: + _LOGGER.warning( + "Not able to parse BigQuery load job location for %s", job_id) + return None + return m.group(1) + def _start_job( self, request, # type: bigquery.BigqueryJobsInsertRequest @@ -585,11 +599,17 @@ def _start_job( return response except HttpError as exn: if exn.status_code == 409: + jobId = request.job.jobReference.jobId _LOGGER.info( "BigQuery job %s already exists, will not retry inserting it: %s", request.job.jobReference, exn) - return request.job + job_location = self._parse_location_from_exc(exn.content, jobId) + response = request.job + if not response.jobReference.location and job_location: + # Request not constructed with location + response.jobReference.location = job_location + return response else: _LOGGER.info( "Failed to insert job %s: %s", request.job.jobReference, exn) @@ -631,8 +651,7 @@ def _start_query_job( return self._start_job(request) - def wait_for_bq_job( - self, job_reference, sleep_duration_sec=5, max_retries=0, location=None): + def wait_for_bq_job(self, job_reference, sleep_duration_sec=5, max_retries=0): """Poll job until it is DONE. Args: @@ -640,7 +659,6 @@ def wait_for_bq_job( sleep_duration_sec: Specifies the delay in seconds between retries. max_retries: The total number of times to retry. If equals to 0, the function waits forever. - location: Fall back on this location if job_reference doesn't have one. Raises: `RuntimeError`: If the job is FAILED or the number of retries has been @@ -650,9 +668,7 @@ def wait_for_bq_job( while True: retry += 1 job = self.get_job( - job_reference.projectId, - job_reference.jobId, - job_reference.location or location) + job_reference.projectId, job_reference.jobId, job_reference.location) _LOGGER.info('Job %s status: %s', job.id, job.status.state) if job.status.state == 'DONE' and job.status.errorResult: raise RuntimeError( diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index b6f801c63f79..9e006dbeda93 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -126,7 +126,7 @@ def _from_proto_str(proto_msg: bytes) -> 'PubsubMessage': """ msg = pubsub.types.PubsubMessage.deserialize(proto_msg) # Convert ScalarMapContainer to dict. - attributes = dict((key, msg.attributes[key]) for key in msg.attributes) + attributes = dict(msg.attributes) return PubsubMessage( msg.data, attributes, @@ -151,10 +151,8 @@ def _to_proto_str(self, for_publish=False): https://cloud.google.com/pubsub/docs/reference/rpc/google.pubsub.v1#google.pubsub.v1.PubsubMessage containing the payload of this object. """ - msg = pubsub.types.PubsubMessage() if len(self.data) > (10_000_000): raise ValueError('A pubsub message data field must not exceed 10MB') - msg.data = self.data if self.attributes: if len(self.attributes) > 100: @@ -167,19 +165,25 @@ def _to_proto_str(self, for_publish=False): if len(value) > 1024: raise ValueError( 'A pubsub message attribute value must not exceed 1024 bytes') - msg.attributes[key] = value + message_id = None + publish_time = None if not for_publish: if self.message_id: - msg.message_id = self.message_id + message_id = self.message_id if self.publish_time: - msg.publish_time = self.publish_time + publish_time = self.publish_time if len(self.ordering_key) > 1024: raise ValueError( 'A pubsub message ordering key must not exceed 1024 bytes.') - msg.ordering_key = self.ordering_key + msg = pubsub.types.PubsubMessage( + data=self.data, + attributes=self.attributes, + message_id=message_id, + publish_time=publish_time, + ordering_key=self.ordering_key) serialized = pubsub.types.PubsubMessage.serialize(msg) if len(serialized) > (10_000_000): raise ValueError( @@ -193,7 +197,7 @@ def _from_message(msg: Any) -> 'PubsubMessage': https://googleapis.github.io/google-cloud-python/latest/pubsub/subscriber/api/message.html """ # Convert ScalarMapContainer to dict. - attributes = dict((key, msg.attributes[key]) for key in msg.attributes) + attributes = dict(msg.attributes) pubsubmessage = PubsubMessage(msg.data, attributes) if msg.message_id: pubsubmessage.message_id = msg.message_id diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 2e3e9b301618..73ba8d6abdb6 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -901,7 +901,8 @@ def test_write_messages_with_attributes_error(self, mock_pubsub): options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True - with self.assertRaisesRegex(Exception, r'Type hint violation'): + with self.assertRaisesRegex(Exception, + r'requires.*PubsubMessage.*applied.*str'): with TestPipeline(options=options) as p: _ = ( p diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py index f402c0acab2f..3e665dd805ea 100644 --- a/sdks/python/apache_beam/metrics/metric.py +++ b/sdks/python/apache_beam/metrics/metric.py @@ -140,7 +140,7 @@ class DelegatingCounter(Counter): def __init__( self, metric_name: MetricName, process_wide: bool = False) -> None: super().__init__(metric_name) - self.inc = MetricUpdater( # type: ignore[assignment] + self.inc = MetricUpdater( # type: ignore[method-assign] cells.CounterCell, metric_name, default_value=1, @@ -150,19 +150,19 @@ class DelegatingDistribution(Distribution): """Metrics Distribution Delegates functionality to MetricsEnvironment.""" def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) - self.update = MetricUpdater(cells.DistributionCell, metric_name) # type: ignore[assignment] + self.update = MetricUpdater(cells.DistributionCell, metric_name) # type: ignore[method-assign] class DelegatingGauge(Gauge): """Metrics Gauge that Delegates functionality to MetricsEnvironment.""" def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) - self.set = MetricUpdater(cells.GaugeCell, metric_name) # type: ignore[assignment] + self.set = MetricUpdater(cells.GaugeCell, metric_name) # type: ignore[method-assign] class DelegatingStringSet(StringSet): """Metrics StringSet that Delegates functionality to MetricsEnvironment.""" def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) - self.add = MetricUpdater(cells.StringSetCell, metric_name) # type: ignore[assignment] + self.add = MetricUpdater(cells.StringSetCell, metric_name) # type: ignore[method-assign] class MetricResults(object): diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index e7af114ad431..4ac856456748 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -116,7 +116,15 @@ def load_model(self) -> ort.InferenceSession: # when path is remote, we should first load into memory then deserialize f = FileSystems.open(self._model_uri, "rb") model_proto = onnx.load(f) - model_proto_bytes = onnx._serialize(model_proto) + model_proto_bytes = model_proto + if not isinstance(model_proto, bytes): + if (hasattr(model_proto, "SerializeToString") and + callable(model_proto.SerializeToString)): + model_proto_bytes = model_proto.SerializeToString() + else: + raise TypeError( + "No SerializeToString method is detected on loaded model. " + + f"Type of model: {type(model_proto)}") ort_session = ort.InferenceSession( model_proto_bytes, sess_options=self._session_options, diff --git a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile index 5abbffdc5a2a..f27abbfd0051 100644 --- a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile +++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile @@ -40,7 +40,7 @@ RUN pip install openai vllm RUN apt install libcairo2-dev pkg-config python3-dev -y RUN pip install pycairo -# Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image. +# Copy the Apache Beam worker dependencies from the Beam Python 3.12 SDK image. COPY --from=apache/beam_python3.12_sdk:2.58.1 /opt/apache/beam /opt/apache/beam # Set the entrypoint to Apache Beam SDK worker launcher. diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index e1ba4f49b8fd..799083d16ceb 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -21,6 +21,7 @@ import logging import os import subprocess +import sys import threading import time import uuid @@ -118,7 +119,7 @@ def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]): def start_server(self, retries=3): if not self._server_started: server_cmd = [ - 'python', + sys.executable, '-m', 'vllm.entrypoints.openai.api_server', '--model', @@ -131,7 +132,7 @@ def start_server(self, retries=3): server_cmd.append(v) self._server_process, self._server_port = start_process(server_cmd) - self.check_connectivity() + self.check_connectivity(retries) def get_server_port(self) -> int: if not self._server_started: diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py b/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py index eab547b1c17b..e09f116dfb38 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py @@ -53,7 +53,7 @@ def _compare_prediction_result(a: PredictionResult, b: PredictionResult): example_equal = numpy.array_equal(a.example.todense(), b.example.todense()) else: - example_equal = numpy.array_equal(a.example, b.example) + example_equal = numpy.array_equal(a.example, b.example) # type: ignore[arg-type] if isinstance(a.inference, dict): return all( x == y for x, y in zip(a.inference.values(), diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index 678ab0882d24..a963f602a06d 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -20,14 +20,11 @@ import os import tempfile import uuid +from collections.abc import Mapping +from collections.abc import Sequence from typing import Any -from typing import Dict from typing import Generic -from typing import List -from typing import Mapping from typing import Optional -from typing import Sequence -from typing import Tuple from typing import TypeVar from typing import Union @@ -67,7 +64,7 @@ def _convert_list_of_dicts_to_dict_of_lists( - list_of_dicts: Sequence[Dict[str, Any]]) -> Dict[str, List[Any]]: + list_of_dicts: Sequence[dict[str, Any]]) -> dict[str, list[Any]]: keys_to_element_list = collections.defaultdict(list) input_keys = list_of_dicts[0].keys() for d in list_of_dicts: @@ -83,9 +80,9 @@ def _convert_list_of_dicts_to_dict_of_lists( def _convert_dict_of_lists_to_lists_of_dict( - dict_of_lists: Dict[str, List[Any]]) -> List[Dict[str, Any]]: + dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]: batch_length = len(next(iter(dict_of_lists.values()))) - result: List[Dict[str, Any]] = [{} for _ in range(batch_length)] + result: list[dict[str, Any]] = [{} for _ in range(batch_length)] # all the values in the dict_of_lists should have same length for key, values in dict_of_lists.items(): assert len(values) == batch_length, ( @@ -140,7 +137,7 @@ def get_counter(self): class BaseOperation(Generic[OperationInputT, OperationOutputT], MLTransformProvider, abc.ABC): - def __init__(self, columns: List[str]) -> None: + def __init__(self, columns: list[str]) -> None: """ Base Opertation class data processing transformations. Args: @@ -150,7 +147,7 @@ def __init__(self, columns: List[str]) -> None: @abc.abstractmethod def apply_transform(self, data: OperationInputT, - output_column_name: str) -> Dict[str, OperationOutputT]: + output_column_name: str) -> dict[str, OperationOutputT]: """ Define any processing logic in the apply_transform() method. processing logics are applied on inputs and returns a transformed @@ -160,7 +157,7 @@ def apply_transform(self, data: OperationInputT, """ def __call__(self, data: OperationInputT, - output_column_name: str) -> Dict[str, OperationOutputT]: + output_column_name: str) -> dict[str, OperationOutputT]: """ This method is called when the instance of the class is called. This method will invoke the apply() method of the class. @@ -172,7 +169,7 @@ def __call__(self, data: OperationInputT, class ProcessHandler( beam.PTransform[beam.PCollection[ExampleT], Union[beam.PCollection[MLTransformOutputT], - Tuple[beam.PCollection[MLTransformOutputT], + tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]], abc.ABC): """ @@ -190,10 +187,10 @@ def append_transform(self, transform: BaseOperation): class EmbeddingsManager(MLTransformProvider): def __init__( self, - columns: List[str], + columns: list[str], *, # common args for all ModelHandlers. - load_model_args: Optional[Dict[str, Any]] = None, + load_model_args: Optional[dict[str, Any]] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, large_model: bool = False, @@ -222,7 +219,7 @@ def get_columns_to_apply(self): class MLTransform( beam.PTransform[beam.PCollection[ExampleT], Union[beam.PCollection[MLTransformOutputT], - Tuple[beam.PCollection[MLTransformOutputT], + tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]], Generic[ExampleT, MLTransformOutputT]): def __init__( @@ -230,7 +227,7 @@ def __init__( *, write_artifact_location: Optional[str] = None, read_artifact_location: Optional[str] = None, - transforms: Optional[List[MLTransformProvider]] = None): + transforms: Optional[list[MLTransformProvider]] = None): """ MLTransform is a Beam PTransform that can be used to apply transformations to the data. MLTransform is used to wrap the @@ -304,12 +301,12 @@ def __init__( self._counter = Metrics.counter( MLTransform, f'BeamML_{self.__class__.__name__}') self._with_exception_handling = False - self._exception_handling_args: Dict[str, Any] = {} + self._exception_handling_args: dict[str, Any] = {} def expand( self, pcoll: beam.PCollection[ExampleT] ) -> Union[beam.PCollection[MLTransformOutputT], - Tuple[beam.PCollection[MLTransformOutputT], + tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]: """ This is the entrypoint for the MLTransform. This method will @@ -533,7 +530,7 @@ class _MLTransformToPTransformMapper: """ def __init__( self, - transforms: List[MLTransformProvider], + transforms: list[MLTransformProvider], artifact_location: str, artifact_mode: str = ArtifactMode.PRODUCE, pipeline_options: Optional[PipelineOptions] = None, @@ -595,7 +592,7 @@ class _EmbeddingHandler(ModelHandler): For example, if the original mode is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a - PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. _EmbeddingHandler will accept an EmbeddingsManager instance, which contains the details of the model to be loaded and the inference_fn to be @@ -619,7 +616,7 @@ def load_model(self): def _validate_column_data(self, batch): pass - def _validate_batch(self, batch: Sequence[Dict[str, Any]]): + def _validate_batch(self, batch: Sequence[dict[str, Any]]): if not batch or not isinstance(batch[0], dict): raise TypeError( 'Expected data to be dicts, got ' @@ -627,10 +624,10 @@ def _validate_batch(self, batch: Sequence[Dict[str, Any]]): def _process_batch( self, - dict_batch: Dict[str, List[Any]], + dict_batch: dict[str, list[Any]], model: ModelT, - inference_args: Optional[Dict[str, Any]]) -> Dict[str, List[Any]]: - result: Dict[str, List[Any]] = collections.defaultdict(list) + inference_args: Optional[dict[str, Any]]) -> dict[str, list[Any]]: + result: dict[str, list[Any]] = collections.defaultdict(list) input_keys = dict_batch.keys() missing_columns_in_data = set(self.columns) - set(input_keys) if missing_columns_in_data: @@ -653,10 +650,10 @@ def _process_batch( def run_inference( self, - batch: Sequence[Dict[str, List[str]]], + batch: Sequence[dict[str, list[str]]], model: ModelT, - inference_args: Optional[Dict[str, Any]] = None, - ) -> List[Dict[str, Union[List[float], List[str]]]]: + inference_args: Optional[dict[str, Any]] = None, + ) -> list[dict[str, Union[list[float], list[str]]]]: """ Runs inference on a batch of text inputs. The inputs are expected to be a list of dicts. Each dict should have the same keys, and the shape @@ -696,7 +693,7 @@ class _TextEmbeddingHandler(_EmbeddingHandler): For example, if the original mode is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a - PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. _TextEmbeddingHandler will accept an EmbeddingsManager instance, which contains the details of the model to be loaded and the inference_fn to be @@ -713,8 +710,8 @@ class _TextEmbeddingHandler(_EmbeddingHandler): def _validate_column_data(self, batch): if not isinstance(batch[0], (str, bytes)): raise TypeError( - 'Embeddings can only be generated on Dict[str, str].' - f'Got Dict[str, {type(batch[0])}] instead.') + 'Embeddings can only be generated on dict[str, str].' + f'Got dict[str, {type(batch[0])}] instead.') def get_metrics_namespace(self) -> str: return ( @@ -730,7 +727,7 @@ class _ImageEmbeddingHandler(_EmbeddingHandler): For example, if the original mode is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a - PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. _ImageEmbeddingHandler will accept an EmbeddingsManager instance, which contains the details of the model to be loaded and the inference_fn to be @@ -750,8 +747,8 @@ def _validate_column_data(self, batch): # here, so just catch columns of primatives for now. if isinstance(batch[0], (int, str, float, bool)): raise TypeError( - 'Embeddings can only be generated on Dict[str, Image].' - f'Got Dict[str, {type(batch[0])}] instead.') + 'Embeddings can only be generated on dict[str, Image].' + f'Got dict[str, {type(batch[0])}] instead.') def get_metrics_namespace(self) -> str: return ( diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 743c3683ce8e..df5bf826742e 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -21,11 +21,9 @@ import tempfile import typing import unittest +from collections.abc import Sequence from typing import Any -from typing import Dict -from typing import List from typing import Optional -from typing import Sequence import numpy as np from parameterized import param @@ -162,7 +160,7 @@ def test_ml_transform_on_list_dict(self): 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0] }], input_types={ - 'x': List[int], 'y': List[float] + 'x': list[int], 'y': list[float] }, expected_dtype={ 'x': typing.Sequence[np.float32], @@ -320,7 +318,7 @@ def test_read_mode_with_transforms(self): class FakeModel: - def __call__(self, example: List[str]) -> List[str]: + def __call__(self, example: list[str]) -> list[str]: for i in range(len(example)): if not isinstance(example[i], str): raise TypeError('Input must be a string') @@ -333,7 +331,7 @@ def run_inference( self, batch: Sequence[str], model: Any, - inference_args: Optional[Dict[str, Any]] = None): + inference_args: Optional[dict[str, Any]] = None): return model(batch) def load_model(self): @@ -345,7 +343,7 @@ def __init__(self, columns, **kwargs): super().__init__(columns=columns, **kwargs) def get_model_handler(self) -> ModelHandler: - FakeModelHandler.__repr__ = lambda x: 'FakeEmbeddingsManager' # type: ignore[assignment] + FakeModelHandler.__repr__ = lambda x: 'FakeEmbeddingsManager' # type: ignore[method-assign] return FakeModelHandler() def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: @@ -508,7 +506,7 @@ def test_handler_with_inconsistent_keys(self): class FakeImageModel: - def __call__(self, example: List[PIL_Image]) -> List[PIL_Image]: + def __call__(self, example: list[PIL_Image]) -> list[PIL_Image]: for i in range(len(example)): if not isinstance(example[i], PIL_Image): raise TypeError('Input must be an Image') @@ -520,7 +518,7 @@ def run_inference( self, batch: Sequence[PIL_Image], model: Any, - inference_args: Optional[Dict[str, Any]] = None): + inference_args: Optional[dict[str, Any]] = None): return model(batch) def load_model(self): @@ -532,7 +530,7 @@ def __init__(self, columns, **kwargs): super().__init__(columns=columns, **kwargs) def get_model_handler(self) -> ModelHandler: - FakeModelHandler.__repr__ = lambda x: 'FakeImageEmbeddingsManager' # type: ignore[assignment] + FakeModelHandler.__repr__ = lambda x: 'FakeImageEmbeddingsManager' # type: ignore[method-assign] return FakeImageModelHandler() def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py index 46b4ef9cf7d6..2162ed050c42 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py @@ -18,13 +18,11 @@ import logging import os +from collections.abc import Callable +from collections.abc import Mapping +from collections.abc import Sequence from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Mapping from typing import Optional -from typing import Sequence import requests @@ -80,7 +78,7 @@ def run_inference( self, batch: Sequence[str], model: SentenceTransformer, - inference_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[dict[str, Any]] = None, ): inference_args = inference_args or {} return model.encode(batch, **inference_args) @@ -113,7 +111,7 @@ class SentenceTransformerEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, - columns: List[str], + columns: list[str], max_seq_length: Optional[int] = None, image_model: bool = False, **kwargs): @@ -216,7 +214,7 @@ class InferenceAPIEmbeddings(EmbeddingsManager): def __init__( self, hf_token: Optional[str], - columns: List[str], + columns: list[str], model_name: Optional[str] = None, # example: "sentence-transformers/all-MiniLM-l6-v2" # pylint: disable=line-too-long api_url: Optional[str] = None, **kwargs, diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py index f78ddf3ff04a..c14904df7c2c 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -14,8 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable -from typing import List +from collections.abc import Iterable from typing import Optional import apache_beam as beam @@ -90,7 +89,7 @@ def run_inference(self, batch, model, inference_args, model_id=None): class TensorflowHubTextEmbeddings(EmbeddingsManager): def __init__( self, - columns: List[str], + columns: list[str], hub_url: str, preprocessing_url: Optional[str] = None, **kwargs): @@ -136,7 +135,7 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: class TensorflowHubImageEmbeddings(EmbeddingsManager): - def __init__(self, columns: List[str], hub_url: str, **kwargs): + def __init__(self, columns: list[str], hub_url: str, **kwargs): """ Embedding config for tensorflow hub models. This config can be used with MLTransform to embed image data. Models are loaded using the RunInference diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index fbefeec231f1..6fe8320e758b 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -19,12 +19,10 @@ # Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long # to install Vertex AI Python SDK. +from collections.abc import Iterable +from collections.abc import Sequence from typing import Any -from typing import Dict -from typing import Iterable -from typing import List from typing import Optional -from typing import Sequence from google.auth.credentials import Credentials @@ -80,7 +78,7 @@ def run_inference( self, batch: Sequence[str], model: Any, - inference_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[dict[str, Any]] = None, ) -> Iterable: embeddings = [] batch_size = _BATCH_SIZE @@ -110,7 +108,7 @@ class VertexAITextEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, - columns: List[str], + columns: list[str], title: Optional[str] = None, task_type: str = DEFAULT_TASK_TYPE, project: Optional[str] = None, @@ -179,7 +177,7 @@ def run_inference( self, batch: Sequence[Image], model: MultiModalEmbeddingModel, - inference_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[dict[str, Any]] = None, ) -> Iterable: embeddings = [] # Maximum request size for muli-model embedding models is 1. @@ -204,7 +202,7 @@ class VertexAIImageEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, - columns: List[str], + columns: list[str], dimension: Optional[int], project: Optional[str] = None, location: Optional[str] = None, diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index 7a912f2d88ea..1e752049f6e5 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -20,11 +20,10 @@ import copy import os import typing +from collections.abc import Sequence from typing import Any -from typing import Dict -from typing import List +from typing import NamedTuple from typing import Optional -from typing import Sequence from typing import Union import numpy as np @@ -71,18 +70,18 @@ np.str_: tf.string, } _primitive_types_to_typing_container_type = { - int: List[int], float: List[float], str: List[str], bytes: List[bytes] + int: list[int], float: list[float], str: list[str], bytes: list[bytes] } -tft_process_handler_input_type = typing.Union[typing.NamedTuple, - beam.Row, - Dict[str, - typing.Union[str, - float, - int, - bytes, - np.ndarray]]] -tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]] +tft_process_handler_input_type = Union[NamedTuple, + beam.Row, + dict[str, + Union[str, + float, + int, + bytes, + np.ndarray]]] +tft_process_handler_output_type = Union[beam.Row, dict[str, np.ndarray]] class _DataCoder: @@ -131,15 +130,15 @@ def process( class _ConvertNamedTupleToDict( - beam.PTransform[beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]], - beam.PCollection[Dict[str, + beam.PTransform[beam.PCollection[Union[beam.Row, NamedTuple]], + beam.PCollection[dict[str, common_types.InstanceDictType]]]): """ A PTransform that converts a collection of NamedTuples or Rows into a collection of dictionaries. """ def expand( - self, pcoll: beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]] + self, pcoll: beam.PCollection[Union[beam.Row, NamedTuple]] ) -> beam.PCollection[common_types.InstanceDictType]: """ Args: @@ -163,7 +162,7 @@ def __init__( operations. """ self.transforms = transforms if transforms else [] - self.transformed_schema: Dict[str, type] = {} + self.transformed_schema: dict[str, type] = {} self.artifact_location = artifact_location self.artifact_mode = artifact_mode if artifact_mode not in ['produce', 'consume']: @@ -217,7 +216,7 @@ def _map_column_names_to_types_from_transforms(self): return column_type_mapping def get_raw_data_feature_spec( - self, input_types: Dict[str, type]) -> Dict[str, tf.io.VarLenFeature]: + self, input_types: dict[str, type]) -> dict[str, tf.io.VarLenFeature]: """ Return a DatasetMetadata object to be used with tft_beam.AnalyzeAndTransformDataset. @@ -265,7 +264,7 @@ def _get_raw_data_feature_spec_per_column( f"Union type is not supported for column: {col_name}. " f"Please pass a PCollection with valid schema for column " f"{col_name} by passing a single type " - "in container. For example, List[int].") + "in container. For example, list[int].") elif issubclass(typ, np.generic) or typ in _default_type_to_tensor_type_map: dtype = typ else: @@ -276,7 +275,7 @@ def _get_raw_data_feature_spec_per_column( return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype]) def get_raw_data_metadata( - self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + self, input_types: dict[str, type]) -> dataset_metadata.DatasetMetadata: raw_data_feature_spec = self.get_raw_data_feature_spec(input_types) raw_data_feature_spec[_TEMP_KEY] = tf.io.VarLenFeature(dtype=tf.string) return self.convert_raw_data_feature_spec_to_dataset_metadata( @@ -305,8 +304,8 @@ def _fail_on_non_default_windowing(self, pcoll: beam.PCollection): "to convert your PCollection to GlobalWindow.") def process_data_fn( - self, inputs: Dict[str, common_types.ConsistentTensorType] - ) -> Dict[str, common_types.ConsistentTensorType]: + self, inputs: dict[str, common_types.ConsistentTensorType] + ) -> dict[str, common_types.ConsistentTensorType]: """ This method is used in the AnalyzeAndTransformDataset step. It applies the transforms to the `inputs` in sequential order on the columns @@ -335,11 +334,11 @@ def _get_transformed_data_schema( name = feature.name feature_type = feature.type if feature_type == schema_pb2.FeatureType.FLOAT: - transformed_types[name] = typing.Sequence[np.float32] + transformed_types[name] = Sequence[np.float32] elif feature_type == schema_pb2.FeatureType.INT: - transformed_types[name] = typing.Sequence[np.int64] # type: ignore[assignment] + transformed_types[name] = Sequence[np.int64] # type: ignore[assignment] else: - transformed_types[name] = typing.Sequence[bytes] # type: ignore[assignment] + transformed_types[name] = Sequence[bytes] # type: ignore[assignment] return transformed_types def expand( @@ -372,7 +371,7 @@ def expand( raw_data = ( raw_data | _ConvertNamedTupleToDict().with_output_types( - Dict[str, typing.Union[tuple(column_type_mapping.values())]])) # type: ignore + dict[str, Union[tuple(column_type_mapping.values())]])) # type: ignore # AnalyzeAndTransformDataset raise type hint since this is # schema'd PCollection and the current output type would be a # custom type(NamedTuple) or a beam.Row type. @@ -408,7 +407,7 @@ def expand( raw_data = ( raw_data | _ConvertNamedTupleToDict().with_output_types( - Dict[str, typing.Union[tuple(column_type_mapping.values())]])) # type: ignore + dict[str, Union[tuple(column_type_mapping.values())]])) # type: ignore feature_set = [feature.name for feature in raw_data_metadata.schema.feature] diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py index 1331f1308087..4b53026c36a4 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers_test.py +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -23,7 +23,6 @@ import typing import unittest import uuid -from typing import List from typing import NamedTuple from typing import Union @@ -65,7 +64,7 @@ class IntType(NamedTuple): class ListIntType(NamedTuple): - x: List[int] + x: list[int] class NumpyType(NamedTuple): @@ -111,7 +110,7 @@ def test_input_type_from_schema_named_tuple_pcoll(self): artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) @@ -126,7 +125,7 @@ def test_input_type_from_schema_named_tuple_pcoll_list(self): artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) def test_input_type_from_row_type_pcoll(self): @@ -140,7 +139,7 @@ def test_input_type_from_row_type_pcoll(self): artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) def test_input_type_from_row_type_pcoll_list(self): @@ -149,14 +148,14 @@ def test_input_type_from_row_type_pcoll_list(self): data = ( p | beam.Create(data) | beam.Map(lambda ele: beam.Row(x=list(ele['x']))).with_output_types( - beam.row_type.RowTypeConstraint.from_fields([('x', List[int])]))) + beam.row_type.RowTypeConstraint.from_fields([('x', list[int])]))) element_type = data.element_type process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) def test_input_type_from_named_tuple_pcoll_numpy(self): @@ -190,8 +189,8 @@ def test_tensorflow_raw_data_metadata_primitive_types(self): self.assertIsInstance(feature_spec, tf.io.VarLenFeature) def test_tensorflow_raw_data_metadata_primitive_types_in_containers(self): - input_types = dict([("x", List[int]), ("y", List[float]), - ("k", List[bytes]), ("l", List[str])]) + input_types = dict([("x", list[int]), ("y", list[float]), + ("k", list[bytes]), ("l", list[str])]) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) for col_name, typ in input_types.items(): @@ -211,7 +210,7 @@ def test_tensorflow_raw_data_metadata_primitive_native_container_types(self): self.assertIsInstance(feature_spec, tf.io.VarLenFeature) def test_tensorflow_raw_data_metadata_numpy_types(self): - input_types = dict(x=np.int64, y=np.float32, z=List[np.int64]) + input_types = dict(x=np.int64, y=np.float32, z=list[np.int64]) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) for col_name, typ in input_types.items(): diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index 6903cca89419..bfe23757642b 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -34,12 +34,9 @@ # pytype: skip-file import logging +from collections.abc import Iterable from typing import Any -from typing import Dict -from typing import Iterable -from typing import List from typing import Optional -from typing import Tuple from typing import Union import apache_beam as beam @@ -67,7 +64,7 @@ # Register the expected input types for each operation # this will be used to determine schema for the tft.AnalyzeDataset -_EXPECTED_TYPES: Dict[str, Union[int, str, float]] = {} +_EXPECTED_TYPES: dict[str, Union[int, str, float]] = {} _LOGGER = logging.getLogger(__name__) @@ -84,7 +81,7 @@ def wrapper(fn): # Add support for outputting artifacts to a text file in human readable form. class TFTOperation(BaseOperation[common_types.TensorType, common_types.TensorType]): - def __init__(self, columns: List[str]) -> None: + def __init__(self, columns: list[str]) -> None: """ Base Operation class for TFT data processing transformations. Processing logic for the transformation is defined in the @@ -150,7 +147,7 @@ def _split_string_with_delimiter(self, data, delimiter): class ComputeAndApplyVocabulary(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], split_string_by_delimiter: Optional[str] = None, *, default_value: Any = -1, @@ -193,7 +190,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: if self.split_string_by_delimiter: data = self._split_string_with_delimiter( @@ -218,7 +215,7 @@ def apply_transform( class ScaleToZScore(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], *, elementwise: bool = False, name: Optional[str] = None): @@ -247,7 +244,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_column_name: tft.scale_to_z_score( x=data, elementwise=self.elementwise, name=self.name) @@ -259,7 +256,7 @@ def apply_transform( class ScaleTo01(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], elementwise: bool = False, name: Optional[str] = None): """ @@ -287,7 +284,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = tft.scale_to_0_1( x=data, elementwise=self.elementwise, name=self.name) @@ -299,7 +296,7 @@ def apply_transform( class ScaleToGaussian(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], elementwise: bool = False, name: Optional[str] = None): """ @@ -324,7 +321,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_column_name: tft.scale_to_gaussian( x=data, elementwise=self.elementwise, name=self.name) @@ -336,7 +333,7 @@ def apply_transform( class ApplyBuckets(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], bucket_boundaries: Iterable[Union[int, float]], name: Optional[str] = None): """ @@ -359,7 +356,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = { output_column_name: tft.apply_buckets( x=data, bucket_boundaries=self.bucket_boundaries, name=self.name) @@ -371,7 +368,7 @@ def apply_transform( class ApplyBucketsWithInterpolation(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], bucket_boundaries: Iterable[Union[int, float]], name: Optional[str] = None): """Interpolates values within the provided buckets and then normalizes to @@ -398,7 +395,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = { output_column_name: tft.apply_buckets_with_interpolation( x=data, bucket_boundaries=self.bucket_boundaries, name=self.name) @@ -410,7 +407,7 @@ def apply_transform( class Bucketize(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], num_buckets: int, *, epsilon: Optional[float] = None, @@ -443,7 +440,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = { output_column_name: tft.bucketize( x=data, @@ -459,7 +456,7 @@ def apply_transform( class TFIDF(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], vocab_size: Optional[int] = None, smooth: bool = True, name: Optional[str] = None, @@ -530,7 +527,7 @@ def apply_transform( class ScaleByMinMax(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], min_value: float = 0.0, max_value: float = 1.0, name: Optional[str] = None): @@ -566,10 +563,10 @@ def apply_transform( class NGrams(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], split_string_by_delimiter: Optional[str] = None, *, - ngram_range: Tuple[int, int] = (1, 1), + ngram_range: tuple[int, int] = (1, 1), ngrams_separator: Optional[str] = None, name: Optional[str] = None): """ @@ -599,7 +596,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: if self.split_string_by_delimiter: data = self._split_string_with_delimiter( data, self.split_string_by_delimiter) @@ -611,10 +608,10 @@ def apply_transform( class BagOfWords(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], split_string_by_delimiter: Optional[str] = None, *, - ngram_range: Tuple[int, int] = (1, 1), + ngram_range: tuple[int, int] = (1, 1), ngrams_separator: Optional[str] = None, compute_word_count: bool = False, key_vocab_filename: Optional[str] = None, @@ -686,9 +683,9 @@ def count_unique_words( class HashStrings(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], hash_buckets: int, - key: Optional[Tuple[int, int]] = None, + key: Optional[tuple[int, int]] = None, name: Optional[str] = None): '''Hashes strings into the provided number of buckets. @@ -715,7 +712,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_col_name: str) -> Dict[str, common_types.TensorType]: + output_col_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_col_name: tft.hash_strings( strings=data, @@ -728,7 +725,7 @@ def apply_transform( @register_input_dtype(str) class DeduplicateTensorPerRow(TFTOperation): - def __init__(self, columns: List[str], name: Optional[str] = None): + def __init__(self, columns: list[str], name: Optional[str] = None): """ Deduplicates each row (0th dimension) of the provided tensor. Args: @@ -740,7 +737,7 @@ def __init__(self, columns: List[str], name: Optional[str] = None): def apply_transform( self, data: common_types.TensorType, - output_col_name: str) -> Dict[str, common_types.TensorType]: + output_col_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_col_name: tft.deduplicate_tensor_per_row( input_tensor=data, name=self.name) diff --git a/sdks/python/apache_beam/ml/transforms/utils.py b/sdks/python/apache_beam/ml/transforms/utils.py index abf4c48fc642..023657895686 100644 --- a/sdks/python/apache_beam/ml/transforms/utils.py +++ b/sdks/python/apache_beam/ml/transforms/utils.py @@ -19,7 +19,6 @@ import os import tempfile -import typing from google.cloud.storage import Client from google.cloud.storage import transfer_manager @@ -72,7 +71,7 @@ def __init__(self, artifact_location: str): self._artifact_location = os.path.join(artifact_location, files[0]) self.transform_output = tft.TFTransformOutput(self._artifact_location) - def get_vocab_list(self, vocab_filename: str) -> typing.List[bytes]: + def get_vocab_list(self, vocab_filename: str) -> list[bytes]: """ Returns list of vocabulary terms created during MLTransform. """ diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 4497ab0993a4..af0c5e3de66f 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1674,12 +1674,16 @@ def _add_argparse_args(cls, parser): action='append', default=[], help='JVM properties to pass to a Java job server.') + parser.add_argument( + '--jar_cache_dir', + default=None, + help='The location to store jar cache for job server.') class FlinkRunnerOptions(PipelineOptions): # These should stay in sync with gradle.properties. - PUBLISHED_FLINK_VERSIONS = ['1.15', '1.16', '1.17', '1.18'] + PUBLISHED_FLINK_VERSIONS = ['1.17', '1.18', '1.19'] @classmethod def _add_argparse_args(cls, parser): diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index c0616bc6451c..66acfe654791 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -31,6 +31,7 @@ from apache_beam.options.pipeline_options import CrossLanguageOptions from apache_beam.options.pipeline_options import DebugOptions from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.options.pipeline_options import JobServerOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import ProfilingOptions from apache_beam.options.pipeline_options import TypeOptions @@ -645,6 +646,11 @@ def test_transform_name_mapping(self): mapping = options.view_as(GoogleCloudOptions).transform_name_mapping self.assertEqual(mapping['from'], 'to') + def test_jar_cache_dir(self): + options = PipelineOptions(['--jar_cache_dir=/path/to/jar_cache_dir']) + jar_cache_dir = options.view_as(JobServerOptions).jar_cache_dir + self.assertEqual(jar_cache_dir, '/path/to/jar_cache_dir') + def test_dataflow_service_options(self): options = PipelineOptions([ '--dataflow_service_option', diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 8a4f26c18e88..c43870d55ebb 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -1504,8 +1504,7 @@ def process(self, windowed_value): return [] def _maybe_sample_exception( - self, exn: BaseException, - windowed_value: Optional[WindowedValue]) -> None: + self, exc_info: Tuple, windowed_value: Optional[WindowedValue]) -> None: if self.execution_context is None: return @@ -1516,7 +1515,7 @@ def _maybe_sample_exception( output_sampler.sample_exception( windowed_value, - exn, + exc_info, self.transform_id, self.execution_context.instruction_id) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 20cae582f320..97996bd6cbb2 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -82,7 +82,7 @@ _LOGGER = logging.getLogger(__name__) -_PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW = ['3.8', '3.9', '3.10', '3.11', '3.12'] +_PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW = ['3.9', '3.10', '3.11', '3.12'] class Environment(object): diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 022136aae9a2..6587e619a500 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -1003,7 +1003,21 @@ def test_interpreter_version_check_passes_with_experiment(self): 'apache_beam.runners.dataflow.internal.apiclient.' 'beam_version.__version__', '2.2.0') - def test_interpreter_version_check_passes_py38(self): + def test_interpreter_version_check_fails_py38(self): + pipeline_options = PipelineOptions([]) + self.assertRaises( + Exception, + apiclient._verify_interpreter_version_is_supported, + pipeline_options) + + @mock.patch( + 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', + (3, 9, 6)) + @mock.patch( + 'apache_beam.runners.dataflow.internal.apiclient.' + 'beam_version.__version__', + '2.2.0') + def test_interpreter_version_check_passes_py39(self): pipeline_options = PipelineOptions([]) apiclient._verify_interpreter_version_is_supported(pipeline_options) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py index 239ee8c700a2..c0d20c3ec8f9 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py @@ -30,4 +30,4 @@ pass # pylint: enable=wrong-import-order, wrong-import-position -__path__ = pkgutil.extend_path(__path__, __name__) # type: ignore +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 49b6622816ce..8b8937653688 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -110,6 +110,38 @@ def visit_transform(self, applied_ptransform): if timer.time_domain == TimeDomain.REAL_TIME: self.supported_by_fnapi_runner = False + class _PrismRunnerSupportVisitor(PipelineVisitor): + """Visitor determining if a Pipeline can be run on the PrismRunner.""" + def accept(self, pipeline): + self.supported_by_prism_runner = True + pipeline.visit(self) + return self.supported_by_prism_runner + + def visit_transform(self, applied_ptransform): + transform = applied_ptransform.transform + # Python SDK assumes the direct runner TestStream implementation is + # being used. + if isinstance(transform, TestStream): + self.supported_by_prism_runner = False + if isinstance(transform, beam.ParDo): + dofn = transform.dofn + # It's uncertain if the Prism Runner supports execution of CombineFns + # with deferred side inputs. + if isinstance(dofn, CombineValuesDoFn): + args, kwargs = transform.raw_side_inputs + args_to_check = itertools.chain(args, kwargs.values()) + if any(isinstance(arg, ArgumentPlaceholder) + for arg in args_to_check): + self.supported_by_prism_runner = False + if userstate.is_stateful_dofn(dofn): + # https://github.com/apache/beam/issues/32786 - + # Remove once Real time clock is used. + _, timer_specs = userstate.get_dofn_specs(dofn) + for timer in timer_specs: + if timer.time_domain == TimeDomain.REAL_TIME: + self.supported_by_prism_runner = False + + tryingPrism = False # Check whether all transforms used in the pipeline are supported by the # FnApiRunner, and the pipeline was not meant to be run as streaming. if _FnApiRunnerSupportVisitor().accept(pipeline): @@ -122,9 +154,33 @@ def visit_transform(self, applied_ptransform): beam_provision_api_pb2.ProvisionInfo( pipeline_options=encoded_options)) runner = fn_runner.FnApiRunner(provision_info=provision_info) + elif _PrismRunnerSupportVisitor().accept(pipeline): + _LOGGER.info('Running pipeline with PrismRunner.') + from apache_beam.runners.portability import prism_runner + runner = prism_runner.PrismRunner() + tryingPrism = True else: runner = BundleBasedDirectRunner() + if tryingPrism: + try: + pr = runner.run_pipeline(pipeline, options) + # This is non-blocking, so if the state is *already* finished, something + # probably failed on job submission. + if pr.state.is_terminal() and pr.state != PipelineState.DONE: + _LOGGER.info( + 'Pipeline failed on PrismRunner, falling back toDirectRunner.') + runner = BundleBasedDirectRunner() + else: + return pr + except Exception as e: + # If prism fails in Preparing the portable job, then the PortableRunner + # code raises an exception. Catch it, log it, and use the Direct runner + # instead. + _LOGGER.info('Exception with PrismRunner:\n %s\n' % (e)) + _LOGGER.info('Falling back to DirectRunner') + runner = BundleBasedDirectRunner() + return runner.run_pipeline(pipeline, options) diff --git a/sdks/python/apache_beam/runners/interactive/caching/reify.py b/sdks/python/apache_beam/runners/interactive/caching/reify.py index ce82785b2585..c82033dc1b9b 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/reify.py +++ b/sdks/python/apache_beam/runners/interactive/caching/reify.py @@ -28,7 +28,6 @@ import apache_beam as beam from apache_beam.runners.interactive import cache_manager as cache from apache_beam.testing import test_stream -from apache_beam.transforms.window import WindowedValue READ_CACHE = 'ReadCache_' WRITE_CACHE = 'WriteCache_' @@ -40,13 +39,8 @@ class Reify(beam.DoFn): Internally used to capture window info with each element into cache for replayability. """ - def process( - self, - e, - w=beam.DoFn.WindowParam, - p=beam.DoFn.PaneInfoParam, - t=beam.DoFn.TimestampParam): - yield test_stream.WindowedValueHolder(WindowedValue(e, t, [w], p)) + def process(self, e, wv=beam.DoFn.WindowedValueParam): + yield test_stream.WindowedValueHolder(wv) class Unreify(beam.DoFn): diff --git a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py index 693abb2aeeee..d767a15a345d 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py +++ b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py @@ -26,6 +26,7 @@ import datetime import html import logging +import warnings from datetime import timedelta from typing import Optional @@ -350,7 +351,12 @@ def display(self, updating_pv=None): ] # String-ify the dictionaries for display because elements of type dict # cannot be ordered. - data = data.applymap(lambda x: str(x) if isinstance(x, dict) else x) + with warnings.catch_warnings(): + # TODO(yathu) switch to use DataFrame.map when dropped pandas<2.1 support + warnings.filterwarnings( + "ignore", message="DataFrame.applymap has been deprecated") + data = data.applymap(lambda x: str(x) if isinstance(x, dict) else x) + if updating_pv: # Only updates when data is not empty. Otherwise, consider it a bad # iteration and noop since there is nothing to be updated. diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 0a03c96bc19b..13ab665c1eb1 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -306,7 +306,7 @@ def get_or_create_environment_with_resource_hints( # "Message"; expected "Environment" [arg-type] # Here, Environment is a subclass of Message but mypy still # throws an error. - cloned_env.CopyFrom(template_env) # type: ignore[arg-type] + cloned_env.CopyFrom(template_env) cloned_env.resource_hints.clear() cloned_env.resource_hints.update(resource_hints) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py index 07af1c958cfd..c1c7f649f77a 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py @@ -256,7 +256,7 @@ def has_as_main_input(self, pcoll): transform.spec.payload, beam_runner_api_pb2.ParDoPayload) local_side_inputs = payload.side_inputs else: - local_side_inputs = {} # type: ignore[assignment] + local_side_inputs = {} for local_id, pipeline_id in transform.inputs.items(): if pcoll == pipeline_id and local_id not in local_side_inputs: return True diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py index c5423e167026..d798e96d3aa3 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py @@ -1071,7 +1071,7 @@ def get_raw(self, if state_key.WhichOneof('type') not in self._SUPPORTED_STATE_TYPES: raise NotImplementedError( - 'Unknown state type: ' + state_key.WhichOneof('type')) + 'Unknown state type: ' + state_key.WhichOneof('type')) # type: ignore[operator] with self._lock: if not continuation_token: diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index e44d8ab0ae93..eee75f66a277 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -127,6 +127,7 @@ def __init__(self, options): self._artifacts_dir = options.artifacts_dir self._java_launcher = options.job_server_java_launcher self._jvm_properties = options.job_server_jvm_properties + self._jar_cache_dir = options.jar_cache_dir def java_arguments( self, job_port, artifact_port, expansion_port, artifacts_dir): @@ -141,11 +142,11 @@ def path_to_beam_jar(gradle_target, artifact_id=None): gradle_target, artifact_id=artifact_id) @staticmethod - def local_jar(url): - return subprocess_server.JavaJarServer.local_jar(url) + def local_jar(url, jar_cache_dir=None): + return subprocess_server.JavaJarServer.local_jar(url, jar_cache_dir) def subprocess_cmd_and_endpoint(self): - jar_path = self.local_jar(self.path_to_jar()) + jar_path = self.local_jar(self.path_to_jar(), self._jar_cache_dir) artifacts_dir = ( self._artifacts_dir if self._artifacts_dir else self.local_temp_dir( prefix='artifacts')) diff --git a/sdks/python/apache_beam/runners/portability/job_server_test.py b/sdks/python/apache_beam/runners/portability/job_server_test.py index 1e2ede281c9d..13b3629b24bf 100644 --- a/sdks/python/apache_beam/runners/portability/job_server_test.py +++ b/sdks/python/apache_beam/runners/portability/job_server_test.py @@ -41,7 +41,8 @@ def path_to_jar(self): return '/path/to/jar' @staticmethod - def local_jar(url): + def local_jar(url, jar_cache_dir=None): + logging.debug("url({%s}), jar_cache_dir({%s})", url, jar_cache_dir) return url diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py index 869f013d0d26..a2b4e5e7f939 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service.py @@ -35,7 +35,7 @@ import grpc from google.protobuf import json_format from google.protobuf import struct_pb2 -from google.protobuf import text_format # type: ignore # not in typeshed +from google.protobuf import text_format from apache_beam import pipeline from apache_beam.metrics import monitoring_infos diff --git a/sdks/python/apache_beam/runners/portability/prism_runner.py b/sdks/python/apache_beam/runners/portability/prism_runner.py index eeccaf5748ce..77dc8a214e8e 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner.py @@ -27,6 +27,7 @@ import platform import shutil import stat +import subprocess import typing import urllib import zipfile @@ -167,38 +168,93 @@ def construct_download_url(self, root_tag: str, sys: str, mach: str) -> str: def path_to_binary(self) -> str: if self._path is not None: - if not os.path.exists(self._path): - url = urllib.parse.urlparse(self._path) - if not url.scheme: - raise ValueError( - 'Unable to parse binary URL "%s". If using a full URL, make ' - 'sure the scheme is specified. If using a local file xpath, ' - 'make sure the file exists; you may have to first build prism ' - 'using `go build `.' % (self._path)) - - # We have a URL, see if we need to construct a valid file name. - if self._path.startswith(GITHUB_DOWNLOAD_PREFIX): - # If this URL starts with the download prefix, let it through. - return self._path - # The only other valid option is a github release page. - if not self._path.startswith(GITHUB_TAG_PREFIX): - raise ValueError( - 'Provided --prism_location URL is not an Apache Beam Github ' - 'Release page URL or download URL: %s' % (self._path)) - # Get the root tag for this URL - root_tag = os.path.basename(os.path.normpath(self._path)) - return self.construct_download_url( - root_tag, platform.system(), platform.machine()) - return self._path - else: - if '.dev' in self._version: + # The path is overidden, check various cases. + if os.path.exists(self._path): + # The path is local and exists, use directly. + return self._path + + # Check if the path is a URL. + url = urllib.parse.urlparse(self._path) + if not url.scheme: + raise ValueError( + 'Unable to parse binary URL "%s". If using a full URL, make ' + 'sure the scheme is specified. If using a local file xpath, ' + 'make sure the file exists; you may have to first build prism ' + 'using `go build `.' % (self._path)) + + # We have a URL, see if we need to construct a valid file name. + if self._path.startswith(GITHUB_DOWNLOAD_PREFIX): + # If this URL starts with the download prefix, let it through. + return self._path + # The only other valid option is a github release page. + if not self._path.startswith(GITHUB_TAG_PREFIX): raise ValueError( - 'Unable to derive URL for dev versions "%s". Please provide an ' - 'alternate version to derive the release URL with the ' - '--prism_beam_version_override flag.' % (self._version)) + 'Provided --prism_location URL is not an Apache Beam Github ' + 'Release page URL or download URL: %s' % (self._path)) + # Get the root tag for this URL + root_tag = os.path.basename(os.path.normpath(self._path)) + return self.construct_download_url( + root_tag, platform.system(), platform.machine()) + + if '.dev' not in self._version: + # Not a development version, so construct the production download URL return self.construct_download_url( self._version, platform.system(), platform.machine()) + # This is a development version! Assume Go is installed. + # Set the install directory to the cache location. + envdict = {**os.environ, "GOBIN": self.BIN_CACHE} + PRISMPKG = "github.com/apache/beam/sdks/v2/go/cmd/prism" + + process = subprocess.run(["go", "install", PRISMPKG], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=envdict, + check=False) + if process.returncode == 0: + # Successfully installed + return '%s/prism' % (self.BIN_CACHE) + + # We failed to build for some reason. + output = process.stdout.decode("utf-8") + if ("not in a module" not in output) and ( + "no required module provides" not in output): + # This branch handles two classes of failures: + # 1. Go isn't installed, so it needs to be installed by the Beam SDK + # developer. + # 2. Go is installed, and they are building in a local version of Prism, + # but there was a compile error that the developer should address. + # Either way, the @latest fallback either would fail, or hide the error, + # so fail now. + _LOGGER.info(output) + raise ValueError( + 'Unable to install a local of Prism: "%s";\n' + 'Likely Go is not installed, or a local change to Prism did not ' + 'compile.\nPlease install Go (see https://go.dev/doc/install) to ' + 'enable automatic local builds.\n' + 'Alternatively provide a binary with the --prism_location flag.' + '\nCaptured output:\n %s' % (self._version, output)) + + # Go is installed and claims we're not in a Go module that has access to + # the Prism package. + + # Fallback to using the @latest version of prism, which works everywhere. + process = subprocess.run(["go", "install", PRISMPKG + "@latest"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=envdict, + check=False) + + if process.returncode == 0: + return '%s/prism' % (self.BIN_CACHE) + + output = process.stdout.decode("utf-8") + raise ValueError( + 'We were unable to execute the subprocess "%s" to automatically ' + 'build prism.\nAlternatively provide an alternate binary with the ' + '--prism_location flag.' + '\nCaptured output:\n %s' % (process.args, output)) + def subprocess_cmd_and_endpoint( self) -> typing.Tuple[typing.List[typing.Any], str]: bin_path = self.local_bin( diff --git a/sdks/python/apache_beam/runners/render.py b/sdks/python/apache_beam/runners/render.py index fccfa8aacd61..45e66e1ba06a 100644 --- a/sdks/python/apache_beam/runners/render.py +++ b/sdks/python/apache_beam/runners/render.py @@ -64,7 +64,7 @@ import urllib.parse from google.protobuf import json_format -from google.protobuf import text_format # type: ignore +from google.protobuf import text_format from apache_beam.options import pipeline_options from apache_beam.portability.api import beam_runner_api_pb2 diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 0f1700f52486..89c137fe4366 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -45,6 +45,7 @@ from typing import Iterator from typing import List from typing import Mapping +from typing import MutableMapping from typing import Optional from typing import Set from typing import Tuple @@ -130,18 +131,16 @@ class RunnerIOOperation(operations.Operation): """Common baseclass for runner harness IO operations.""" - - def __init__(self, - name_context, # type: common.NameContext - step_name, # type: Any - consumers, # type: Mapping[Any, Iterable[operations.Operation]] - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - windowed_coder, # type: coders.Coder - transform_id, # type: str - data_channel # type: data_plane.DataChannel - ): - # type: (...) -> None + def __init__( + self, + name_context: common.NameContext, + step_name: Any, + consumers: Mapping[Any, Iterable[operations.Operation]], + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + windowed_coder: coders.Coder, + transform_id: str, + data_channel: data_plane.DataChannel) -> None: super().__init__(name_context, None, counter_factory, state_sampler) self.windowed_coder = windowed_coder self.windowed_coder_impl = windowed_coder.get_impl() @@ -157,36 +156,32 @@ def __init__(self, class DataOutputOperation(RunnerIOOperation): """A sink-like operation that gathers outputs to be sent back to the runner. """ - def set_output_stream(self, output_stream): - # type: (data_plane.ClosableOutputStream) -> None + def set_output_stream( + self, output_stream: data_plane.ClosableOutputStream) -> None: self.output_stream = output_stream - def process(self, windowed_value): - # type: (windowed_value.WindowedValue) -> None + def process(self, windowed_value: windowed_value.WindowedValue) -> None: self.windowed_coder_impl.encode_to_stream( windowed_value, self.output_stream, True) self.output_stream.maybe_flush() - def finish(self): - # type: () -> None + def finish(self) -> None: super().finish() self.output_stream.close() class DataInputOperation(RunnerIOOperation): """A source-like operation that gathers input from the runner.""" - - def __init__(self, - operation_name, # type: common.NameContext - step_name, - consumers, # type: Mapping[Any, List[operations.Operation]] - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - windowed_coder, # type: coders.Coder - transform_id, - data_channel # type: data_plane.GrpcClientDataChannel - ): - # type: (...) -> None + def __init__( + self, + operation_name: common.NameContext, + step_name, + consumers: Mapping[Any, List[operations.Operation]], + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + windowed_coder: coders.Coder, + transform_id, + data_channel: data_plane.GrpcClientDataChannel) -> None: super().__init__( operation_name, step_name, @@ -217,18 +212,15 @@ def setup(self, data_sampler=None): producer_batch_converter=self.get_output_batch_converter()) ] - def start(self): - # type: () -> None + def start(self) -> None: super().start() with self.splitting_lock: self.started = True - def process(self, windowed_value): - # type: (windowed_value.WindowedValue) -> None + def process(self, windowed_value: windowed_value.WindowedValue) -> None: self.output(windowed_value) - def process_encoded(self, encoded_windowed_values): - # type: (bytes) -> None + def process_encoded(self, encoded_windowed_values: bytes) -> None: input_stream = coder_impl.create_InputStream(encoded_windowed_values) while input_stream.size() > 0: with self.splitting_lock: @@ -244,8 +236,9 @@ def process_encoded(self, encoded_windowed_values): str(self.windowed_coder)) from exn self.output(decoded_value) - def monitoring_infos(self, transform_id, tag_to_pcollection_id): - # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] + def monitoring_infos( + self, transform_id: str, tag_to_pcollection_id: Dict[str, str] + ) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]: all_monitoring_infos = super().monitoring_infos( transform_id, tag_to_pcollection_id) read_progress_info = monitoring_infos.int64_counter( @@ -259,8 +252,13 @@ def monitoring_infos(self, transform_id, tag_to_pcollection_id): # TODO(https://github.com/apache/beam/issues/19737): typing not compatible # with super type def try_split( # type: ignore[override] - self, fraction_of_remainder, total_buffer_size, allowed_split_points): - # type: (...) -> Optional[Tuple[int, Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual], int]] + self, fraction_of_remainder, total_buffer_size, allowed_split_points + ) -> Optional[ + Tuple[ + int, + Iterable[operations.SdfSplitResultsPrimary], + Iterable[operations.SdfSplitResultsResidual], + int]]: with self.splitting_lock: if not self.started: return None @@ -314,9 +312,10 @@ def is_valid_split_point(index): # try splitting at the current element. if (keep_of_element_remainder < 1 and is_valid_split_point(index) and is_valid_split_point(index + 1)): - split = try_split( - keep_of_element_remainder - ) # type: Optional[Tuple[Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual]]] + split: Optional[Tuple[ + Iterable[operations.SdfSplitResultsPrimary], + Iterable[operations.SdfSplitResultsResidual]]] = try_split( + keep_of_element_remainder) if split: element_primaries, element_residuals = split return index - 1, element_primaries, element_residuals, index + 1 @@ -343,15 +342,13 @@ def is_valid_split_point(index): else: return None - def finish(self): - # type: () -> None + def finish(self) -> None: super().finish() with self.splitting_lock: self.index += 1 self.started = False - def reset(self): - # type: () -> None + def reset(self) -> None: with self.splitting_lock: self.index = -1 self.stop = float('inf') @@ -359,12 +356,12 @@ def reset(self): class _StateBackedIterable(object): - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - coder_or_impl, # type: Union[coders.Coder, coder_impl.CoderImpl] - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + coder_or_impl: Union[coders.Coder, coder_impl.CoderImpl], + ) -> None: self._state_handler = state_handler self._state_key = state_key if isinstance(coder_or_impl, coders.Coder): @@ -372,8 +369,7 @@ def __init__(self, else: self._coder_impl = coder_or_impl - def __iter__(self): - # type: () -> Iterator[Any] + def __iter__(self) -> Iterator[Any]: return iter( self._state_handler.blocking_get(self._state_key, self._coder_impl)) @@ -391,15 +387,15 @@ class StateBackedSideInputMap(object): _BULK_READ_FULLY = "fully" _BULK_READ_PARTIALLY = "partially" - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - transform_id, # type: str - tag, # type: Optional[str] - side_input_data, # type: pvalue.SideInputData - coder, # type: WindowedValueCoder - use_bulk_read = False, # type: bool - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + transform_id: str, + tag: Optional[str], + side_input_data: pvalue.SideInputData, + coder: WindowedValueCoder, + use_bulk_read: bool = False, + ) -> None: self._state_handler = state_handler self._transform_id = transform_id self._tag = tag @@ -407,7 +403,7 @@ def __init__(self, self._element_coder = coder.wrapped_value_coder self._target_window_coder = coder.window_coder # TODO(robertwb): Limit the cache size. - self._cache = {} # type: Dict[BoundedWindow, Any] + self._cache: Dict[BoundedWindow, Any] = {} self._use_bulk_read = use_bulk_read def __getitem__(self, window): @@ -503,14 +499,12 @@ def __reduce__(self): self._cache[target_window] = self._side_input_data.view_fn(raw_view) return self._cache[target_window] - def is_globally_windowed(self): - # type: () -> bool + def is_globally_windowed(self) -> bool: return ( self._side_input_data.window_mapping_fn == sideinputs._global_window_mapping_fn) - def reset(self): - # type: () -> None + def reset(self) -> None: # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens. self._cache = {} @@ -519,26 +513,28 @@ class ReadModifyWriteRuntimeState(userstate.ReadModifyWriteRuntimeState): def __init__(self, underlying_bag_state): self._underlying_bag_state = underlying_bag_state - def read(self): # type: () -> Any + def read(self) -> Any: values = list(self._underlying_bag_state.read()) if not values: return None return values[0] - def write(self, value): # type: (Any) -> None + def write(self, value: Any) -> None: self.clear() self._underlying_bag_state.add(value) - def clear(self): # type: () -> None + def clear(self) -> None: self._underlying_bag_state.clear() - def commit(self): # type: () -> None + def commit(self) -> None: self._underlying_bag_state.commit() class CombiningValueRuntimeState(userstate.CombiningValueRuntimeState): - def __init__(self, underlying_bag_state, combinefn): - # type: (userstate.AccumulatingRuntimeState, core.CombineFn) -> None + def __init__( + self, + underlying_bag_state: userstate.AccumulatingRuntimeState, + combinefn: core.CombineFn) -> None: self._combinefn = combinefn self._combinefn.setup() self._underlying_bag_state = underlying_bag_state @@ -552,12 +548,10 @@ def _read_accumulator(self, rewrite=True): self._underlying_bag_state.add(merged_accumulator) return merged_accumulator - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: return self._combinefn.extract_output(self._read_accumulator()) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: # Prefer blind writes, but don't let them grow unboundedly. # This should be tuned to be much lower, but for now exercise # both paths well. @@ -569,8 +563,7 @@ def add(self, value): self._underlying_bag_state.add( self._combinefn.add_input(accumulator, value)) - def clear(self): - # type: () -> None + def clear(self) -> None: self._underlying_bag_state.clear() def commit(self): @@ -587,13 +580,11 @@ class _ConcatIterable(object): Unlike itertools.chain, this allows reiteration. """ - def __init__(self, first, second): - # type: (Iterable[Any], Iterable[Any]) -> None + def __init__(self, first: Iterable[Any], second: Iterable[Any]) -> None: self.first = first self.second = second - def __iter__(self): - # type: () -> Iterator[Any] + def __iter__(self) -> Iterator[Any]: for elem in self.first: yield elem for elem in self.second: @@ -604,38 +595,32 @@ def __iter__(self): class SynchronousBagRuntimeState(userstate.BagRuntimeState): - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - value_coder # type: coders.Coder - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = [] # type: List[Any] + self._added_elements: List[Any] = [] - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: return _ConcatIterable([] if self._cleared else cast( 'Iterable[Any]', _StateBackedIterable( self._state_handler, self._state_key, self._value_coder)), self._added_elements) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: self._added_elements.append(value) - def clear(self): - # type: () -> None + def clear(self) -> None: self._cleared = True self._added_elements = [] - def commit(self): - # type: () -> None + def commit(self) -> None: to_await = None if self._cleared: to_await = self._state_handler.clear(self._state_key) @@ -648,18 +633,16 @@ def commit(self): class SynchronousSetRuntimeState(userstate.SetRuntimeState): - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - value_coder # type: coders.Coder - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = set() # type: Set[Any] + self._added_elements: Set[Any] = set() def _compact_data(self, rewrite=True): accumulator = set( @@ -679,12 +662,10 @@ def _compact_data(self, rewrite=True): return accumulator - def read(self): - # type: () -> Set[Any] + def read(self) -> Set[Any]: return self._compact_data(rewrite=False) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: if self._cleared: # This is a good time explicitly clear. self._state_handler.clear(self._state_key) @@ -694,13 +675,11 @@ def add(self, value): if random.random() > 0.5: self._compact_data() - def clear(self): - # type: () -> None + def clear(self) -> None: self._cleared = True self._added_elements = set() - def commit(self): - # type: () -> None + def commit(self) -> None: to_await = None if self._cleared: to_await = self._state_handler.clear(self._state_key) @@ -887,16 +866,16 @@ def commit(self) -> None: class OutputTimer(userstate.BaseTimer): - def __init__(self, - key, - window, # type: BoundedWindow - timestamp, # type: timestamp.Timestamp - paneinfo, # type: windowed_value.PaneInfo - time_domain, # type: str - timer_family_id, # type: str - timer_coder_impl, # type: coder_impl.TimerCoderImpl - output_stream # type: data_plane.ClosableOutputStream - ): + def __init__( + self, + key, + window: BoundedWindow, + timestamp: timestamp.Timestamp, + paneinfo: windowed_value.PaneInfo, + time_domain: str, + timer_family_id: str, + timer_coder_impl: coder_impl.TimerCoderImpl, + output_stream: data_plane.ClosableOutputStream): self._key = key self._window = window self._input_timestamp = timestamp @@ -942,15 +921,13 @@ def __init__(self, timer_coder_impl, output_stream=None): class FnApiUserStateContext(userstate.UserStateContext): """Interface for state and timers from SDK to Fn API servicer of state..""" - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - transform_id, # type: str - key_coder, # type: coders.Coder - window_coder, # type: coders.Coder - ): - # type: (...) -> None - + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + transform_id: str, + key_coder: coders.Coder, + window_coder: coders.Coder, + ) -> None: """Initialize a ``FnApiUserStateContext``. Args: @@ -964,11 +941,10 @@ def __init__(self, self._key_coder = key_coder self._window_coder = window_coder # A mapping of {timer_family_id: TimerInfo} - self._timers_info = {} # type: Dict[str, TimerInfo] - self._all_states = {} # type: Dict[tuple, FnApiUserRuntimeStateTypes] + self._timers_info: Dict[str, TimerInfo] = {} + self._all_states: Dict[tuple, FnApiUserRuntimeStateTypes] = {} - def add_timer_info(self, timer_family_id, timer_info): - # type: (str, TimerInfo) -> None + def add_timer_info(self, timer_family_id: str, timer_info: TimerInfo) -> None: self._timers_info[timer_family_id] = timer_info def get_timer( @@ -987,19 +963,15 @@ def get_timer( timer_coder_impl, output_stream) - def get_state(self, *args): - # type: (*Any) -> FnApiUserRuntimeStateTypes + def get_state(self, *args: Any) -> FnApiUserRuntimeStateTypes: state_handle = self._all_states.get(args) if state_handle is None: state_handle = self._all_states[args] = self._create_state(*args) return state_handle - def _create_state(self, - state_spec, # type: userstate.StateSpec - key, - window # type: BoundedWindow - ): - # type: (...) -> FnApiUserRuntimeStateTypes + def _create_state( + self, state_spec: userstate.StateSpec, key, + window: BoundedWindow) -> FnApiUserRuntimeStateTypes: if isinstance(state_spec, (userstate.BagStateSpec, userstate.CombiningValueStateSpec, @@ -1046,13 +1018,11 @@ def _create_state(self, else: raise NotImplementedError(state_spec) - def commit(self): - # type: () -> None + def commit(self) -> None: for state in self._all_states.values(): state.commit() - def reset(self): - # type: () -> None + def reset(self) -> None: for state in self._all_states.values(): state.finalize() self._all_states = {} @@ -1071,14 +1041,12 @@ def wrapper(*args): return wrapper -def only_element(iterable): - # type: (Iterable[T]) -> T +def only_element(iterable: Iterable[T]) -> T: element, = iterable return element -def _environments_compatible(submission, runtime): - # type: (str, str) -> bool +def _environments_compatible(submission: str, runtime: str) -> bool: if submission == runtime: return True if 'rc' in submission and runtime in submission: @@ -1088,8 +1056,8 @@ def _environments_compatible(submission, runtime): return False -def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor): - # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None +def _verify_descriptor_created_in_a_compatible_env( + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor) -> None: runtime_sdk = environments.sdk_base_version_capability() for t in process_bundle_descriptor.transforms.values(): @@ -1111,16 +1079,14 @@ def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor): class BundleProcessor(object): """ A class for processing bundles of elements. """ - - def __init__(self, - runner_capabilities, # type: FrozenSet[str] - process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor - state_handler, # type: sdk_worker.CachingStateHandler - data_channel_factory, # type: data_plane.DataChannelFactory - data_sampler=None, # type: Optional[data_sampler.DataSampler] - ): - # type: (...) -> None - + def __init__( + self, + runner_capabilities: FrozenSet[str], + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + state_handler: sdk_worker.CachingStateHandler, + data_channel_factory: data_plane.DataChannelFactory, + data_sampler: Optional[data_sampler.DataSampler] = None, + ) -> None: """Initialize a bundle processor. Args: @@ -1136,7 +1102,7 @@ def __init__(self, self.state_handler = state_handler self.data_channel_factory = data_channel_factory self.data_sampler = data_sampler - self.current_instruction_id = None # type: Optional[str] + self.current_instruction_id: Optional[str] = None # Represents whether the SDK is consuming received data. self.consuming_received_data = False @@ -1155,7 +1121,7 @@ def __init__(self, # {(transform_id, timer_family_id): TimerInfo} # The mapping is empty when there is no timer_family_specs in the # ProcessBundleDescriptor. - self.timers_info = {} # type: Dict[Tuple[str, str], TimerInfo] + self.timers_info: Dict[Tuple[str, str], TimerInfo] = {} # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. @@ -1170,10 +1136,8 @@ def __init__(self, self.splitting_lock = threading.Lock() def create_execution_tree( - self, - descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor - ): - # type: (...) -> collections.OrderedDict[str, operations.DoOperation] + self, descriptor: beam_fn_api_pb2.ProcessBundleDescriptor + ) -> collections.OrderedDict[str, operations.DoOperation]: transform_factory = BeamTransformFactory( self.runner_capabilities, descriptor, @@ -1192,16 +1156,14 @@ def is_side_input(transform_proto, tag): transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload).side_inputs - pcoll_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[str]] + pcoll_consumers: DefaultDict[str, List[str]] = collections.defaultdict(list) for transform_id, transform_proto in descriptor.transforms.items(): for tag, pcoll_id in transform_proto.inputs.items(): if not is_side_input(transform_proto, tag): pcoll_consumers[pcoll_id].append(transform_id) @memoize - def get_operation(transform_id): - # type: (str) -> operations.Operation + def get_operation(transform_id: str) -> operations.Operation: transform_consumers = { tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] for tag, @@ -1218,8 +1180,7 @@ def get_operation(transform_id): # Operations must be started (hence returned) in order. @memoize - def topological_height(transform_id): - # type: (str) -> int + def topological_height(transform_id: str) -> int: return 1 + max([0] + [ topological_height(consumer) for pcoll in descriptor.transforms[transform_id].outputs.values() @@ -1232,18 +1193,18 @@ def topological_height(transform_id): get_operation(transform_id))) for transform_id in sorted( descriptor.transforms, key=topological_height, reverse=True)]) - def reset(self): - # type: () -> None + def reset(self) -> None: self.counter_factory.reset() self.state_sampler.reset() # Side input caches. for op in self.ops.values(): op.reset() - def process_bundle(self, instruction_id): - # type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + def process_bundle( + self, instruction_id: str + ) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]: - expected_input_ops = [] # type: List[DataInputOperation] + expected_input_ops: List[DataInputOperation] = [] for op in self.ops.values(): if isinstance(op, DataOutputOperation): @@ -1269,9 +1230,10 @@ def process_bundle(self, instruction_id): # both data input and timer input. The data input is identied by # transform_id. The data input is identified by # (transform_id, timer_family_id). - data_channels = collections.defaultdict( - list - ) # type: DefaultDict[data_plane.DataChannel, List[Union[str, Tuple[str, str]]]] + data_channels: DefaultDict[data_plane.DataChannel, + List[Union[str, Tuple[ + str, + str]]]] = collections.defaultdict(list) # Add expected data inputs for each data channel. input_op_by_transform_id = {} @@ -1337,18 +1299,17 @@ def process_bundle(self, instruction_id): self.current_instruction_id = None self.state_sampler.stop_if_still_running() - def finalize_bundle(self): - # type: () -> beam_fn_api_pb2.FinalizeBundleResponse + def finalize_bundle(self) -> beam_fn_api_pb2.FinalizeBundleResponse: for op in self.ops.values(): op.finalize_bundle() return beam_fn_api_pb2.FinalizeBundleResponse() - def requires_finalization(self): - # type: () -> bool + def requires_finalization(self) -> bool: return any(op.needs_finalization() for op in self.ops.values()) - def try_split(self, bundle_split_request): - # type: (beam_fn_api_pb2.ProcessBundleSplitRequest) -> beam_fn_api_pb2.ProcessBundleSplitResponse + def try_split( + self, bundle_split_request: beam_fn_api_pb2.ProcessBundleSplitRequest + ) -> beam_fn_api_pb2.ProcessBundleSplitResponse: split_response = beam_fn_api_pb2.ProcessBundleSplitResponse() with self.splitting_lock: if bundle_split_request.instruction_id != self.current_instruction_id: @@ -1386,20 +1347,18 @@ def try_split(self, bundle_split_request): return split_response - def delayed_bundle_application(self, - op, # type: operations.DoOperation - deferred_remainder # type: SplitResultResidual - ): - # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication + def delayed_bundle_application( + self, op: operations.DoOperation, deferred_remainder: SplitResultResidual + ) -> beam_fn_api_pb2.DelayedBundleApplication: assert op.input_info is not None # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. (element_and_restriction, current_watermark, deferred_timestamp) = ( deferred_remainder) if deferred_timestamp: assert isinstance(deferred_timestamp, timestamp.Duration) - proto_deferred_watermark = proto_utils.from_micros( - duration_pb2.Duration, - deferred_timestamp.micros) # type: Optional[duration_pb2.Duration] + proto_deferred_watermark: Optional[ + duration_pb2.Duration] = proto_utils.from_micros( + duration_pb2.Duration, deferred_timestamp.micros) else: proto_deferred_watermark = None return beam_fn_api_pb2.DelayedBundleApplication( @@ -1407,29 +1366,26 @@ def delayed_bundle_application(self, application=self.construct_bundle_application( op.input_info, current_watermark, element_and_restriction)) - def bundle_application(self, - op, # type: operations.DoOperation - primary # type: SplitResultPrimary - ): - # type: (...) -> beam_fn_api_pb2.BundleApplication + def bundle_application( + self, op: operations.DoOperation, + primary: SplitResultPrimary) -> beam_fn_api_pb2.BundleApplication: assert op.input_info is not None return self.construct_bundle_application( op.input_info, None, primary.primary_value) - def construct_bundle_application(self, - op_input_info, # type: operations.OpInputInfo - output_watermark, # type: Optional[timestamp.Timestamp] - element - ): - # type: (...) -> beam_fn_api_pb2.BundleApplication + def construct_bundle_application( + self, + op_input_info: operations.OpInputInfo, + output_watermark: Optional[timestamp.Timestamp], + element) -> beam_fn_api_pb2.BundleApplication: transform_id, main_input_tag, main_input_coder, outputs = op_input_info if output_watermark: proto_output_watermark = proto_utils.from_micros( timestamp_pb2.Timestamp, output_watermark.micros) - output_watermarks = { + output_watermarks: Optional[Dict[str, timestamp_pb2.Timestamp]] = { output: proto_output_watermark for output in outputs - } # type: Optional[Dict[str, timestamp_pb2.Timestamp]] + } else: output_watermarks = None return beam_fn_api_pb2.BundleApplication( @@ -1438,9 +1394,7 @@ def construct_bundle_application(self, output_watermarks=output_watermarks, element=main_input_coder.get_impl().encode_nested(element)) - def monitoring_infos(self): - # type: () -> List[metrics_pb2.MonitoringInfo] - + def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]: """Returns the list of MonitoringInfos collected processing this bundle.""" # Construct a new dict first to remove duplicates. all_monitoring_infos_dict = {} @@ -1452,8 +1406,7 @@ def monitoring_infos(self): return list(all_monitoring_infos_dict.values()) - def shutdown(self): - # type: () -> None + def shutdown(self) -> None: for op in self.ops.values(): op.teardown() @@ -1474,15 +1427,16 @@ class ExecutionContext: class BeamTransformFactory(object): """Factory for turning transform_protos into executable operations.""" - def __init__(self, - runner_capabilities, # type: FrozenSet[str] - descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor - data_channel_factory, # type: data_plane.DataChannelFactory - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - state_handler, # type: sdk_worker.CachingStateHandler - data_sampler, # type: Optional[data_sampler.DataSampler] - ): + def __init__( + self, + runner_capabilities: FrozenSet[str], + descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + data_channel_factory: data_plane.DataChannelFactory, + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + state_handler: sdk_worker.CachingStateHandler, + data_sampler: Optional[data_sampler.DataSampler], + ): self.runner_capabilities = runner_capabilities self.descriptor = descriptor self.data_channel_factory = data_channel_factory @@ -1499,27 +1453,41 @@ def __init__(self, element_coder_impl)) self.data_sampler = data_sampler - _known_urns = { - } # type: Dict[str, Tuple[ConstructorFn, Union[Type[message.Message], Type[bytes], None]]] + _known_urns: Dict[str, + Tuple[ConstructorFn, + Union[Type[message.Message], Type[bytes], + None]]] = {} @classmethod def register_urn( - cls, - urn, # type: str - parameter_type # type: Optional[Type[T]] - ): - # type: (...) -> Callable[[Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]], Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]] + cls, urn: str, parameter_type: Optional[Type[T]] + ) -> Callable[[ + Callable[[ + BeamTransformFactory, + str, + beam_runner_api_pb2.PTransform, + T, + Dict[str, List[operations.Operation]] + ], + operations.Operation] + ], + Callable[[ + BeamTransformFactory, + str, + beam_runner_api_pb2.PTransform, + T, + Dict[str, List[operations.Operation]] + ], + operations.Operation]]: def wrapper(func): cls._known_urns[urn] = func, parameter_type return func return wrapper - def create_operation(self, - transform_id, # type: str - consumers # type: Dict[str, List[operations.Operation]] - ): - # type: (...) -> operations.Operation + def create_operation( + self, transform_id: str, + consumers: Dict[str, List[operations.Operation]]) -> operations.Operation: transform_proto = self.descriptor.transforms[transform_id] if not transform_proto.unique_name: _LOGGER.debug("No unique name set for transform %s" % transform_id) @@ -1529,8 +1497,7 @@ def create_operation(self, transform_proto.spec.payload, parameter_type) return creator(self, transform_id, transform_proto, payload, consumers) - def extract_timers_info(self): - # type: () -> Dict[Tuple[str, str], TimerInfo] + def extract_timers_info(self) -> Dict[Tuple[str, str], TimerInfo]: timers_info = {} for transform_id, transform_proto in self.descriptor.transforms.items(): if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: @@ -1545,8 +1512,7 @@ def extract_timers_info(self): timer_coder_impl=timer_coder_impl) return timers_info - def get_coder(self, coder_id): - # type: (str) -> coders.Coder + def get_coder(self, coder_id: str) -> coders.Coder: if coder_id not in self.descriptor.coders: raise KeyError("No such coder: %s" % coder_id) coder_proto = self.descriptor.coders[coder_id] @@ -1557,8 +1523,7 @@ def get_coder(self, coder_id): return operation_specs.get_coder_from_spec( json.loads(coder_proto.spec.payload.decode('utf-8'))) - def get_windowed_coder(self, pcoll_id): - # type: (str) -> WindowedValueCoder + def get_windowed_coder(self, pcoll_id: str) -> WindowedValueCoder: coder = self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) # TODO(robertwb): Remove this condition once all runners are consistent. if not isinstance(coder, WindowedValueCoder): @@ -1569,32 +1534,34 @@ def get_windowed_coder(self, pcoll_id): else: return coder - def get_output_coders(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.Coder] + def get_output_coders( + self, transform_proto: beam_runner_api_pb2.PTransform + ) -> Dict[str, coders.Coder]: return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.outputs.items() } - def get_only_output_coder(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> coders.Coder + def get_only_output_coder( + self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: return only_element(self.get_output_coders(transform_proto).values()) - def get_input_coders(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.WindowedValueCoder] + def get_input_coders( + self, transform_proto: beam_runner_api_pb2.PTransform + ) -> Dict[str, coders.WindowedValueCoder]: return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.inputs.items() } - def get_only_input_coder(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> coders.Coder + def get_only_input_coder( + self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: return only_element(list(self.get_input_coders(transform_proto).values())) - def get_input_windowing(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Windowing + def get_input_windowing( + self, transform_proto: beam_runner_api_pb2.PTransform) -> Windowing: pcoll_id = only_element(transform_proto.inputs.values()) windowing_strategy_id = self.descriptor.pcollections[ pcoll_id].windowing_strategy_id @@ -1603,12 +1570,10 @@ def get_input_windowing(self, transform_proto): # TODO(robertwb): Update all operations to take these in the constructor. @staticmethod def augment_oldstyle_op( - op, # type: OperationT - step_name, # type: str - consumers, # type: Mapping[str, Iterable[operations.Operation]] - tag_list=None # type: Optional[List[str]] - ): - # type: (...) -> OperationT + op: OperationT, + step_name: str, + consumers: Mapping[str, Iterable[operations.Operation]], + tag_list: Optional[List[str]] = None) -> OperationT: op.step_name = step_name for tag, op_consumers in consumers.items(): for consumer in op_consumers: @@ -1619,13 +1584,11 @@ def augment_oldstyle_op( @BeamTransformFactory.register_urn( DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) def create_source_runner( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> DataInputOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + grpc_port: beam_fn_api_pb2.RemoteGrpcPort, + consumers: Dict[str, List[operations.Operation]]) -> DataInputOperation: output_coder = factory.get_coder(grpc_port.coder_id) return DataInputOperation( @@ -1642,13 +1605,11 @@ def create_source_runner( @BeamTransformFactory.register_urn( DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) def create_sink_runner( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> DataOutputOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + grpc_port: beam_fn_api_pb2.RemoteGrpcPort, + consumers: Dict[str, List[operations.Operation]]) -> DataOutputOperation: output_coder = factory.get_coder(grpc_port.coder_id) return DataOutputOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -1663,13 +1624,12 @@ def create_sink_runner( @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_READ_URN, None) def create_source_java( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, parameter, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ReadOperation + consumers: Dict[str, + List[operations.Operation]]) -> operations.ReadOperation: # The Dataflow runner harness strips the base64 encoding. source = pickler.loads(base64.b64encode(parameter)) spec = operation_specs.WorkerRead( @@ -1688,13 +1648,12 @@ def create_source_java( @BeamTransformFactory.register_urn( common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload) def create_deprecated_read( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ReadPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ReadOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ReadPayload, + consumers: Dict[str, + List[operations.Operation]]) -> operations.ReadOperation: source = iobase.BoundedSource.from_runner_api( parameter.source, factory.context) spec = operation_specs.WorkerRead( @@ -1713,13 +1672,12 @@ def create_deprecated_read( @BeamTransformFactory.register_urn( python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload) def create_read_from_impulse_python( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ReadPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ImpulseReadOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ReadPayload, + consumers: Dict[str, List[operations.Operation]] +) -> operations.ImpulseReadOperation: return operations.ImpulseReadOperation( common.NameContext(transform_proto.unique_name, transform_id), factory.counter_factory, @@ -1731,12 +1689,11 @@ def create_read_from_impulse_python( @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN, None) def create_dofn_javasdk( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, serialized_fn, - consumers # type: Dict[str, List[operations.Operation]] -): + consumers: Dict[str, List[operations.Operation]]): return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -1820,12 +1777,11 @@ def process(self, element_restriction, *args, **kwargs): common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn, beam_runner_api_pb2.ParDoPayload) def create_process_sized_elements_and_restrictions( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ParDoPayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ParDoPayload, + consumers: Dict[str, List[operations.Operation]]): return _create_pardo_operation( factory, transform_id, @@ -1867,13 +1823,11 @@ def _create_sdf_operation( @BeamTransformFactory.register_urn( common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) def create_par_do( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ParDoPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.DoOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ParDoPayload, + consumers: Dict[str, List[operations.Operation]]) -> operations.DoOperation: return _create_pardo_operation( factory, transform_id, @@ -1885,14 +1839,13 @@ def create_par_do( def _create_pardo_operation( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, consumers, serialized_fn, - pardo_proto=None, # type: Optional[beam_runner_api_pb2.ParDoPayload] - operation_cls=operations.DoOperation -): + pardo_proto: Optional[beam_runner_api_pb2.ParDoPayload] = None, + operation_cls=operations.DoOperation): if pardo_proto and pardo_proto.side_inputs: input_tags_to_coders = factory.get_input_coders(transform_proto) @@ -1924,9 +1877,8 @@ def _create_pardo_operation( if not dofn_data[-1]: # Windowing not set. if pardo_proto: - other_input_tags = set.union( - set(pardo_proto.side_inputs), - set(pardo_proto.timer_family_specs)) # type: Container[str] + other_input_tags: Container[str] = set.union( + set(pardo_proto.side_inputs), set(pardo_proto.timer_family_specs)) else: other_input_tags = () pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items() @@ -1950,12 +1902,12 @@ def _create_pardo_operation( main_input_coder = found_input_coder if pardo_proto.timer_family_specs or pardo_proto.state_specs: - user_state_context = FnApiUserStateContext( - factory.state_handler, - transform_id, - main_input_coder.key_coder(), - main_input_coder.window_coder - ) # type: Optional[FnApiUserStateContext] + user_state_context: Optional[ + FnApiUserStateContext] = FnApiUserStateContext( + factory.state_handler, + transform_id, + main_input_coder.key_coder(), + main_input_coder.window_coder) else: user_state_context = None else: @@ -1989,12 +1941,13 @@ def _create_pardo_operation( return result -def _create_simple_pardo_operation(factory, # type: BeamTransformFactory - transform_id, - transform_proto, - consumers, - dofn, # type: beam.DoFn - ): +def _create_simple_pardo_operation( + factory: BeamTransformFactory, + transform_id, + transform_proto, + consumers, + dofn: beam.DoFn, +): serialized_fn = pickler.dumps((dofn, (), {}, [], None)) return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -2004,12 +1957,11 @@ def _create_simple_pardo_operation(factory, # type: BeamTransformFactory common_urns.primitives.ASSIGN_WINDOWS.urn, beam_runner_api_pb2.WindowingStrategy) def create_assign_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.WindowingStrategy - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.WindowingStrategy, + consumers: Dict[str, List[operations.Operation]]): class WindowIntoDoFn(beam.DoFn): def __init__(self, windowing): self.windowing = windowing @@ -2036,13 +1988,12 @@ def process( @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None) def create_identity_dofn( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, parameter, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.FlattenOperation + consumers: Dict[str, List[operations.Operation]] +) -> operations.FlattenOperation: return factory.augment_oldstyle_op( operations.FlattenOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -2058,13 +2009,12 @@ def create_identity_dofn( common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_precombine( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.PGBKCVOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, + List[operations.Operation]]) -> operations.PGBKCVOperation: serialized_combine_fn = pickler.dumps(( beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -2085,12 +2035,11 @@ def create_combine_per_key_precombine( common_urns.combine_components.COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn, beam_runner_api_pb2.CombinePayload) def create_combbine_per_key_merge_accumulators( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'merge') @@ -2099,12 +2048,11 @@ def create_combbine_per_key_merge_accumulators( common_urns.combine_components.COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_extract_outputs( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'extract') @@ -2113,12 +2061,11 @@ def create_combine_per_key_extract_outputs( common_urns.combine_components.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_convert_to_accumulators( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'convert') @@ -2127,19 +2074,18 @@ def create_combine_per_key_convert_to_accumulators( common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, beam_runner_api_pb2.CombinePayload) def create_combine_grouped_values( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'all') def _create_combine_phase_operation( - factory, transform_id, transform_proto, payload, consumers, phase): - # type: (...) -> operations.CombineOperation + factory, transform_id, transform_proto, payload, consumers, + phase) -> operations.CombineOperation: serialized_combine_fn = pickler.dumps(( beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -2158,13 +2104,12 @@ def _create_combine_phase_operation( @BeamTransformFactory.register_urn(common_urns.primitives.FLATTEN.urn, None) def create_flatten( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, payload, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.FlattenOperation + consumers: Dict[str, List[operations.Operation]] +) -> operations.FlattenOperation: return factory.augment_oldstyle_op( operations.FlattenOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -2179,12 +2124,11 @@ def create_flatten( @BeamTransformFactory.register_urn( common_urns.primitives.MAP_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) def create_map_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN window_mapping_fn = pickler.loads(mapping_fn_spec.payload) @@ -2200,12 +2144,11 @@ def process(self, element): @BeamTransformFactory.register_urn( common_urns.primitives.MERGE_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) def create_merge_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOWFN window_fn = pickler.loads(mapping_fn_spec.payload) @@ -2213,24 +2156,25 @@ class MergeWindows(beam.DoFn): def process(self, element): nonce, windows = element - original_windows = set(windows) # type: Set[window.BoundedWindow] - merged_windows = collections.defaultdict( - set - ) # type: MutableMapping[window.BoundedWindow, Set[window.BoundedWindow]] # noqa: F821 + original_windows: Set[window.BoundedWindow] = set(windows) + merged_windows: MutableMapping[ + window.BoundedWindow, + Set[window.BoundedWindow]] = collections.defaultdict( + set) # noqa: F821 class RecordingMergeContext(window.WindowFn.MergeContext): def merge( self, - to_be_merged, # type: Iterable[window.BoundedWindow] - merge_result, # type: window.BoundedWindow - ): + to_be_merged: Iterable[window.BoundedWindow], + merge_result: window.BoundedWindow, + ): originals = merged_windows[merge_result] - for window in to_be_merged: - if window in original_windows: - originals.add(window) - original_windows.remove(window) + for w in to_be_merged: + if w in original_windows: + originals.add(w) + original_windows.remove(w) else: - originals.update(merged_windows.pop(window)) + originals.update(merged_windows.pop(w)) window_fn.merge(RecordingMergeContext(windows)) yield nonce, (original_windows, merged_windows.items()) @@ -2241,12 +2185,11 @@ def merge( @BeamTransformFactory.register_urn(common_urns.primitives.TO_STRING.urn, None) def create_to_string_fn( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): class ToString(beam.DoFn): def process(self, element): key, value = element diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index 88cc3c9791d5..979c7cdb53be 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -125,7 +125,7 @@ def emit(self, record: logging.LogRecord) -> None: log_entry.message = ( "Failed to format '%s' with args '%s' during logging." % (str(record.msg), record.args)) - log_entry.thread = record.threadName + log_entry.thread = record.threadName # type: ignore[assignment] log_entry.log_location = '%s:%s' % ( record.pathname or record.module, record.lineno or record.funcName) (fraction, seconds) = math.modf(record.created) diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py index 51ca4cf0545b..5496bccd014e 100644 --- a/sdks/python/apache_beam/runners/worker/opcounters.py +++ b/sdks/python/apache_beam/runners/worker/opcounters.py @@ -259,7 +259,7 @@ def do_sample(self, windowed_value): self.type_check(windowed_value.value) size, observables = ( - self.coder_impl.get_estimated_size_and_observables(windowed_value)) + self.coder_impl.get_estimated_size_and_observables(windowed_value)) # type: ignore[union-attr] if not observables: self.current_size = size else: diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 2a1423fccba9..b091220a06b5 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -1335,8 +1335,8 @@ def _get_cache_token(self, state_key): return self._context.user_state_cache_token else: return self._context.bundle_cache_token - elif state_key.WhichOneof('type').endswith('_side_input'): - side_input = getattr(state_key, state_key.WhichOneof('type')) + elif state_key.WhichOneof('type').endswith('_side_input'): # type: ignore[union-attr] + side_input = getattr(state_key, state_key.WhichOneof('type')) # type: ignore[arg-type] return self._context.side_input_cache_tokens.get( (side_input.transform_id, side_input.side_input_id), self._context.bundle_cache_token) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index cd49c69a80aa..3389f0c7afb1 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -27,7 +27,7 @@ import sys import traceback -from google.protobuf import text_format # type: ignore # not in typeshed +from google.protobuf import text_format from apache_beam.internal import pickler from apache_beam.io import filesystems diff --git a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py index 87c631762287..ba3fae6819bd 100644 --- a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py +++ b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py @@ -29,7 +29,7 @@ from apache_beam.testing.load_tests.load_test_metrics_utils import MeasureTime from apache_beam.testing.load_tests.load_test_metrics_utils import MetricsReader -from google.protobuf import text_format # type: ignore # typeshed out of date +from google.protobuf import text_format from trainer import taxi diff --git a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py index 6d84c995bfca..88ed53e11fc4 100644 --- a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py +++ b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py @@ -18,7 +18,7 @@ from tensorflow_transform import coders as tft_coders from tensorflow_transform.tf_metadata import schema_utils -from google.protobuf import text_format # type: ignore # typeshed out of date +from google.protobuf import text_format from tensorflow.python.lib.io import file_io from tensorflow_metadata.proto.v0 import schema_pb2 diff --git a/sdks/python/apache_beam/testing/load_tests/sideinput_test.py b/sdks/python/apache_beam/testing/load_tests/sideinput_test.py index 745d961d2aac..3b5dfdf38cd9 100644 --- a/sdks/python/apache_beam/testing/load_tests/sideinput_test.py +++ b/sdks/python/apache_beam/testing/load_tests/sideinput_test.py @@ -111,7 +111,7 @@ class SequenceSideInputTestDoFn(beam.DoFn): def __init__(self, first_n: int): self._first_n = first_n - def process( # type: ignore[override] + def process( self, element: Any, side_input: Iterable[Tuple[bytes, bytes]]) -> None: i = 0 @@ -129,7 +129,7 @@ class MappingSideInputTestDoFn(beam.DoFn): def __init__(self, first_n: int): self._first_n = first_n - def process( # type: ignore[override] + def process( self, element: Any, dict_side_input: Dict[bytes, bytes]) -> None: i = 0 for key in dict_side_input: @@ -146,7 +146,7 @@ def __init__(self): # Avoid having to use save_main_session self.window = window - def process(self, element: int) -> Iterable[window.TimestampedValue]: # type: ignore[override] + def process(self, element: int) -> Iterable[window.TimestampedValue]: yield self.window.TimestampedValue(element, element) class GetSyntheticSDFOptions(beam.DoFn): @@ -156,7 +156,7 @@ def __init__( self.key_size = key_size self.value_size = value_size - def process(self, element: Any) -> Iterable[Dict[str, Union[int, str]]]: # type: ignore[override] + def process(self, element: Any) -> Iterable[Dict[str, Union[int, str]]]: yield { 'num_records': self.elements_per_record, 'key_size': self.key_size, diff --git a/sdks/python/apache_beam/testing/test_stream_service_test.py b/sdks/python/apache_beam/testing/test_stream_service_test.py index 5bfd0c104ba0..a04fa2303d08 100644 --- a/sdks/python/apache_beam/testing/test_stream_service_test.py +++ b/sdks/python/apache_beam/testing/test_stream_service_test.py @@ -30,9 +30,9 @@ # Nose automatically detects tests if they match a regex. Here, it mistakens # these protos as tests. For more info see the Nose docs at: # https://nose.readthedocs.io/en/latest/writing_tests.html -beam_runner_api_pb2.TestStreamPayload.__test__ = False # type: ignore[attr-defined] -beam_interactive_api_pb2.TestStreamFileHeader.__test__ = False # type: ignore[attr-defined] -beam_interactive_api_pb2.TestStreamFileRecord.__test__ = False # type: ignore[attr-defined] +beam_runner_api_pb2.TestStreamPayload.__test__ = False +beam_interactive_api_pb2.TestStreamFileHeader.__test__ = False +beam_interactive_api_pb2.TestStreamFileRecord.__test__ = False class EventsReader: diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index be3cec6304f4..9c798d3ce6dc 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -103,6 +103,7 @@ 'Windowing', 'WindowInto', 'Flatten', + 'FlattenWith', 'Create', 'Impulse', 'RestrictionProvider', @@ -1414,7 +1415,7 @@ class PartitionFn(WithTypeHints): def default_label(self): return self.__class__.__name__ - def partition_for(self, element, num_partitions, *args, **kwargs): + def partition_for(self, element, num_partitions, *args, **kwargs): # type: ignore[empty-body] # type: (T, int, *typing.Any, **typing.Any) -> int """Specify which partition will receive this element. @@ -3450,7 +3451,7 @@ def _dynamic_named_tuple(type_name, field_names): type_name, field_names) # typing: can't override a method. also, self type is unknown and can't # be cast to tuple - result.__reduce__ = lambda self: ( # type: ignore[assignment] + result.__reduce__ = lambda self: ( # type: ignore[method-assign] _unpickle_dynamic_named_tuple, (type_name, field_names, tuple(self))) # type: ignore[arg-type] return result @@ -3881,6 +3882,33 @@ def from_runner_api_parameter( common_urns.primitives.FLATTEN.urn, None, Flatten.from_runner_api_parameter) +class FlattenWith(PTransform): + """A PTransform that flattens its input with other PCollections. + + This is equivalent to creating a tuple containing both the input and the + other PCollection(s), but has the advantage that it can be more easily used + inline. + + Root PTransforms can be passed as well as PCollections, in which case their + outputs will be flattened. + """ + def __init__(self, *others): + self._others = others + + def expand(self, pcoll): + pcolls = [pcoll] + for other in self._others: + if isinstance(other, pvalue.PCollection): + pcolls.append(other) + elif isinstance(other, PTransform): + pcolls.append(pcoll.pipeline | other) + else: + raise TypeError( + 'FlattenWith only takes other PCollections and PTransforms, ' + f'got {other}') + return tuple(pcolls) | Flatten() + + class Create(PTransform): """A transform that creates a PCollection from an iterable.""" def __init__(self, values, reshuffle=True): diff --git a/sdks/python/apache_beam/transforms/display.py b/sdks/python/apache_beam/transforms/display.py index 86bbf101f567..14cd485d1f8e 100644 --- a/sdks/python/apache_beam/transforms/display.py +++ b/sdks/python/apache_beam/transforms/display.py @@ -173,7 +173,7 @@ def create_payload(dd) -> Optional[beam_runner_api_pb2.LabelledPayload]: elif isinstance(value, (float, complex)): return beam_runner_api_pb2.LabelledPayload( label=label, - double_value=value, + double_value=value, # type: ignore[arg-type] key=display_data_dict['key'], namespace=display_data_dict.get('namespace', '')) else: diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index ea98fb6b0bbd..06b40bf38cc1 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -14,11 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from collections.abc import Callable +from collections.abc import Mapping from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Mapping from typing import Optional from typing import Union @@ -30,7 +28,7 @@ from apache_beam.transforms.enrichment import EnrichmentSourceHandler QueryFn = Callable[[beam.Row], str] -ConditionValueFn = Callable[[beam.Row], List[Any]] +ConditionValueFn = Callable[[beam.Row], list[Any]] def _validate_bigquery_metadata( @@ -54,8 +52,8 @@ def _validate_bigquery_metadata( "`condition_value_fn`") -class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]], - Union[Row, List[Row]]]): +class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]], + Union[Row, list[Row]]]): """Enrichment handler for Google Cloud BigQuery. Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` @@ -83,8 +81,8 @@ def __init__( *, table_name: str = "", row_restriction_template: str = "", - fields: Optional[List[str]] = None, - column_names: Optional[List[str]] = None, + fields: Optional[list[str]] = None, + column_names: Optional[list[str]] = None, condition_value_fn: Optional[ConditionValueFn] = None, query_fn: Optional[QueryFn] = None, min_batch_size: int = 1, @@ -107,10 +105,10 @@ def __init__( row_restriction_template (str): A template string for the `WHERE` clause in the BigQuery query with placeholders (`{}`) to dynamically filter rows based on input data. - fields: (Optional[List[str]]) List of field names present in the input + fields: (Optional[list[str]]) List of field names present in the input `beam.Row`. These are used to construct the WHERE clause (if `condition_value_fn` is not provided). - column_names: (Optional[List[str]]) Names of columns to select from the + column_names: (Optional[list[str]]) Names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected. condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function that takes a `beam.Row` and returns a list of value to populate in the @@ -179,11 +177,11 @@ def create_row_key(self, row: beam.Row): return (tuple(row_dict[field] for field in self.fields)) raise ValueError("Either fields or condition_value_fn must be specified") - def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): - if isinstance(request, List): + def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): + if isinstance(request, list): values = [] responses = [] - requests_map: Dict[Any, Any] = {} + requests_map: dict[Any, Any] = {} batch_size = len(request) raw_query = self.query_template if batch_size > 1: @@ -230,8 +228,8 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() - def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]): - if isinstance(request, List): + def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]): + if isinstance(request, list): cache_keys = [] for req in request: req_dict = req._asdict() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py index ddb62c2f60d5..c251ab05ecab 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -15,9 +15,8 @@ # limitations under the License. # import logging +from collections.abc import Callable from typing import Any -from typing import Callable -from typing import Dict from typing import Optional from google.api_core.exceptions import NotFound @@ -115,7 +114,7 @@ def __call__(self, request: beam.Row, *args, **kwargs): Args: request: the input `beam.Row` to enrich. """ - response_dict: Dict[str, Any] = {} + response_dict: dict[str, Any] = {} row_key_str: str = "" try: if self._row_key_fn: diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py index 79d73178e94e..6bf57cefacbe 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -18,10 +18,7 @@ import datetime import logging import unittest -from typing import Dict -from typing import List from typing import NamedTuple -from typing import Tuple from unittest.mock import MagicMock import pytest @@ -57,8 +54,8 @@ class ValidateResponse(beam.DoFn): def __init__( self, n_fields: int, - fields: List[str], - enriched_fields: Dict[str, List[str]], + fields: list[str], + enriched_fields: dict[str, list[str]], include_timestamp: bool = False, ): self.n_fields = n_fields @@ -88,7 +85,7 @@ def process(self, element: beam.Row, *args, **kwargs): "Response from bigtable should contain a %s column_family with " "%s columns." % (column_family, columns)) if (self._include_timestamp and - not isinstance(element_dict[column_family][key][0], Tuple)): # type: ignore[arg-type] + not isinstance(element_dict[column_family][key][0], tuple)): raise BeamAssertException( "Response from bigtable should contain timestamp associated with " "its value.") diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py index dc2a71786f65..f8e8b4db1d7f 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py @@ -16,11 +16,10 @@ # import logging import tempfile +from collections.abc import Callable +from collections.abc import Mapping from pathlib import Path from typing import Any -from typing import Callable -from typing import List -from typing import Mapping from typing import Optional import apache_beam as beam @@ -95,7 +94,7 @@ class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row, def __init__( self, feature_store_yaml_path: str, - feature_names: Optional[List[str]] = None, + feature_names: Optional[list[str]] = None, feature_service_name: Optional[str] = "", full_feature_names: Optional[bool] = False, entity_id: str = "", diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py index 89cb39c2c19c..9c4dab3d68b8 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py @@ -22,8 +22,8 @@ """ import unittest +from collections.abc import Mapping from typing import Any -from typing import Mapping import pytest diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py index 753b04e1793d..b6de3aa1c826 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py @@ -15,7 +15,6 @@ # limitations under the License. # import logging -from typing import List import proto from google.api_core.exceptions import NotFound @@ -209,7 +208,7 @@ def __init__( api_endpoint: str, feature_store_id: str, entity_type_id: str, - feature_ids: List[str], + feature_ids: list[str], row_key: str, *, exception_level: ExceptionLevel = ExceptionLevel.WARN, @@ -224,7 +223,7 @@ def __init__( Vertex AI Feature Store (Legacy). feature_store_id (str): The id of the Vertex AI Feature Store (Legacy). entity_type_id (str): The entity type of the feature store. - feature_ids (List[str]): A list of feature-ids to fetch + feature_ids (list[str]): A list of feature-ids to fetch from the Feature Store. row_key (str): The row key field name containing the entity id for the feature values. diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 8a04e7efb195..83c439ca8ddd 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -740,7 +740,7 @@ def expand(self, pvalueish): components = context.to_runner_api() request = beam_expansion_api_pb2.ExpansionRequest( components=components, - namespace=self._external_namespace, + namespace=self._external_namespace, # type: ignore[arg-type] transform=transform_proto, output_coder_requests=output_coders, pipeline_options=pipeline._options.to_runner_api()) diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index fe2914a08699..c95a5d19f0cd 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -52,6 +52,7 @@ from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.utils import proto_utils from apache_beam.utils.subprocess_server import JavaJarServer +from apache_beam.utils.subprocess_server import SubprocessServer # Protect against environments where apitools library is not available. # pylint: disable=wrong-import-order, wrong-import-position @@ -718,6 +719,9 @@ def test_implicit_builder_with_constructor_method(self): class JavaJarExpansionServiceTest(unittest.TestCase): + def setUp(self): + SubprocessServer._cache._live_owners = set() + def test_classpath(self): with tempfile.TemporaryDirectory() as temp_dir: try: diff --git a/sdks/python/apache_beam/transforms/managed.py b/sdks/python/apache_beam/transforms/managed.py index 22ee15b1de1c..cbcb6de56ed7 100644 --- a/sdks/python/apache_beam/transforms/managed.py +++ b/sdks/python/apache_beam/transforms/managed.py @@ -77,12 +77,16 @@ ICEBERG = "iceberg" KAFKA = "kafka" +BIGQUERY = "bigquery" _MANAGED_IDENTIFIER = "beam:transform:managed:v1" _EXPANSION_SERVICE_JAR_TARGETS = { "sdks:java:io:expansion-service:shadowJar": [KAFKA, ICEBERG], + "sdks:java:io:google-cloud-platform:expansion-service:shadowJar": [ + BIGQUERY + ] } -__all__ = ["ICEBERG", "KAFKA", "Read", "Write"] +__all__ = ["ICEBERG", "KAFKA", "BIGQUERY", "Read", "Write"] class Read(PTransform): @@ -90,6 +94,7 @@ class Read(PTransform): _READ_TRANSFORMS = { ICEBERG: ManagedTransforms.Urns.ICEBERG_READ.urn, KAFKA: ManagedTransforms.Urns.KAFKA_READ.urn, + BIGQUERY: ManagedTransforms.Urns.BIGQUERY_READ.urn } def __init__( @@ -130,6 +135,7 @@ class Write(PTransform): _WRITE_TRANSFORMS = { ICEBERG: ManagedTransforms.Urns.ICEBERG_WRITE.urn, KAFKA: ManagedTransforms.Urns.KAFKA_WRITE.urn, + BIGQUERY: ManagedTransforms.Urns.BIGQUERY_WRITE.urn } def __init__( diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 8554ebce5dbd..4848dc4aade8 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -497,13 +497,12 @@ def type_check_inputs_or_outputs(self, pvalueish, input_or_output): at_context = ' %s %s' % (input_or_output, context) if context else '' raise TypeCheckError( '{type} type hint violation at {label}{context}: expected {hint}, ' - 'got {actual_type}\nFull type hint:\n{debug_str}'.format( + 'got {actual_type}'.format( type=input_or_output.title(), label=self.label, context=at_context, hint=hint, - actual_type=pvalue_.element_type, - debug_str=type_hints.debug_str())) + actual_type=pvalue_.element_type)) def _infer_output_coder(self, input_type=None, input_coder=None): # type: (...) -> Optional[coders.Coder] @@ -748,7 +747,7 @@ def to_runner_api(self, context, has_parts=False, **extra_kwargs): # type: (PipelineContext, bool, Any) -> beam_runner_api_pb2.FunctionSpec from apache_beam.portability.api import beam_runner_api_pb2 # typing: only ParDo supports extra_kwargs - urn, typed_param = self.to_runner_api_parameter(context, **extra_kwargs) # type: ignore[call-arg] + urn, typed_param = self.to_runner_api_parameter(context, **extra_kwargs) if urn == python_urns.GENERIC_COMPOSITE_TRANSFORM and not has_parts: # TODO(https://github.com/apache/beam/issues/18713): Remove this fallback. urn, typed_param = self.to_runner_api_pickled(context) @@ -939,7 +938,25 @@ def element_type(side_input): bindings = getcallargs_forhints(argspec_fn, *arg_types, **kwargs_types) hints = getcallargs_forhints( argspec_fn, *input_types[0], **input_types[1]) - for arg, hint in hints.items(): + + # First check the main input. + arg_hints = iter(hints.items()) + element_arg, element_hint = next(arg_hints) + if not typehints.is_consistent_with( + bindings.get(element_arg, typehints.Any), element_hint): + transform_nest_level = self.label.count("/") + split_producer_label = pvalueish.producer.full_label.split("/") + producer_label = "/".join( + split_producer_label[:transform_nest_level + 1]) + raise TypeCheckError( + f"The transform '{self.label}' requires " + f"PCollections of type '{element_hint}' " + f"but was applied to a PCollection of type" + f" '{bindings[element_arg]}' " + f"(produced by the transform '{producer_label}'). ") + + # Now check the side inputs. + for arg, hint in arg_hints: if arg.startswith('__unknown__'): continue if hint is None: diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index d760ef74fb14..7db017a59158 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -787,6 +787,18 @@ def split_even_odd(element): assert_that(even_length, equal_to(['AA', 'CC']), label='assert:even') assert_that(odd_length, equal_to(['BBB']), label='assert:odd') + def test_flatten_with(self): + with TestPipeline() as pipeline: + input = pipeline | 'Start' >> beam.Create(['AA', 'BBB', 'CC']) + + result = ( + input + | 'WithPCollection' >> beam.FlattenWith(input | beam.Map(str.lower)) + | 'WithPTransform' >> beam.FlattenWith(beam.Create(['x', 'y']))) + + assert_that( + result, equal_to(['AA', 'BBB', 'CC', 'aa', 'bbb', 'cc', 'x', 'y'])) + def test_group_by_key_input_must_be_kv_pairs(self): with self.assertRaises(typehints.TypeCheckError) as e: with TestPipeline() as pipeline: @@ -1286,17 +1298,13 @@ class ToUpperCaseWithPrefix(beam.DoFn): def process(self, element, prefix): return [prefix + element.upper()] - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'T' >> beam.Create([1, 2, 3]).with_output_types(int) | 'Upper' >> beam.ParDo(ToUpperCaseWithPrefix(), 'hello')) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for element".format(str, int)) - def test_do_fn_pipeline_runtime_type_check_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -1323,18 +1331,14 @@ class AddWithNum(beam.DoFn): def process(self, element, num): return [element + num] - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Add.*requires.*int.*applied.*str'): ( self.p | 'T' >> beam.Create(['1', '2', '3']).with_output_types(str) | 'Add' >> beam.ParDo(AddWithNum(), 5)) self.p.run() - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Add': " - "requires {} but got {} for element".format(int, str)) - def test_pardo_does_not_type_check_using_type_hint_decorators(self): @with_input_types(a=int) @with_output_types(typing.List[str]) @@ -1343,17 +1347,13 @@ def int_to_str(a): # The function above is expecting an int for its only parameter. However, it # will receive a str instead, which should result in a raised exception. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'ToStr.*requires.*int.*applied.*str'): ( self.p | 'S' >> beam.Create(['b', 'a', 'r']).with_output_types(str) | 'ToStr' >> beam.FlatMap(int_to_str)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'ToStr': " - "requires {} but got {} for a".format(int, str)) - def test_pardo_properly_type_checks_using_type_hint_decorators(self): @with_input_types(a=str) @with_output_types(typing.List[str]) @@ -1375,7 +1375,8 @@ def to_all_upper_case(a): def test_pardo_does_not_type_check_using_type_hint_methods(self): # The first ParDo outputs pcoll's of type int, however the second ParDo is # expecting pcoll's of type str instead. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'S' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str) @@ -1386,11 +1387,6 @@ def test_pardo_does_not_type_check_using_type_hint_methods(self): 'Upper' >> beam.FlatMap(lambda x: [x.upper()]).with_input_types( str).with_output_types(str))) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for x".format(str, int)) - def test_pardo_properly_type_checks_using_type_hint_methods(self): # Pipeline should be created successfully without an error d = ( @@ -1407,18 +1403,14 @@ def test_pardo_properly_type_checks_using_type_hint_methods(self): def test_map_does_not_type_check_using_type_hints_methods(self): # The transform before 'Map' has indicated that it outputs PCollections with # int's, while Map is expecting one of str. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int) | 'Upper' >> beam.Map(lambda x: x.upper()).with_input_types( str).with_output_types(str)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for x".format(str, int)) - def test_map_properly_type_checks_using_type_hints_methods(self): # No error should be raised if this type-checks properly. d = ( @@ -1437,17 +1429,13 @@ def upper(s): # Hinted function above expects a str at pipeline construction. # However, 'Map' should detect that Create has hinted an int instead. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int) | 'Upper' >> beam.Map(upper)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for s".format(str, int)) - def test_map_properly_type_checks_using_type_hints_decorator(self): @with_input_types(a=bool) @with_output_types(int) @@ -1465,7 +1453,8 @@ def bool_to_int(a): def test_filter_does_not_type_check_using_type_hints_method(self): # Filter is expecting an int but instead looks to the 'left' and sees a str # incoming. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Below 3.*requires.*int.*applied.*str'): ( self.p | 'Strs' >> beam.Create(['1', '2', '3', '4', '5' @@ -1474,11 +1463,6 @@ def test_filter_does_not_type_check_using_type_hints_method(self): str).with_output_types(str) | 'Below 3' >> beam.Filter(lambda x: x < 3).with_input_types(int)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Below 3': " - "requires {} but got {} for x".format(int, str)) - def test_filter_type_checks_using_type_hints_method(self): # No error should be raised if this type-checks properly. d = ( @@ -1496,17 +1480,13 @@ def more_than_half(a): return a > 0.50 # Func above was hinted to only take a float, yet a str will be passed. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Half.*requires.*float.*applied.*str'): ( self.p | 'Ints' >> beam.Create(['1', '2', '3', '4']).with_output_types(str) | 'Half' >> beam.Filter(more_than_half)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Half': " - "requires {} but got {} for a".format(float, str)) - def test_filter_type_checks_using_type_hints_decorator(self): @with_input_types(b=int) def half(b): @@ -2116,14 +2096,10 @@ def test_mean_globally_pipeline_checking_violated(self): self.p | 'C' >> beam.Create(['test']).with_output_types(str) | 'Mean' >> combine.Mean.Globally()) - - expected_msg = \ - "Type hint violation for 'CombinePerKey': " \ - "requires Tuple[TypeVariable[K], Union[, , " \ - ", ]] " \ - "but got Tuple[None, ] for element" - - self.assertStartswith(e.exception.args[0], expected_msg) + err_msg = e.exception.args[0] + assert "CombinePerKey" in err_msg + assert "Tuple[TypeVariable[K]" in err_msg + assert "Tuple[None, " in err_msg def test_mean_globally_runtime_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -2183,14 +2159,12 @@ def test_mean_per_key_pipeline_checking_violated(self): typing.Tuple[str, str])) | 'EvenMean' >> combine.Mean.PerKey()) self.p.run() - - expected_msg = \ - "Type hint violation for 'CombinePerKey(MeanCombineFn)': " \ - "requires Tuple[TypeVariable[K], Union[, , " \ - ", ]] " \ - "but got Tuple[, ] for element" - - self.assertStartswith(e.exception.args[0], expected_msg) + err_msg = e.exception.args[0] + assert "CombinePerKey(MeanCombineFn)" in err_msg + assert "requires" in err_msg + assert "Tuple[TypeVariable[K]" in err_msg + assert "applied" in err_msg + assert "Tuple[, ]" in err_msg def test_mean_per_key_runtime_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True diff --git a/sdks/python/apache_beam/transforms/stats.py b/sdks/python/apache_beam/transforms/stats.py index d389463e55a2..0d56b60b050f 100644 --- a/sdks/python/apache_beam/transforms/stats.py +++ b/sdks/python/apache_beam/transforms/stats.py @@ -919,7 +919,7 @@ def _offset(self, new_weight): # TODO(https://github.com/apache/beam/issues/19737): Signature incompatible # with supertype - def create_accumulator(self): # type: ignore[override] + def create_accumulator(self): # type: () -> _QuantileState self._qs = _QuantileState( unbuffered_elements=[], diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py index cad733538111..3b876bf9dbfb 100644 --- a/sdks/python/apache_beam/transforms/userstate.py +++ b/sdks/python/apache_beam/transforms/userstate.py @@ -299,7 +299,7 @@ def validate_stateful_dofn(dofn: 'DoFn') -> None: 'callback: %s.') % (dofn, timer_spec)) method_name = timer_spec._attached_callback.__name__ if (timer_spec._attached_callback != getattr(dofn, method_name, - None).__func__): + None).__func__): # type: ignore[union-attr] raise ValueError(( 'The on_timer callback for %s is not the specified .%s method ' 'for DoFn %r (perhaps it was overwritten?).') % @@ -314,7 +314,7 @@ def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: raise NotImplementedError -_TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp')) +_TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp')) # type: ignore[name-match] class RuntimeTimer(BaseTimer): diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index a27c7aca9e20..a03652de2496 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -30,6 +30,7 @@ import uuid from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import List from typing import Tuple @@ -44,6 +45,7 @@ from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import AsSideInput +from apache_beam.pvalue import PCollection from apache_beam.transforms import window from apache_beam.transforms.combiners import CountCombineFn from apache_beam.transforms.core import CombinePerKey @@ -92,6 +94,7 @@ 'RemoveDuplicates', 'Reshuffle', 'ToString', + 'Tee', 'Values', 'WithKeys', 'GroupIntoBatches' @@ -1665,6 +1668,37 @@ def _process(element): return pcoll | FlatMap(_process) +@typehints.with_input_types(T) +@typehints.with_output_types(T) +class Tee(PTransform): + """A PTransform that returns its input, but also applies its input elsewhere. + + Similar to the shell {@code tee} command. This can be useful to write out or + otherwise process an intermediate transform without breaking the linear flow + of a chain of transforms, e.g.:: + + (input + | SomePTransform() + | ... + | Tee(SomeSideTransform()) + | ...) + """ + def __init__( + self, + *consumers: Union[PTransform[PCollection[T], Any], + Callable[[PCollection[T]], Any]]): + self._consumers = consumers + + def expand(self, input): + for consumer in self._consumers: + print("apply", consumer) + if callable(consumer): + _ = input | ptransform_fn(consumer)() + else: + _ = input | consumer + return input + + @typehints.with_input_types(T) @typehints.with_output_types(T) class WaitOn(PTransform): diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 9c70be7900da..d86509c7dde3 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -19,6 +19,8 @@ # pytype: skip-file +import collections +import importlib import logging import math import random @@ -27,6 +29,7 @@ import unittest import warnings from datetime import datetime +from typing import Mapping import pytest import pytz @@ -1812,6 +1815,32 @@ def test_split_without_empty(self): assert_that(result, equal_to(expected_result)) +class TeeTest(unittest.TestCase): + _side_effects: Mapping[str, int] = collections.defaultdict(int) + + def test_tee(self): + # The imports here are to avoid issues with the class (and its attributes) + # possibly being pickled rather than referenced. + def cause_side_effect(element): + importlib.import_module(__name__).TeeTest._side_effects[element] += 1 + + def count_side_effects(element): + return importlib.import_module(__name__).TeeTest._side_effects[element] + + with TestPipeline() as p: + result = ( + p + | beam.Create(['a', 'b', 'c']) + | 'TeePTransform' >> beam.Tee(beam.Map(cause_side_effect)) + | 'TeeCallable' >> beam.Tee( + lambda pcoll: pcoll | beam.Map( + lambda element: cause_side_effect('X' + element)))) + assert_that(result, equal_to(['a', 'b', 'c'])) + + self.assertEqual(count_side_effects('a'), 1) + self.assertEqual(count_side_effects('Xa'), 1) + + class WaitOnTest(unittest.TestCase): def test_find(self): # We need shared reference that survives pickling. diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index 3baf9fa8322f..71edc75f31a6 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -409,7 +409,7 @@ def fn(a: int) -> int: return a with self.assertRaisesRegex(TypeCheckError, - r'requires .*int.* but got .*str'): + r'requires .*int.* but was applied .*str'): _ = ['a', 'b', 'c'] | Map(fn) # Same pipeline doesn't raise without annotations on fn. @@ -423,7 +423,7 @@ def fn(a: int) -> int: _ = [1, 2, 3] | Map(fn) # Doesn't raise - correct types. with self.assertRaisesRegex(TypeCheckError, - r'requires .*int.* but got .*str'): + r'requires .*int.* but was applied .*str'): _ = ['a', 'b', 'c'] | Map(fn) @decorators.no_annotations diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index fd7885ad59c4..880a897bbbe8 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -49,7 +49,8 @@ def __init__( fields: Sequence[Tuple[str, type]], user_type, schema_options: Optional[Sequence[Tuple[str, Any]]] = None, - field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None): + field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, + field_descriptions: Optional[Dict[str, str]] = None): """For internal use only, no backwards comatibility guaratees. See https://beam.apache.org/documentation/programming-guide/#schemas-for-pl-types for guidance on creating PCollections with inferred schemas. @@ -96,6 +97,7 @@ def __init__( self._schema_options = schema_options or [] self._field_options = field_options or {} + self._field_descriptions = field_descriptions or {} @staticmethod def from_user_type( @@ -107,12 +109,15 @@ def from_user_type( fields = [(name, user_type.__annotations__[name]) for name in user_type._fields] + field_descriptions = getattr(user_type, '_field_descriptions', None) + if _user_type_is_generated(user_type): return RowTypeConstraint.from_fields( fields, schema_id=getattr(user_type, _BEAM_SCHEMA_ID), schema_options=schema_options, - field_options=field_options) + field_options=field_options, + field_descriptions=field_descriptions) # TODO(https://github.com/apache/beam/issues/22125): Add user API for # specifying schema/field options @@ -120,7 +125,8 @@ def from_user_type( fields=fields, user_type=user_type, schema_options=schema_options, - field_options=field_options) + field_options=field_options, + field_descriptions=field_descriptions) return None @@ -131,13 +137,15 @@ def from_fields( schema_options: Optional[Sequence[Tuple[str, Any]]] = None, field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, schema_registry: Optional[SchemaTypeRegistry] = None, + field_descriptions: Optional[Dict[str, str]] = None, ) -> RowTypeConstraint: return GeneratedClassRowTypeConstraint( fields, schema_id=schema_id, schema_options=schema_options, field_options=field_options, - schema_registry=schema_registry) + schema_registry=schema_registry, + field_descriptions=field_descriptions) def __call__(self, *args, **kwargs): # We make RowTypeConstraint callable (defers to constructing the user type) @@ -206,6 +214,7 @@ def __init__( schema_options: Optional[Sequence[Tuple[str, Any]]] = None, field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, schema_registry: Optional[SchemaTypeRegistry] = None, + field_descriptions: Optional[Dict[str, str]] = None, ): from apache_beam.typehints.schemas import named_fields_to_schema from apache_beam.typehints.schemas import named_tuple_from_schema @@ -224,7 +233,8 @@ def __init__( fields, user_type, schema_options=schema_options, - field_options=field_options) + field_options=field_options, + field_descriptions=field_descriptions) def __reduce__(self): return ( diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index ef82ca91044c..fea9b3534b0c 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -274,6 +274,7 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: self.option_to_runner_api(option_tuple) for option_tuple in type_.field_options(field_name) ], + description=type_._field_descriptions.get(field_name, None), ) for (field_name, field_type) in type_._fields ], id=schema_id, diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index 5d38b16d9783..15144c6c2c17 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -489,6 +489,44 @@ def test_row_type_constraint_to_schema_with_field_options(self): ] self.assertEqual(list(field.options), expected) + def test_row_type_constraint_to_schema_with_field_descriptions(self): + row_type_with_options = row_type.RowTypeConstraint.from_fields( + [ + ('foo', np.int8), + ('bar', float), + ('baz', bytes), + ], + field_descriptions={ + 'foo': 'foo description', + 'bar': 'bar description', + 'baz': 'baz description', + }) + result_type = typing_to_runner_api(row_type_with_options) + + self.assertIsInstance(result_type, schema_pb2.FieldType) + self.assertEqual(result_type.WhichOneof("type_info"), "row_type") + + fields = result_type.row_type.schema.fields + + expected = [ + schema_pb2.Field( + name='foo', + description='foo description', + type=schema_pb2.FieldType(atomic_type=schema_pb2.BYTE), + ), + schema_pb2.Field( + name='bar', + description='bar description', + type=schema_pb2.FieldType(atomic_type=schema_pb2.DOUBLE), + ), + schema_pb2.Field( + name='baz', + description='baz description', + type=schema_pb2.FieldType(atomic_type=schema_pb2.BYTES), + ), + ] + self.assertEqual(list(fields), expected) + def assert_namedtuple_equivalent(self, actual, expected): # Two types are only considered equal if they are literally the same # object (i.e. `actual == expected` is the same as `actual is expected` in diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test.py b/sdks/python/apache_beam/typehints/typed_pipeline_test.py index 72aed46f5e78..57e7f44f6922 100644 --- a/sdks/python/apache_beam/typehints/typed_pipeline_test.py +++ b/sdks/python/apache_beam/typehints/typed_pipeline_test.py @@ -88,11 +88,11 @@ def process(self, element): self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): ['a', 'b', 'c'] | beam.ParDo(MyDoFn()) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_dofn_method(self): @@ -104,11 +104,11 @@ def process(self, element: int) -> typehints.Tuple[str]: self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = ['a', 'b', 'c'] | beam.ParDo(MyDoFn()) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_dofn_method_with_class_decorators(self): @@ -124,12 +124,12 @@ def process(self, element: int) -> typehints.Tuple[str]: with self.assertRaisesRegex( typehints.TypeCheckError, - r'requires.*Tuple\[, \].*got.*str'): + r'requires.*Tuple\[, \].*applied.*str'): _ = ['a', 'b', 'c'] | beam.ParDo(MyDoFn()) with self.assertRaisesRegex( typehints.TypeCheckError, - r'requires.*Tuple\[, \].*got.*int'): + r'requires.*Tuple\[, \].*applied.*int'): _ = [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_callable_iterable_output(self): @@ -156,11 +156,11 @@ def process(self, element: typehints.Tuple[int, int]) -> \ self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = ['a', 'b', 'c'] | beam.ParDo(my_do_fn) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = [1, 2, 3] | (beam.ParDo(my_do_fn) | 'again' >> beam.ParDo(my_do_fn)) def test_typed_callable_instance(self): @@ -177,11 +177,11 @@ def do_fn(element: typehints.Tuple[int, int]) -> typehints.Generator[str]: self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = ['a', 'b', 'c'] | pardo with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = [1, 2, 3] | (pardo | 'again' >> pardo) def test_filter_type_hint(self): @@ -430,7 +430,7 @@ def fn(element: float): return pcoll | beam.ParDo(fn) with self.assertRaisesRegex(typehints.TypeCheckError, - r'ParDo.*requires.*float.*got.*str'): + r'ParDo.*requires.*float.*applied.*str'): _ = ['1', '2', '3'] | MyMap() with self.assertRaisesRegex(typehints.TypeCheckError, r'MyMap.*expected.*str.*got.*bytes'): @@ -632,14 +632,14 @@ def produces_unkown(e): return e @typehints.with_input_types(int) - def requires_int(e): + def accepts_int(e): return e class MyPTransform(beam.PTransform): def expand(self, pcoll): unknowns = pcoll | beam.Map(produces_unkown) ints = pcoll | beam.Map(int) - return (unknowns, ints) | beam.Flatten() | beam.Map(requires_int) + return (unknowns, ints) | beam.Flatten() | beam.Map(accepts_int) _ = [1, 2, 3] | MyPTransform() @@ -761,8 +761,8 @@ def test_var_positional_only_side_input_hint(self): with self.assertRaisesRegex( typehints.TypeCheckError, - r'requires Tuple\[Union\[, \], ...\] but ' - r'got Tuple\[Union\[, \], ...\]'): + r'requires.*Tuple\[Union\[, \], ...\].*' + r'applied.*Tuple\[Union\[, \], ...\]'): _ = [1.2] | beam.Map(lambda *_: 'a', 5).with_input_types(int, str) def test_var_keyword_side_input_hint(self): diff --git a/sdks/python/apache_beam/utils/profiler.py b/sdks/python/apache_beam/utils/profiler.py index c75fdcc5878d..61c2371bd07d 100644 --- a/sdks/python/apache_beam/utils/profiler.py +++ b/sdks/python/apache_beam/utils/profiler.py @@ -104,7 +104,9 @@ def __exit__(self, *args): self.profile.create_stats() self.profile_output = self._upload_profile_data( # typing: seems stats attr is missing from typeshed - self.profile_location, 'cpu_profile', self.profile.stats) # type: ignore[attr-defined] + self.profile_location, + 'cpu_profile', + self.profile.stats) if self.enable_memory_profiling: if not self.hpy: diff --git a/sdks/python/apache_beam/utils/proto_utils.py b/sdks/python/apache_beam/utils/proto_utils.py index 9a93c9e48ea3..60c0af2ebac0 100644 --- a/sdks/python/apache_beam/utils/proto_utils.py +++ b/sdks/python/apache_beam/utils/proto_utils.py @@ -46,7 +46,7 @@ def pack_Any(msg: message.Message) -> any_pb2.Any: @overload -def pack_Any(msg: None) -> None: +def pack_Any(msg: None) -> None: # type: ignore[overload-cannot-match] pass diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py index 944c12625d7c..b1080cb643af 100644 --- a/sdks/python/apache_beam/utils/subprocess_server.py +++ b/sdks/python/apache_beam/utils/subprocess_server.py @@ -266,11 +266,17 @@ class JavaJarServer(SubprocessServer): 'local', (threading.local, ), dict(__init__=lambda self: setattr(self, 'replacements', {})))() - def __init__(self, stub_class, path_to_jar, java_arguments, classpath=None): + def __init__( + self, + stub_class, + path_to_jar, + java_arguments, + classpath=None, + cache_dir=None): if classpath: # java -jar ignores the classpath, so we make a new jar that embeds # the requested classpath. - path_to_jar = self.make_classpath_jar(path_to_jar, classpath) + path_to_jar = self.make_classpath_jar(path_to_jar, classpath, cache_dir) super().__init__( stub_class, ['java', '-jar', path_to_jar] + list(java_arguments)) self._existing_service = path_to_jar if is_service_endpoint( diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py index 4719bc3e66aa..4088e17afe2c 100644 --- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py +++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py @@ -20,12 +20,17 @@ import itertools import re +import docstring_parser import yaml from apache_beam.portability.api import schema_pb2 +from apache_beam.typehints import schemas from apache_beam.utils import subprocess_server +from apache_beam.utils.python_callable import PythonCallableWithSource +from apache_beam.version import __version__ as beam_version from apache_beam.yaml import json_utils from apache_beam.yaml import yaml_provider +from apache_beam.yaml.yaml_mapping import ErrorHandlingConfig def _singular(name): @@ -134,8 +139,28 @@ def maybe_row_parameters(t): def maybe_optional(t): return " (Optional)" if t.nullable else "" + def normalize_error_handling(f): + doc = docstring_parser.parse( + ErrorHandlingConfig.__doc__, docstring_parser.DocstringStyle.GOOGLE) + if f.name == "error_handling": + f = schema_pb2.Field( + name="error_handling", + type=schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[ + schemas.schema_field( + param.arg_name, + PythonCallableWithSource.load_from_expression( + param.type_name), + param.description) for param in doc.params + ]))), + description=f.description) + return f + def lines(): for f in schema.fields: + f = normalize_error_handling(f) yield ''.join([ f'**{f.name}** `{pretty_type(f.type)}`', maybe_optional(f.type), @@ -284,42 +309,143 @@ def main(): markdown.extensions.toc.TocExtension(toc_depth=2), 'codehilite', ]) - html = md.convert(markdown_out.getvalue()) pygments_style = pygments.formatters.HtmlFormatter().get_style_defs( '.codehilite') extra_style = ''' - .nav { - height: 100%; - width: 12em; + * { + box-sizing: border-box; + } + body { + font-family: 'Roboto', sans-serif; + font-weight: normal; + color: #404040; + background: #edf0f2; + } + .body-for-nav { + background: #fcfcfc; + } + .grid-for-nav { + width: 100%; + } + .nav-side { position: fixed; top: 0; left: 0; - overflow-x: hidden; + width: 300px; + height: 100%; + padding-bottom: 2em; + color: #9b9b9b; + background: #343131; } - .nav a { - color: #333; - padding: .2em; + .nav-header { display: block; - text-decoration: none; + width: 300px; + padding: 1em; + background-color: #2980B9; + text-align: center; + color: #fcfcfc; + } + .nav-header a { + color: #fcfcfc; + font-weight: bold; + display: inline-block; + padding: 4px 6px; + margin-bottom: 1em; + text-decoration:none; } - .nav a:hover { - color: #888; + .nav-header>div.version { + margin-top: -.5em; + margin-bottom: 1em; + font-weight: normal; + color: rgba(255, 255, 255, 0.3); } - .nav li { - list-style-type: none; + .toc { + width: 300px; + text-align: left; + overflow-y: auto; + max-height: calc(100% - 4.3em); + scrollbar-width: thin; + scrollbar-color: #9b9b9b #343131; + } + .toc ul { margin: 0; padding: 0; + list-style: none; + } + .toc li { + border-bottom: 1px solid #4e4a4a; + margin-left: 1em; + } + .toc a { + display: block; + line-height: 36px; + font-size: 90%; + color: #d9d9d9; + padding: .1em 0.6em; + text-decoration: none; + transition: background-color 0.3s ease, color 0.3s ease; + } + .toc a:hover { + background-color: #4e4a4a; + color: #ffffff; + } + .transform-content-wrap { + margin-left: 300px; + background: #fcfcfc; + } + .transform-content { + padding: 1.5em 3em; + margin: 20px; + padding-bottom: 2em; + } + .transform-content li::marker { + display: inline-block; + width: 0.5em; + } + .transform-content h1 { + font-size: 40px; + } + .transform-content ul { + margin-left: 0.75em; + text-align: left; + list-style-type: disc; + } + hr { + color: gray; + display: block; + height: 1px; + border: 0; + border-top: 1px solid #e1e4e5; + margin-bottom: 3em; + margin-top: 3em; + padding: 0; } - .content { - margin-left: 12em; + .codehilite { + background: #f5f5f5; + border: 1px solid #ccc; + border-radius: 4px; + padding: 0.2em 1em; + overflow: auto; + font-family: monospace; + font-size: 14px; + line-height: 1.5; } - h2 { - margin-top: 2em; + p code, li code { + white-space: nowrap; + max-width: 100%; + background: #fff; + border: solid 1px #e1e4e5; + padding: 0 5px; + font-family: monospace; + color: #404040; + font-weight: bold; + padding: 2px 5px; } ''' - with open(options.html_file, 'w') as fout: - fout.write( + html = md.convert(markdown_out.getvalue()) + with open(options.html_file, 'w') as html_out: + html_out.write( f''' @@ -329,13 +455,23 @@ def main(): {extra_style} - - -
    -

    {title}

    - {html} + +
    + +
    +
    +

    {title}

    + {html.replace(' +
    diff --git a/sdks/python/apache_beam/yaml/integration_tests.py b/sdks/python/apache_beam/yaml/integration_tests.py index af1be7b1e8e5..72b3918195da 100644 --- a/sdks/python/apache_beam/yaml/integration_tests.py +++ b/sdks/python/apache_beam/yaml/integration_tests.py @@ -69,7 +69,7 @@ def temp_bigquery_table(project, prefix='yaml_bq_it_'): dataset_id = '%s_%s' % (prefix, uuid.uuid4().hex) bigquery_client.get_or_create_dataset(project, dataset_id) logging.info("Created dataset %s in project %s", dataset_id, project) - yield f'{project}:{dataset_id}.tmp_table' + yield f'{project}.{dataset_id}.tmp_table' request = bigquery.BigqueryDatasetsDeleteRequest( projectId=project, datasetId=dataset_id, deleteContents=True) logging.info("Deleting dataset %s in project %s", dataset_id, project) diff --git a/sdks/python/apache_beam/yaml/json_utils.py b/sdks/python/apache_beam/yaml/json_utils.py index 40e515ee6946..76cc80bc2036 100644 --- a/sdks/python/apache_beam/yaml/json_utils.py +++ b/sdks/python/apache_beam/yaml/json_utils.py @@ -106,6 +106,18 @@ def json_type_to_beam_type(json_type: Dict[str, Any]) -> schema_pb2.FieldType: raise ValueError(f'Unable to convert {json_type} to a Beam schema.') +def beam_schema_to_json_schema( + beam_schema: schema_pb2.Schema) -> Dict[str, Any]: + return { + 'type': 'object', + 'properties': { + field.name: beam_type_to_json_type(field.type) + for field in beam_schema.fields + }, + 'additionalProperties': False + } + + def beam_type_to_json_type(beam_type: schema_pb2.FieldType) -> Dict[str, Any]: type_info = beam_type.WhichOneof("type_info") if type_info == "atomic_type": @@ -267,3 +279,52 @@ def json_formater( convert = row_to_json( schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) return lambda row: json.dumps(convert(row), sort_keys=True).encode('utf-8') + + +def _validate_compatible(weak_schema, strong_schema): + if not weak_schema: + return + if weak_schema['type'] != strong_schema['type']: + raise ValueError( + 'Incompatible types: %r vs %r' % + (weak_schema['type'] != strong_schema['type'])) + if weak_schema['type'] == 'array': + _validate_compatible(weak_schema['items'], strong_schema['items']) + elif weak_schema == 'object': + for required in strong_schema.get('required', []): + if required not in weak_schema['properties']: + raise ValueError('Missing or unkown property %r' % required) + for name, spec in weak_schema.get('properties', {}): + if name in strong_schema['properties']: + try: + _validate_compatible(spec, strong_schema['properties'][name]) + except Exception as exn: + raise ValueError('Incompatible schema for %r' % name) from exn + elif not strong_schema.get('additionalProperties'): + raise ValueError( + 'Prohibited property: {property}; ' + 'perhaps additionalProperties: False is missing?') + + +def row_validator(beam_schema: schema_pb2.Schema, + json_schema: Dict[str, Any]) -> Callable[[Any], Any]: + """Returns a callable that will fail on elements not respecting json_schema. + """ + if not json_schema: + return lambda x: None + + # Validate that this compiles, but avoid pickling the validator itself. + _ = jsonschema.validators.validator_for(json_schema)(json_schema) + _validate_compatible(beam_schema_to_json_schema(beam_schema), json_schema) + validator = None + + convert = row_to_json( + schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) + + def validate(row): + nonlocal validator + if validator is None: + validator = jsonschema.validators.validator_for(json_schema)(json_schema) + validator.validate(convert(row)) + + return validate diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml index 4de36b3dc9e0..400ab07a41fa 100644 --- a/sdks/python/apache_beam/yaml/standard_io.yaml +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -271,6 +271,8 @@ table: 'table_id' query: 'query' columns: 'columns' + index: 'index' + batching: 'batching' 'WriteToSpanner': project: 'project_id' instance: 'instance_id' diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 574179805959..242faaa9a77b 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -101,3 +101,8 @@ Explode: "beam:schematransform:org.apache.beam:yaml:explode:v1" config: gradle_target: 'sdks:java:extensions:sql:expansion-service:shadowJar' + +- type: 'python' + config: {} + transforms: + Enrichment: 'apache_beam.yaml.yaml_enrichment.enrichment_transform' diff --git a/sdks/python/apache_beam/yaml/tests/enrichment.yaml b/sdks/python/apache_beam/yaml/tests/enrichment.yaml new file mode 100644 index 000000000000..6469c094b8b4 --- /dev/null +++ b/sdks/python/apache_beam/yaml/tests/enrichment.yaml @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +fixtures: + - name: BQ_TABLE + type: "apache_beam.yaml.integration_tests.temp_bigquery_table" + config: + project: "apache-beam-testing" + - name: TEMP_DIR + type: "apache_beam.yaml.integration_tests.gcs_temp_dir" + config: + bucket: "gs://temp-storage-for-end-to-end-tests/temp-it" + +pipelines: + - pipeline: + type: chain + transforms: + - type: Create + name: Rows + config: + elements: + - {label: '11a', rank: 0} + - {label: '37a', rank: 1} + - {label: '389a', rank: 2} + + - type: WriteToBigQuery + config: + table: "{BQ_TABLE}" + + - pipeline: + type: chain + transforms: + - type: Create + name: Data + config: + elements: + - {label: '11a', name: 'S1'} + - {label: '37a', name: 'S2'} + - {label: '389a', name: 'S3'} + - type: Enrichment + name: Enriched + config: + enrichment_handler: 'BigQuery' + handler_config: + project: apache-beam-testing + table_name: "{BQ_TABLE}" + fields: ['label'] + row_restriction_template: "label = '37a'" + timeout: 30 + + - type: MapToFields + config: + language: python + fields: + label: + callable: 'lambda x: x.label' + output_type: string + rank: + callable: 'lambda x: x.rank' + output_type: integer + name: + callable: 'lambda x: x.name' + output_type: string + + - type: AssertEqual + config: + elements: + - {label: '37a', rank: 1, name: 'S2'} + options: + yaml_experimental_features: [ 'Enrichment' ] \ No newline at end of file diff --git a/sdks/python/apache_beam/yaml/tests/map.yaml b/sdks/python/apache_beam/yaml/tests/map.yaml index b676966ad6bd..04f057cb2e82 100644 --- a/sdks/python/apache_beam/yaml/tests/map.yaml +++ b/sdks/python/apache_beam/yaml/tests/map.yaml @@ -30,6 +30,7 @@ pipelines: config: append: true fields: + # TODO(https://github.com/apache/beam/issues/32832): Figure out why Java sometimes re-orders these fields. named_field: element literal_int: 10 literal_float: 1.5 diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py new file mode 100644 index 000000000000..9bea17f78fdd --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -0,0 +1,106 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any +from typing import Dict +from typing import Optional + +import apache_beam as beam +from apache_beam.yaml import options + +try: + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import VertexAIFeatureStoreEnrichmentHandler +except ImportError: + Enrichment = None # type: ignore + BigQueryEnrichmentHandler = None # type: ignore + BigTableEnrichmentHandler = None # type: ignore + VertexAIFeatureStoreEnrichmentHandler = None # type: ignore + +try: + from apache_beam.transforms.enrichment_handlers.feast_feature_store import FeastFeatureStoreEnrichmentHandler +except ImportError: + FeastFeatureStoreEnrichmentHandler = None # type: ignore + + +@beam.ptransform.ptransform_fn +def enrichment_transform( + pcoll, + enrichment_handler: str, + handler_config: Dict[str, Any], + timeout: Optional[float] = 30): + """ + The Enrichment transform allows you to dynamically + enhance elements in a pipeline by performing key-value + lookups against external services like APIs or databases. + + Example Usage:: + + - type: Enrichment + config: + enrichment_handler: 'BigTable' + handler_config: + project_id: 'apache-beam-testing' + instance_id: 'beam-test' + table_id: 'bigtable-enrichment-test' + row_key: 'product_id' + timeout: 30 + + Args: + enrichment_handler: Specifies the source from + where data needs to be extracted + into the pipeline for enriching data. + It can be a string value in ["BigQuery", + "BigTable", "FeastFeatureStore", + "VertexAIFeatureStore"]. + handler_config: Specifies the parameters for + the respective enrichment_handler in a dictionary format. + To see the full set of handler_config parameters, see + their corresponding doc pages: + + - :class:`~apache_beam.transforms.enrichment_handlers.bigquery.BigQueryEnrichmentHandler` # pylint: disable=line-too-long + - :class:`~apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler` # pylint: disable=line-too-long + - :class:`~apache_beam.transforms.enrichment_handlers.feast_feature_store.FeastFeatureStoreEnrichmentHandler` # pylint: disable=line-too-long + - :class:`~apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store.VertexAIFeatureStoreEnrichmentHandler` # pylint: disable=line-too-long + """ + options.YamlOptions.check_enabled(pcoll.pipeline, 'Enrichment') + + if not Enrichment: + raise ValueError( + f"gcp dependencies not installed. Cannot use {enrichment_handler} " + f"handler. Please install using 'pip install apache-beam[gcp]'.") + + if (enrichment_handler == 'FeastFeatureStore' and + not FeastFeatureStoreEnrichmentHandler): + raise ValueError( + "FeastFeatureStore handler requires 'feast' package to be installed. " + + "Please install using 'pip install feast[gcp]' and try again.") + + handler_map = { + 'BigQuery': BigQueryEnrichmentHandler, + 'BigTable': BigTableEnrichmentHandler, + 'FeastFeatureStore': FeastFeatureStoreEnrichmentHandler, + 'VertexAIFeatureStore': VertexAIFeatureStoreEnrichmentHandler + } + + if enrichment_handler not in handler_map: + raise ValueError(f"Unknown enrichment source: {enrichment_handler}") + + handler = handler_map[enrichment_handler](**handler_config) + return pcoll | Enrichment(source_handler=handler, timeout=timeout) diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py new file mode 100644 index 000000000000..e26d6140af23 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import unittest + +import mock + +import apache_beam as beam +from apache_beam import Row +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.yaml.yaml_transform import YamlTransform + + +class FakeEnrichmentTransform: + def __init__(self, enrichment_handler, handler_config, timeout=30): + self._enrichment_handler = enrichment_handler + self._handler_config = handler_config + self._timeout = timeout + + def __call__(self, enrichment_handler, *, handler_config, timeout=30): + assert enrichment_handler == self._enrichment_handler + assert handler_config == self._handler_config + assert timeout == self._timeout + return beam.Map(lambda x: beam.Row(**x._asdict())) + + +class EnrichmentTransformTest(unittest.TestCase): + def test_enrichment_with_bigquery(self): + input_data = [ + Row(label="item1", rank=0), + Row(label="item2", rank=1), + ] + + handler = 'BigQuery' + config = { + "project": "apache-beam-testing", + "table_name": "project.database.table", + "row_restriction_template": "label='item1' or label='item2'", + "fields": ["label"] + } + + with beam.Pipeline() as p: + with mock.patch('apache_beam.yaml.yaml_enrichment.enrichment_transform', + FakeEnrichmentTransform(enrichment_handler=handler, + handler_config=config)): + input_pcoll = p | 'CreateInput' >> beam.Create(input_data) + result = input_pcoll | YamlTransform( + f''' + type: Enrichment + config: + enrichment_handler: {handler} + handler_config: {config} + ''') + assert_that(result, equal_to(input_data)) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 377bcac0e31a..5c14b0f5ea79 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -43,6 +43,7 @@ from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import named_fields_from_element_type +from apache_beam.typehints.schemas import schema_from_element_type from apache_beam.utils import python_callable from apache_beam.yaml import json_utils from apache_beam.yaml import options @@ -417,6 +418,11 @@ def checking_func(row): class ErrorHandlingConfig(NamedTuple): + """Class to define Error Handling parameters. + + Args: + output (str): Name to use for the output error collection + """ output: str # TODO: Other parameters are valid here too, but not common to Java. @@ -435,7 +441,8 @@ def _map_errors_to_standard_format(input_type): # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple. return beam.Map( - lambda x: beam.Row(element=x[0], msg=str(x[1][1]), stack=str(x[1][2])) + lambda x: beam.Row( + element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2])) ).with_output_types( RowTypeConstraint.from_fields([("element", input_type), ("msg", str), ("stack", str)])) @@ -475,6 +482,40 @@ def expand(pcoll, error_handling=None, **kwargs): return expand +class _Validate(beam.PTransform): + """Validates each element of a PCollection against a json schema. + + Args: + schema: A json schema against which to validate each element. + error_handling: Whether and how to handle errors during iteration. + If this is not set, invalid elements will fail the pipeline, otherwise + invalid elements will be passed to the specified error output along + with information about how the schema was invalidated. + """ + def __init__( + self, + schema: Dict[str, Any], + error_handling: Optional[Mapping[str, Any]] = None): + self._schema = schema + self._exception_handling_args = exception_handling_args(error_handling) + + @maybe_with_exception_handling + def expand(self, pcoll): + validator = json_utils.row_validator( + schema_from_element_type(pcoll.element_type), self._schema) + + def invoke_validator(x): + validator(x) + return x + + return pcoll | beam.Map(invoke_validator) + + def with_exception_handling(self, **kwargs): + # It's possible there's an error in iteration... + self._exception_handling_args = kwargs + return self + + class _Explode(beam.PTransform): """Explodes (aka unnest/flatten) one or more fields producing multiple rows. @@ -797,6 +838,7 @@ def create_mapping_providers(): 'Partition-python': _Partition, 'Partition-javascript': _Partition, 'Partition-generic': _Partition, + 'ValidateWithSchema': _Validate, }), yaml_provider.SqlBackedProvider({ 'Filter-sql': _SqlFilterTransform, diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index 1b74a765e54b..2c5feec18278 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -134,6 +134,43 @@ def test_explode(self): beam.Row(a=3, b='y', c=.125, range=2), ])) + def test_validate(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(key='good', small=[5], nested=beam.Row(big=100)), + beam.Row(key='bad1', small=[500], nested=beam.Row(big=100)), + beam.Row(key='bad2', small=[5], nested=beam.Row(big=1)), + ]) + result = elements | YamlTransform( + ''' + type: ValidateWithSchema + config: + schema: + type: object + properties: + small: + type: array + items: + type: integer + maximum: 10 + nested: + type: object + properties: + big: + type: integer + minimum: 10 + error_handling: + output: bad + ''') + + assert_that( + result['good'] | beam.Map(lambda x: x.key), equal_to(['good'])) + assert_that( + result['bad'] | beam.Map(lambda x: x.element.key), + equal_to(['bad1', 'bad2']), + label='Errors') + def test_validate_explicit_types(self): with self.assertRaisesRegex(TypeError, r'.*violates schema.*'): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( diff --git a/sdks/python/container/build.gradle b/sdks/python/container/build.gradle index f07b6f743fa4..14c08a3a539b 100644 --- a/sdks/python/container/build.gradle +++ b/sdks/python/container/build.gradle @@ -20,7 +20,7 @@ plugins { id 'org.apache.beam.module' } applyGoNature() description = "Apache Beam :: SDKs :: Python :: Container" -int min_python_version=8 +int min_python_version=9 int max_python_version=12 configurations { diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt deleted file mode 100644 index 0a67a3666d25..000000000000 --- a/sdks/python/container/py38/base_image_requirements.txt +++ /dev/null @@ -1,172 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Autogenerated requirements file for Apache Beam py38 container image. -# Run ./gradlew :sdks:python:container:generatePythonRequirementsAll to update. -# Do not edit manually, adjust ../base_image_requirements_manual.txt or -# Apache Beam's setup.py instead, and regenerate the list. -# You will need Python interpreters for all versions supported by Beam, see: -# https://s.apache.org/beam-python-dev-wiki -# Reach out to a committer if you need help. - -annotated-types==0.7.0 -async-timeout==4.0.3 -attrs==24.2.0 -backports.tarfile==1.2.0 -beautifulsoup4==4.12.3 -bs4==0.0.2 -build==1.2.2 -cachetools==5.5.0 -certifi==2024.8.30 -cffi==1.17.1 -charset-normalizer==3.3.2 -click==8.1.7 -cloudpickle==2.2.1 -cramjam==2.8.4 -crcmod==1.7 -cryptography==43.0.1 -Cython==3.0.11 -Deprecated==1.2.14 -deprecation==2.1.0 -dill==0.3.1.1 -dnspython==2.6.1 -docker==7.1.0 -docopt==0.6.2 -docstring_parser==0.16 -exceptiongroup==1.2.2 -execnet==2.1.1 -fastavro==1.9.7 -fasteners==0.19 -freezegun==1.5.1 -future==1.0.0 -google-api-core==2.20.0 -google-api-python-client==2.147.0 -google-apitools==0.5.31 -google-auth==2.35.0 -google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.69.0 -google-cloud-bigquery==3.26.0 -google-cloud-bigquery-storage==2.26.0 -google-cloud-bigtable==2.26.0 -google-cloud-core==2.4.1 -google-cloud-datastore==2.20.1 -google-cloud-dlp==3.23.0 -google-cloud-language==2.14.0 -google-cloud-profiler==4.1.0 -google-cloud-pubsub==2.25.2 -google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.12 -google-cloud-resource-manager==1.12.5 -google-cloud-spanner==3.49.1 -google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.13.5 -google-cloud-vision==3.7.4 -google-crc32c==1.5.0 -google-resumable-media==2.7.2 -googleapis-common-protos==1.65.0 -greenlet==3.1.1 -grpc-google-iam-v1==0.13.1 -grpc-interceptor==0.15.4 -grpcio==1.65.5 -grpcio-status==1.62.3 -guppy3==3.1.4.post1 -hdfs==2.7.3 -httplib2==0.22.0 -hypothesis==6.112.3 -idna==3.10 -importlib_metadata==8.4.0 -importlib_resources==6.4.5 -iniconfig==2.0.0 -jaraco.classes==3.4.0 -jaraco.context==6.0.1 -jaraco.functools==4.1.0 -jeepney==0.8.0 -Jinja2==3.1.4 -joblib==1.4.2 -jsonpickle==3.3.0 -jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -keyring==25.4.1 -keyrings.google-artifactregistry-auth==1.1.2 -MarkupSafe==2.1.5 -mmh3==5.0.1 -mock==5.1.0 -more-itertools==10.5.0 -nltk==3.9.1 -nose==1.3.7 -numpy==1.24.4 -oauth2client==4.1.3 -objsize==0.7.0 -opentelemetry-api==1.27.0 -opentelemetry-sdk==1.27.0 -opentelemetry-semantic-conventions==0.48b0 -orjson==3.10.7 -overrides==7.7.0 -packaging==24.1 -pandas==2.0.3 -parameterized==0.9.0 -pkgutil_resolve_name==1.3.10 -pluggy==1.5.0 -proto-plus==1.24.0 -protobuf==4.25.5 -psycopg2-binary==2.9.9 -pyarrow==16.1.0 -pyarrow-hotfix==0.6 -pyasn1==0.6.1 -pyasn1_modules==0.4.1 -pycparser==2.22 -pydantic==2.9.2 -pydantic_core==2.23.4 -pydot==1.4.2 -PyHamcrest==2.1.0 -pymongo==4.10.1 -PyMySQL==1.1.1 -pyparsing==3.1.4 -pyproject_hooks==1.2.0 -pytest==7.4.4 -pytest-timeout==2.3.1 -pytest-xdist==3.6.1 -python-dateutil==2.9.0.post0 -python-snappy==0.7.3 -pytz==2024.2 -PyYAML==6.0.2 -redis==5.1.1 -referencing==0.35.1 -regex==2024.9.11 -requests==2.32.3 -requests-mock==1.12.1 -rpds-py==0.20.0 -rsa==4.9 -scikit-learn==1.3.2 -scipy==1.10.1 -SecretStorage==3.3.3 -shapely==2.0.6 -six==1.16.0 -sortedcontainers==2.4.0 -soupsieve==2.6 -SQLAlchemy==2.0.35 -sqlparse==0.5.1 -tenacity==8.5.0 -testcontainers==3.7.1 -threadpoolctl==3.5.0 -tomli==2.0.2 -tqdm==4.66.5 -typing_extensions==4.12.2 -tzdata==2024.2 -uritemplate==4.1.1 -urllib3==2.2.3 -wrapt==1.16.0 -zipp==3.20.2 -zstandard==0.23.0 diff --git a/sdks/python/container/py38/build.gradle b/sdks/python/container/py38/build.gradle deleted file mode 100644 index 304895a83718..000000000000 --- a/sdks/python/container/py38/build.gradle +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -plugins { - id 'base' - id 'org.apache.beam.module' -} -applyDockerNature() -applyPythonNature() - -pythonVersion = '3.8' - -apply from: "../common.gradle" diff --git a/sdks/python/expansion-service-container/Dockerfile b/sdks/python/expansion-service-container/Dockerfile index 5a5ef0f410bc..4e82165f594c 100644 --- a/sdks/python/expansion-service-container/Dockerfile +++ b/sdks/python/expansion-service-container/Dockerfile @@ -17,7 +17,7 @@ ############################################################################### # We just need to support one Python version supported by Beam. -# Picking the current default Beam Python version which is Python 3.8. +# Picking the current default Beam Python version which is Python 3.9. FROM python:3.9-bookworm as expansion-service LABEL Author "Apache Beam " ARG TARGETOS diff --git a/sdks/python/expansion-service-container/build.gradle b/sdks/python/expansion-service-container/build.gradle index 3edcaee35b4a..4e46f060e59f 100644 --- a/sdks/python/expansion-service-container/build.gradle +++ b/sdks/python/expansion-service-container/build.gradle @@ -40,7 +40,7 @@ task copyDockerfileDependencies(type: Copy) { } task copyRequirementsFile(type: Copy) { - from project(':sdks:python:container:py38').fileTree("./") + from project(':sdks:python:container:py39').fileTree("./") include 'base_image_requirements.txt' rename 'base_image_requirements.txt', 'requirements.txt' setDuplicatesStrategy(DuplicatesStrategy.INCLUDE) diff --git a/sdks/python/mypy.ini b/sdks/python/mypy.ini index 562cb8d56dcc..ee76089fec0b 100644 --- a/sdks/python/mypy.ini +++ b/sdks/python/mypy.ini @@ -28,11 +28,17 @@ files = apache_beam color_output = true # uncomment this to see how close we are to being complete # check_untyped_defs = true -disable_error_code = var-annotated +disable_error_code = var-annotated, import-untyped, valid-type, truthy-function, attr-defined, annotation-unchecked + +[tool.mypy] +ignore_missing_imports = true [mypy-apache_beam.coders.proto2_coder_test_messages_pb2] ignore_errors = true +[mypy-apache_beam.dataframe.*] +ignore_errors = true + [mypy-apache_beam.examples.*] ignore_errors = true diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index a99599a2ce2b..4eb827297019 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -21,7 +21,7 @@ requires = [ "setuptools", "wheel>=0.36.0", - "grpcio-tools==1.65.5", + "grpcio-tools==1.62.1", "mypy-protobuf==3.5.0", # Avoid https://github.com/pypa/virtualenv/issues/2006 "distlib==0.3.7", diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 3462429190c8..21561e1bf6a9 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -124,7 +124,6 @@ extensions = [ ] master_doc = 'index' html_theme = 'sphinx_rtd_theme' -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] project = 'Apache Beam' version = beam_version.__version__ release = version @@ -244,6 +243,9 @@ ignore_identifiers = [ # IPython Magics py:class reference target not found 'IPython.core.magic.Magics', + + # Type variables. + 'apache_beam.transforms.util.T', ] ignore_references = [ 'BeamIOError', diff --git a/sdks/python/setup.py b/sdks/python/setup.py index cac27db69803..9ae5d3153f51 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -155,7 +155,7 @@ def cythonize(*args, **kwargs): # Exclude 1.5.0 and 1.5.1 because of # https://github.com/pandas-dev/pandas/issues/45725 dataframe_dependency = [ - 'pandas>=1.4.3,!=1.5.0,!=1.5.1,<2.3;python_version>="3.8"', + 'pandas>=1.4.3,!=1.5.0,!=1.5.1,<2.3', ] @@ -271,18 +271,13 @@ def get_portability_package_data(): return files -python_requires = '>=3.8' +python_requires = '>=3.9' -if sys.version_info.major == 3 and sys.version_info.minor >= 12: +if sys.version_info.major == 3 and sys.version_info.minor >= 13: warnings.warn( 'This version of Apache Beam has not been sufficiently tested on ' 'Python %s.%s. You may encounter bugs or missing features.' % (sys.version_info.major, sys.version_info.minor)) -elif sys.version_info.major == 3 and sys.version_info.minor == 8: - warnings.warn('Python 3.8 reaches EOL in October 2024 and support will ' - 'be removed from Apache Beam in version 2.61.0. See ' - 'https://github.com/apache/beam/issues/31192 for more ' - 'information.') if __name__ == '__main__': # In order to find the tree of proto packages, the directory @@ -534,7 +529,6 @@ def get_portability_package_data(): 'Intended Audience :: End Users/Desktop', 'License :: OSI Approved :: Apache Software License', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 6bca904c1a64..71d44652bc7e 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -543,12 +543,13 @@ task mockAPITests { } // add all RunInference E2E tests that run on DataflowRunner -// As of now, this test suite is enable in py38 suite as the base NVIDIA image used for Tensor RT -// contains Python 3.8. +// As of now, this test suite is enable in py310 suite as the base NVIDIA image used for Tensor RT +// contains Python 3.10. // TODO: https://github.com/apache/beam/issues/22651 project.tasks.register("inferencePostCommitIT") { dependsOn = [ - 'tensorRTtests', + // Temporarily disabled because of a container issue + // 'tensorRTtests', 'vertexAIInferenceTest', 'mockAPITests', ] diff --git a/sdks/python/test-suites/dataflow/py38/build.gradle b/sdks/python/test-suites/dataflow/py38/build.gradle deleted file mode 100644 index b3c3a5bfb8a6..000000000000 --- a/sdks/python/test-suites/dataflow/py38/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -apply plugin: org.apache.beam.gradle.BeamModulePlugin -applyPythonNature() - -// Required to setup a Python 3 virtualenv and task names. -pythonVersion = '3.8' -apply from: "../common.gradle" diff --git a/sdks/python/test-suites/direct/py38/build.gradle b/sdks/python/test-suites/direct/py38/build.gradle deleted file mode 100644 index edf86a7bf5a8..000000000000 --- a/sdks/python/test-suites/direct/py38/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -plugins { id 'org.apache.beam.module' } -applyPythonNature() - -// Required to setup a Python 3 virtualenv and task names. -pythonVersion = '3.8' -apply from: '../common.gradle' diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle index 5fd1b182a471..fbd65a1657cb 100644 --- a/sdks/python/test-suites/portable/common.gradle +++ b/sdks/python/test-suites/portable/common.gradle @@ -23,6 +23,7 @@ import org.apache.tools.ant.taskdefs.condition.Os def pythonRootDir = "${rootDir}/sdks/python" def pythonVersionSuffix = project.ext.pythonVersion.replace('.', '') def latestFlinkVersion = project.ext.latestFlinkVersion +def currentJavaVersion = project.ext.currentJavaVersion ext { pythonContainerTask = ":sdks:python:container:py${pythonVersionSuffix}:docker" @@ -369,7 +370,7 @@ project.tasks.register("postCommitPy${pythonVersionSuffix}IT") { 'setupVirtualenv', 'installGcpTest', ":runners:flink:${latestFlinkVersion}:job-server:shadowJar", - ':sdks:java:container:java8:docker', + ":sdks:java:container:${currentJavaVersion}:docker", ':sdks:java:testing:kafka-service:buildTestKafkaServiceJar', ':sdks:java:io:expansion-service:shadowJar', ':sdks:java:io:google-cloud-platform:expansion-service:shadowJar', @@ -420,7 +421,7 @@ project.tasks.register("xlangSpannerIOIT") { 'setupVirtualenv', 'installGcpTest', ":runners:flink:${latestFlinkVersion}:job-server:shadowJar", - ':sdks:java:container:java8:docker', + ":sdks:java:container:${currentJavaVersion}:docker", ':sdks:java:io:expansion-service:shadowJar', ':sdks:java:io:google-cloud-platform:expansion-service:shadowJar', ':sdks:java:io:kinesis:expansion-service:shadowJar', diff --git a/sdks/python/test-suites/portable/py38/build.gradle b/sdks/python/test-suites/portable/py38/build.gradle deleted file mode 100644 index e15443fa935f..000000000000 --- a/sdks/python/test-suites/portable/py38/build.gradle +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -apply plugin: org.apache.beam.gradle.BeamModulePlugin -applyPythonNature() - -addPortableWordCountTasks() - -// Required to setup a Python 3.8 virtualenv and task names. -pythonVersion = '3.8' -apply from: "../common.gradle" diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle deleted file mode 100644 index 2ca82d3d9268..000000000000 --- a/sdks/python/test-suites/tox/py38/build.gradle +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Unit tests for Python 3.8 - */ - -plugins { id 'org.apache.beam.module' } -applyPythonNature() - -// Required to setup a Python 3 virtualenv and task names. -pythonVersion = '3.8' - -def posargs = project.findProperty("posargs") ?: "" - -apply from: "../common.gradle" - -toxTask "testPy38CloudCoverage", "py38-cloudcoverage", "${posargs}" -test.dependsOn "testPy38CloudCoverage" -project.tasks.register("preCommitPyCoverage") { - dependsOn = ["testPy38CloudCoverage"] -} - -// Dep Postcommit runs test suites that evaluate compatibility of particular -// dependencies. It is exercised on a single Python version. -// -// Should still leave at least one version in PreCommit unless the marked tests -// are also exercised by existing PreCommit -// e.g. pyarrow and pandas also run on PreCommit Dataframe and Coverage -project.tasks.register("postCommitPyDep") {} - -// Create a test task for supported major versions of pyarrow -// We should have a test for the lowest supported version and -// For versions that we would like to prioritize for testing, -// for example versions released in a timeframe of last 1-2 years. - -toxTask "testPy38pyarrow-3", "py38-pyarrow-3", "${posargs}" -test.dependsOn "testPy38pyarrow-3" -postCommitPyDep.dependsOn "testPy38pyarrow-3" - -toxTask "testPy38pyarrow-9", "py38-pyarrow-9", "${posargs}" -test.dependsOn "testPy38pyarrow-9" -postCommitPyDep.dependsOn "testPy38pyarrow-9" - -toxTask "testPy38pyarrow-10", "py38-pyarrow-10", "${posargs}" -test.dependsOn "testPy38pyarrow-10" -postCommitPyDep.dependsOn "testPy38pyarrow-10" - -toxTask "testPy38pyarrow-11", "py38-pyarrow-11", "${posargs}" -test.dependsOn "testPy38pyarrow-11" -postCommitPyDep.dependsOn "testPy38pyarrow-11" - -toxTask "testPy38pyarrow-12", "py38-pyarrow-12", "${posargs}" -test.dependsOn "testPy38pyarrow-12" -postCommitPyDep.dependsOn "testPy38pyarrow-12" - -toxTask "testPy38pyarrow-13", "py38-pyarrow-13", "${posargs}" -test.dependsOn "testPy38pyarrow-13" -postCommitPyDep.dependsOn "testPy38pyarrow-13" - -toxTask "testPy38pyarrow-14", "py38-pyarrow-14", "${posargs}" -test.dependsOn "testPy38pyarrow-14" -postCommitPyDep.dependsOn "testPy38pyarrow-14" - -toxTask "testPy38pyarrow-15", "py38-pyarrow-15", "${posargs}" -test.dependsOn "testPy38pyarrow-15" -postCommitPyDep.dependsOn "testPy38pyarrow-15" - -toxTask "testPy38pyarrow-16", "py38-pyarrow-16", "${posargs}" -test.dependsOn "testPy38pyarrow-16" -postCommitPyDep.dependsOn "testPy38pyarrow-16" - -// Create a test task for each supported minor version of pandas -toxTask "testPy38pandas-14", "py38-pandas-14", "${posargs}" -test.dependsOn "testPy38pandas-14" -postCommitPyDep.dependsOn "testPy38pandas-14" - -toxTask "testPy38pandas-15", "py38-pandas-15", "${posargs}" -test.dependsOn "testPy38pandas-15" -postCommitPyDep.dependsOn "testPy38pandas-15" - -toxTask "testPy38pandas-20", "py38-pandas-20", "${posargs}" -test.dependsOn "testPy38pandas-20" -postCommitPyDep.dependsOn "testPy38pandas-20" - -// TODO(https://github.com/apache/beam/issues/31192): Add below suites -// after dependency compat tests suite switches to Python 3.9 or we add -// Python 2.2 support. - -// toxTask "testPy39pandas-21", "py39-pandas-21", "${posargs}" -// test.dependsOn "testPy39pandas-21" -// postCommitPyDep.dependsOn "testPy39pandas-21" - -// toxTask "testPy39pandas-22", "py39-pandas-22", "${posargs}" -// test.dependsOn "testPy39pandas-22" -// postCommitPyDep.dependsOn "testPy39pandas-22" - -// TODO(https://github.com/apache/beam/issues/30908): Revise what are we testing - -// Create a test task for each minor version of pytorch -toxTask "testPy38pytorch-19", "py38-pytorch-19", "${posargs}" -test.dependsOn "testPy38pytorch-19" -postCommitPyDep.dependsOn "testPy38pytorch-19" - -toxTask "testPy38pytorch-110", "py38-pytorch-110", "${posargs}" -test.dependsOn "testPy38pytorch-110" -postCommitPyDep.dependsOn "testPy38pytorch-110" - -toxTask "testPy38pytorch-111", "py38-pytorch-111", "${posargs}" -test.dependsOn "testPy38pytorch-111" -postCommitPyDep.dependsOn "testPy38pytorch-111" - -toxTask "testPy38pytorch-112", "py38-pytorch-112", "${posargs}" -test.dependsOn "testPy38pytorch-112" -postCommitPyDep.dependsOn "testPy38pytorch-112" - -toxTask "testPy38pytorch-113", "py38-pytorch-113", "${posargs}" -test.dependsOn "testPy38pytorch-113" -postCommitPyDep.dependsOn "testPy38pytorch-113" - -// run on precommit -toxTask "testPy38pytorch-200", "py38-pytorch-200", "${posargs}" -test.dependsOn "testPy38pytorch-200" -postCommitPyDep.dependsOn "testPy38pytorch-200" - -toxTask "testPy38tft-113", "py38-tft-113", "${posargs}" -test.dependsOn "testPy38tft-113" -postCommitPyDep.dependsOn "testPy38tft-113" - -// TODO(https://github.com/apache/beam/issues/25796) - uncomment onnx tox task once onnx supports protobuf 4.x.x -// Create a test task for each minor version of onnx -// toxTask "testPy38onnx-113", "py38-onnx-113", "${posargs}" -// test.dependsOn "testPy38onnx-113" -// postCommitPyDep.dependsOn "testPy38onnx-113" - -// Create a test task for each minor version of tensorflow -toxTask "testPy38tensorflow-212", "py38-tensorflow-212", "${posargs}" -test.dependsOn "testPy38tensorflow-212" -postCommitPyDep.dependsOn "testPy38tensorflow-212" - -// Create a test task for each minor version of transformers -toxTask "testPy38transformers-428", "py38-transformers-428", "${posargs}" -test.dependsOn "testPy38transformers-428" -postCommitPyDep.dependsOn "testPy38transformers-428" - -toxTask "testPy38transformers-429", "py38-transformers-429", "${posargs}" -test.dependsOn "testPy38transformers-429" -postCommitPyDep.dependsOn "testPy38transformers-429" - -toxTask "testPy38transformers-430", "py38-transformers-430", "${posargs}" -test.dependsOn "testPy38transformers-430" -postCommitPyDep.dependsOn "testPy38transformers-430" - -toxTask "testPy38embeddingsMLTransform", "py38-embeddings", "${posargs}" -test.dependsOn "testPy38embeddingsMLTransform" -postCommitPyDep.dependsOn "testPy38embeddingsMLTransform" - -// Part of MLTransform embeddings test suite but requires tensorflow hub, which we need to test on -// mutliple versions so keeping this suite separate. -toxTask "testPy38TensorflowHubEmbeddings-014", "py38-TFHubEmbeddings-014", "${posargs}" -test.dependsOn "testPy38TensorflowHubEmbeddings-014" -postCommitPyDep.dependsOn "testPy38TensorflowHubEmbeddings-014" - -toxTask "testPy38TensorflowHubEmbeddings-015", "py38-TFHubEmbeddings-015", "${posargs}" -test.dependsOn "testPy38TensorflowHubEmbeddings-015" -postCommitPyDep.dependsOn "testPy38TensorflowHubEmbeddings-015" - -toxTask "whitespacelint", "whitespacelint", "${posargs}" - -task archiveFilesToLint(type: Zip) { - archiveFileName = "files-to-whitespacelint.zip" - destinationDirectory = file("$buildDir/dist") - - from ("$rootProject.projectDir") { - include "**/*.md" - include "**/build.gradle" - include '**/build.gradle.kts' - exclude '**/build/**' // intermediate build directory - exclude 'website/www/site/themes/docsy/**' // fork to google/docsy - exclude "**/node_modules/*" - exclude "**/.gogradle/*" - } -} - -task unpackFilesToLint(type: Copy) { - from zipTree("$buildDir/dist/files-to-whitespacelint.zip") - into "$buildDir/files-to-whitespacelint" -} - -whitespacelint.dependsOn archiveFilesToLint, unpackFilesToLint -unpackFilesToLint.dependsOn archiveFilesToLint -archiveFilesToLint.dependsOn cleanPython - -toxTask "jest", "jest", "${posargs}" - -toxTask "eslint", "eslint", "${posargs}" - -task copyTsSource(type: Copy) { - from ("$rootProject.projectDir") { - include "sdks/python/apache_beam/runners/interactive/extensions/**/*" - exclude "sdks/python/apache_beam/runners/interactive/extensions/**/lib/*" - exclude "sdks/python/apache_beam/runners/interactive/extensions/**/node_modules/*" - } - into "$buildDir/ts" -} - -jest.dependsOn copyTsSource -eslint.dependsOn copyTsSource -copyTsSource.dependsOn cleanPython diff --git a/sdks/python/test-suites/tox/py39/build.gradle b/sdks/python/test-suites/tox/py39/build.gradle index ea02e9d5b1e8..e9624f8e810e 100644 --- a/sdks/python/test-suites/tox/py39/build.gradle +++ b/sdks/python/test-suites/tox/py39/build.gradle @@ -168,7 +168,9 @@ postCommitPyDep.dependsOn "testPy39transformers-430" toxTask "testPy39embeddingsMLTransform", "py39-embeddings", "${posargs}" test.dependsOn "testPy39embeddingsMLTransform" -postCommitPyDep.dependsOn "testPy39embeddingsMLTransform" +// TODO(https://github.com/apache/beam/issues/32965): re-enable this suite for the dep +// postcommit once the sentence-transformers import error is debugged +// postCommitPyDep.dependsOn "testPy39embeddingsMLTransform" // Part of MLTransform embeddings test suite but requires tensorflow hub, which we need to test on // mutliple versions so keeping this suite separate. diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index 8cdc4a98bbfe..c7713498d87d 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -149,7 +149,7 @@ commands = [testenv:mypy] deps = - mypy==0.790 + mypy==1.13.0 dask==2022.01.0 distributed==2022.01.0 # make extras available in case any of these libs are typed diff --git a/sdks/typescript/src/apache_beam/runners/flink.ts b/sdks/typescript/src/apache_beam/runners/flink.ts index ad4339b431f5..ab2d641b3302 100644 --- a/sdks/typescript/src/apache_beam/runners/flink.ts +++ b/sdks/typescript/src/apache_beam/runners/flink.ts @@ -28,7 +28,7 @@ import { JavaJarService } from "../utils/service"; const MAGIC_HOST_NAMES = ["[local]", "[auto]"]; // These should stay in sync with gradle.properties. -const PUBLISHED_FLINK_VERSIONS = ["1.15", "1.16", "1.17", "1.18"]; +const PUBLISHED_FLINK_VERSIONS = ["1.17", "1.18", "1.19"]; const defaultOptions = { flinkMaster: "[local]", diff --git a/settings.gradle.kts b/settings.gradle.kts index 9701b4dbc06f..ca30a5ea750a 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -125,14 +125,6 @@ include(":runners:extensions-java:metrics") * verify versions in website/www/site/content/en/documentation/runners/flink.md * verify version in sdks/python/apache_beam/runners/interactive/interactive_beam.py */ -// Flink 1.15 -include(":runners:flink:1.15") -include(":runners:flink:1.15:job-server") -include(":runners:flink:1.15:job-server-container") -// Flink 1.16 -include(":runners:flink:1.16") -include(":runners:flink:1.16:job-server") -include(":runners:flink:1.16:job-server-container") // Flink 1.17 include(":runners:flink:1.17") include(":runners:flink:1.17:job-server") @@ -141,6 +133,10 @@ include(":runners:flink:1.17:job-server-container") include(":runners:flink:1.18") include(":runners:flink:1.18:job-server") include(":runners:flink:1.18:job-server-container") +// Flink 1.19 +include(":runners:flink:1.19") +include(":runners:flink:1.19:job-server") +include(":runners:flink:1.19:job-server-container") /* End Flink Runner related settings */ include(":runners:twister2") include(":runners:google-cloud-dataflow-java") @@ -337,6 +333,8 @@ project(":beam-test-gha").projectDir = file(".github") include("beam-validate-runner") project(":beam-validate-runner").projectDir = file(".test-infra/validate-runner") include("com.google.api.gax.batching") +include("sdks:java:io:kafka:kafka-312") +findProject(":sdks:java:io:kafka:kafka-312")?.name = "kafka-312" include("sdks:java:io:kafka:kafka-251") findProject(":sdks:java:io:kafka:kafka-251")?.name = "kafka-251" include("sdks:java:io:kafka:kafka-241") @@ -349,12 +347,6 @@ include("sdks:java:io:kafka:kafka-211") findProject(":sdks:java:io:kafka:kafka-211")?.name = "kafka-211" include("sdks:java:io:kafka:kafka-201") findProject(":sdks:java:io:kafka:kafka-201")?.name = "kafka-201" -include("sdks:java:io:kafka:kafka-111") -findProject(":sdks:java:io:kafka:kafka-111")?.name = "kafka-111" -include("sdks:java:io:kafka:kafka-100") -findProject(":sdks:java:io:kafka:kafka-100")?.name = "kafka-100" -include("sdks:java:io:kafka:kafka-01103") -findProject(":sdks:java:io:kafka:kafka-01103")?.name = "kafka-01103" include("sdks:java:managed") findProject(":sdks:java:managed")?.name = "managed" include("sdks:java:io:iceberg") diff --git a/vendor/grpc-1_60_1/build.gradle b/vendor/grpc-1_60_1/build.gradle index 834c496d9ca4..da152ef10f75 100644 --- a/vendor/grpc-1_60_1/build.gradle +++ b/vendor/grpc-1_60_1/build.gradle @@ -23,7 +23,7 @@ plugins { id 'org.apache.beam.vendor-java' } description = "Apache Beam :: Vendored Dependencies :: gRPC :: 1.60.1" group = "org.apache.beam" -version = "0.2" +version = "0.3" vendorJava( dependencies: GrpcVendoring_1_60_1.dependencies(), diff --git a/website/www/site/config.toml b/website/www/site/config.toml index e937289fbde7..d769f8434a7f 100644 --- a/website/www/site/config.toml +++ b/website/www/site/config.toml @@ -104,7 +104,7 @@ github_project_repo = "https://github.com/apache/beam" [params] description = "Apache Beam is an open source, unified model and set of language-specific SDKs for defining and executing data processing workflows, and also data ingestion and integration flows, supporting Enterprise Integration Patterns (EIPs) and Domain Specific Languages (DSLs). Dataflow pipelines simplify the mechanics of large-scale batch and streaming data processing and can run on a number of runtimes like Apache Flink, Apache Spark, and Google Cloud Dataflow (a cloud service). Beam also brings DSL in different languages, allowing users to easily implement their data integration processes." -release_latest = "2.59.0" +release_latest = "2.60.0" # The repository and branch where the files live in Github or Colab. This is used # to serve and stage from your local branch, but publish to the master branch. # e.g. https://github.com/{{< param branch_repo >}}/path/to/notebook.ipynb diff --git a/website/www/site/content/en/blog/beam-2.60.0.md b/website/www/site/content/en/blog/beam-2.60.0.md new file mode 100644 index 000000000000..462bdaf16798 --- /dev/null +++ b/website/www/site/content/en/blog/beam-2.60.0.md @@ -0,0 +1,81 @@ +--- +title: "Apache Beam 2.60.0" +date: 2024-10-17 15:00:00 -0500 +categories: + - blog + - release +authors: + - yhu +--- + + +We are happy to present the new 2.60.0 release of Beam. +This release includes both improvements and new functionality. +See the [download page](/get-started/downloads/#2600-2024-10-17) for this release. + + + +For more information on changes in 2.60.0, check out the [detailed release notes](https://github.com/apache/beam/milestone/24). + +## Highlights + +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) +* [Managed Iceberg] Added support for streaming writes ([#32451](https://github.com/apache/beam/pull/32451)) +* [Managed Iceberg] Added auto-sharding for streaming writes ([#32612](https://github.com/apache/beam/pull/32612)) +* [Managed Iceberg] Added support for writing to dynamic destinations ([#32565](https://github.com/apache/beam/pull/32565)) + +## New Features / Improvements + +* Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). +* Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349)) +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) +* Prism release binaries and container bootloaders are now being built with the latest Go 1.23 patch. ([#32575](https://github.com/apache/beam/pull/32575)) +* Prism + * Prism now supports Bundle Finalization. ([#32425](https://github.com/apache/beam/pull/32425)) +* Significantly improved performance of Kafka IO reads that enable [commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--) by removing the data reshuffle from SDF implementation. ([#31682](https://github.com/apache/beam/pull/31682)). +* Added support for dynamic writing in MqttIO (Java) ([#19376](https://github.com/apache/beam/issues/19376)) +* Optimized Spark Runner parDo transform evaluator (Java) ([#32537](https://github.com/apache/beam/issues/32537)) +* [Managed Iceberg] More efficient manifest file writes/commits ([#32666](https://github.com/apache/beam/issues/32666)) + +## Breaking Changes + +* In Python, assert_that now throws if it is not in a pipeline context instead of silently succeeding ([#30771](https://github.com/apache/beam/pull/30771)) +* In Python and YAML, ReadFromJson now override the dtype from None to + an explicit False. Most notably, string values like `"123"` are preserved + as strings rather than silently coerced (and possibly truncated) to numeric + values. To retain the old behavior, pass `dtype=True` (or any other value + accepted by `pandas.read_json`). +* Users of KafkaIO Read transform that enable [commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--) might encounter pipeline graph compatibility issues when updating the pipeline. To mitigate, set the `updateCompatibilityVersion` option to the SDK version used for the original pipeline, example `--updateCompatabilityVersion=2.58.1` + +## Deprecations + +* Python 3.8 is reaching EOL and support is being removed in Beam 2.61.0. The 2.60.0 release will warn users +when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) + +## Bugfixes + +* (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)). +* (Java, Python, Go) Fixed PeriodicSequence backlog bytes reporting, which was preventing Dataflow Runner autoscaling from functioning properly ([#32506](https://github.com/apache/beam/issues/32506)). +* (Java) Fix improper decoding of rows with schemas containing nullable fields when encoded with a schema with equal encoding positions but modified field order. ([#32388](https://github.com/apache/beam/issues/32388)). + +## Known Issues + +N/A + +For the most up to date list of known issues, see https://github.com/apache/beam/blob/master/CHANGES.md + +## List of Contributors + +According to git shortlog, the following people contributed to the 2.60.0 release. Thank you to all contributors! + +Ahmed Abualsaud, Aiden Grossman, Arun Pandian, Bartosz Zablocki, Chamikara Jayalath, Claire McGinty, DKPHUONG, Damon Douglass, Danny McCormick, Dip Patel, Ferran Fernández Garrido, Hai Joey Tran, Hyeonho Kim, Igor Bernstein, Israel Herraiz, Jack McCluskey, Jaehyeon Kim, Jeff Kinard, Jeffrey Kinard, Joey Tran, Kenneth Knowles, Kirill Berezin, Michel Davit, Minbo Bae, Naireen Hussain, Niel Markwick, Nito Buendia, Reeba Qureshi, Reuven Lax, Robert Bradshaw, Robert Burke, Rohit Sinha, Ryan Fu, Sam Whittle, Shunping Huang, Svetak Sundhar, Udaya Chathuranga, Vitaly Terentyev, Vlado Djerek, Yi Hu, Claude van der Merwe, XQ Hu, Martin Trieu, Valentyn Tymofieiev, twosom diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md new file mode 100644 index 000000000000..995b59b978c1 --- /dev/null +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -0,0 +1,277 @@ +--- +title: "Efficient Streaming Data Processing with Beam YAML and Protobuf" +date: "2024-09-20T11:53:38+02:00" +categories: + - blog +authors: + - ffernandez92 +--- + + +# Efficient Streaming Data Processing with Beam YAML and Protobuf + +As streaming data processing grows, so do its maintenance, complexity, and costs. +This post explains how to efficiently scale pipelines by using [Protobuf](https://protobuf.dev/), +which ensures that pipelines are reusable and quick to deploy. The goal is to keep this process simple +for engineers to implement using [Beam YAML](https://beam.apache.org/documentation/sdks/yaml/). + + + +## Simplify pipelines with Beam YAML + +Creating a pipeline in Beam can be somewhat difficult, especially for new Apache Beam users. +Setting up the project, managing dependencies, and so on can be challenging. +Beam YAML eliminates most of the boilerplate code, +which allows you to focus on the most important part of the work: data transformation. + +Some of the key benefits of Beam YAML include: + +* **Readability:** By using a declarative language ([YAML](https://yaml.org/)), the pipeline configuration is more human readable. +* **Reusability:** Reusing the same components across different pipelines is simplified. +* **Maintainability:** Pipeline maintenance and updates are easier. + +The following template shows an example of reading events from a [Kafka](https://kafka.apache.org/intro) topic and +writing them into [BigQuery](https://cloud.google.com/bigquery?hl=en). + +```yaml +pipeline: + transforms: + - type: ReadFromKafka + name: ReadProtoMovieEvents + config: + topic: 'TOPIC_NAME' + format: RAW/AVRO/JSON/PROTO + bootstrap_servers: 'BOOTSTRAP_SERVERS' + schema: 'SCHEMA' + - type: WriteToBigQuery + name: WriteMovieEvents + input: ReadProtoMovieEvents + config: + table: 'PROJECT_ID.DATASET.MOVIE_EVENTS_TABLE' + useAtLeastOnceSemantics: true + +options: + streaming: true + dataflow_service_options: [streaming_mode_at_least_once] +``` + +## The complete workflow + +This section demonstrates the complete workflow for this pipeline. + +### Create a simple proto event + +The following code creates a simple movie event. + +```protobuf +// events/v1/movie_event.proto + +syntax = "proto3"; + +package event.v1; + +import "bq_field.proto"; +import "bq_table.proto"; +import "buf/validate/validate.proto"; +import "google/protobuf/wrappers.proto"; + +message MovieEvent { + option (gen_bq_schema.bigquery_opts).table_name = "movie_table"; + google.protobuf.StringValue event_id = 1 [(gen_bq_schema.bigquery).description = "Unique Event ID"]; + google.protobuf.StringValue user_id = 2 [(gen_bq_schema.bigquery).description = "Unique User ID"]; + google.protobuf.StringValue movie_id = 3 [(gen_bq_schema.bigquery).description = "Unique Movie ID"]; + google.protobuf.Int32Value rating = 4 [(buf.validate.field).int32 = { + // validates the average rating is at least 0 + gte: 0, + // validates the average rating is at most 100 + lte: 100 + }, (gen_bq_schema.bigquery).description = "Movie rating"]; + string event_dt = 5 [ + (gen_bq_schema.bigquery).type_override = "DATETIME", + (gen_bq_schema.bigquery).description = "UTC Datetime representing when we received this event. Format: YYYY-MM-DDTHH:MM:SS", + (buf.validate.field) = { + string: { + pattern: "^\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}$" + }, + ignore_empty: false, + } + ]; +} +``` + +Because these events are written to BigQuery, +the [`bq_field`](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_field.proto) proto +and the [`bq_table`](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_table.proto) proto are imported. +These proto files help generate the BigQuery JSON schema. +This example also demonstrates a shift-left approach, which moves testing, quality, +and performance as early as possible in the development process. For example, to ensure that only valid events are generated from the source, the `buf.validate` elements are included. + +After you create the `movie_event.proto` proto in the `events/v1` folder, you can generate +the necessary [file descriptor](https://buf.build/docs/reference/descriptors). +A file descriptor is a compiled representation of the schema that allows various tools and systems +to understand and work with protobuf data dynamically. To simplify the process, this example uses Buf, +which requires the following configuration files. + + +Buf configuration: + +```yaml +# buf.yaml + +version: v2 +deps: + - buf.build/googlecloudplatform/bq-schema-api + - buf.build/bufbuild/protovalidate +breaking: + use: + - FILE +lint: + use: + - DEFAULT +``` + +```yaml +# buf.gen.yaml + +version: v2 +managed: + enabled: true +plugins: + # Python Plugins + - remote: buf.build/protocolbuffers/python + out: gen/python + - remote: buf.build/grpc/python + out: gen/python + + # Java Plugins + - remote: buf.build/protocolbuffers/java:v25.2 + out: gen/maven/src/main/java + - remote: buf.build/grpc/java + out: gen/maven/src/main/java + + # BQ Schemas + - remote: buf.build/googlecloudplatform/bq-schema:v1.1.0 + out: protoc-gen/bq_schema + +``` + +Run the following two commands to generate the necessary Java, Python, BigQuery schema, and Descriptor file: + +```bash +// Generate the buf.lock file +buf deps update + +// It generates the descriptor in descriptor.binp. +buf build . -o descriptor.binp --exclude-imports + +// It generates the Java, Python and BigQuery schema as described in buf.gen.yaml +buf generate --include-imports +``` + +### Make the Beam YAML read proto + +Make the following modifications to the to the YAML file: + +```yaml +# movie_events_pipeline.yml + +pipeline: + transforms: + - type: ReadFromKafka + name: ReadProtoMovieEvents + config: + topic: 'movie_proto' + format: PROTO + bootstrap_servers: '' + file_descriptor_path: 'gs://my_proto_bucket/movie/v1.0.0/descriptor.binp' + message_name: 'event.v1.MovieEvent' + - type: WriteToBigQuery + name: WriteMovieEvents + input: ReadProtoMovieEvents + config: + table: '.raw.movie_table' + useAtLeastOnceSemantics: true +options: + streaming: true + dataflow_service_options: [streaming_mode_at_least_once] +``` + +This step changes the format to `PROTO` and adds the `file_descriptor_path` and the `message_name`. + +### Deploy the pipeline with Terraform + +You can use [Terraform](https://www.terraform.io/) to deploy the Beam YAML pipeline +with [Dataflow](https://cloud.google.com/products/dataflow?hl=en) as the runner. +The following Terraform code example demonstrates how to achieve this: + +```hcl +// Enable Dataflow API. +resource "google_project_service" "enable_dataflow_api" { + project = var.gcp_project_id + service = "dataflow.googleapis.com" +} + +// DF Beam YAML +resource "google_dataflow_flex_template_job" "data_movie_job" { + provider = google-beta + project = var.gcp_project_id + name = "movie-proto-events" + container_spec_gcs_path = "gs://dataflow-templates-${var.gcp_region}/latest/flex/Yaml_Template" + region = var.gcp_region + on_delete = "drain" + machine_type = "n2d-standard-4" + enable_streaming_engine = true + subnetwork = var.subnetwork + skip_wait_on_job_termination = true + parameters = { + yaml_pipeline_file = "gs://${var.bucket_name}/yamls/${var.package_version}/movie_events_pipeline.yml" + max_num_workers = 40 + worker_zone = var.gcp_zone + } + depends_on = [google_project_service.enable_dataflow_api] +} +``` + +Assuming the BigQuery table exists, which you can do by using Terraform and Proto, +this code creates a Dataflow job by using the Beam YAML code that reads Proto events from +Kafka and writes them into BigQuery. + +## Improvements and conclusions + +The following community contributions could improve the Beam YAML code in this example: + +* **Support schema registries:** Integrate with schema registries such as Buf Registry or Apicurio for +better schema management. The current workflow generates the descriptors by using Buf and store them in Google Cloud Storage. +The descriptors could be stored in a schema registry instead. + + +* **Enhanced Monitoring:** Implement advanced monitoring and alerting mechanisms to quickly identify and address +issues in the data pipeline. + +Leveraging Beam YAML and Protobuf lets us streamline the creation and maintenance of +data processing pipelines, significantly reducing complexity. This approach ensures that engineers can more +efficiently implement and scale robust, reusable pipelines without needs to manually write Beam code. + +## Contribute + +Developers who want to help build out and add functionalities are welcome to start contributing to the effort in the +[Beam YAML module](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/yaml). + +There is also a list of open [bugs](https://github.com/apache/beam/issues?q=is%3Aopen+is%3Aissue+label%3Ayaml) found +on the GitHub repo - now marked with the `yaml` tag. + +Although Beam YAML is marked stable as of Beam 2.52, it is still under heavy development, with new features being +added with each release. Those who want to be part of the design decisions and give insights to how the framework is +being used are highly encouraged to join the [dev mailing list](https://beam.apache.org/community/contact-us/), where those discussions are occurring. diff --git a/website/www/site/content/en/case-studies/behalf.md b/website/www/site/content/en/case-studies/behalf.md deleted file mode 100644 index e5a240a03d4f..000000000000 --- a/website/www/site/content/en/case-studies/behalf.md +++ /dev/null @@ -1,19 +0,0 @@ ---- -title: "Behalf" -icon: /images/logos/powered-by/behalf.png -hasLink: "https://www.behalf.com/" ---- - - diff --git a/website/www/site/content/en/documentation/ml/large-language-modeling.md b/website/www/site/content/en/documentation/ml/large-language-modeling.md index 90bbd43383c0..b8bd0704d20e 100644 --- a/website/www/site/content/en/documentation/ml/large-language-modeling.md +++ b/website/www/site/content/en/documentation/ml/large-language-modeling.md @@ -170,3 +170,32 @@ class MyModelHandler(): def run_inference(self, batch: Sequence[str], model: MyWrapper, inference_args): return model.predict(unpickleable_object) ``` + +## RAG and Prompt Engineering in Beam + +Beam is also an excellent tool for improving the quality of your LLM prompts using Retrieval Augmented Generation (RAG). +Retrieval augmented generation is a technique that enhances large language models (LLMs) by connecting them to external knowledge sources. +This allows the LLM to access and process real-time information, improving the accuracy, relevance, and factuality of its responses. + +Beam has several mechanisms to make this process simpler: + +1. Beam's [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) provides an embeddings package to generate the embeddings used for RAG. You can also use RunInference to generate embeddings if you have a model without an embeddings handler. +2. Beam's [Enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) makes it easy to look up embeddings or other information in an external storage system like a [vector database](https://www.pinecone.io/learn/vector-database/). + +Collectively, you can use these to perform RAG using the following steps: + +**Pipeline 1 - generate knowledge base:** + +1. Ingest data from external source using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) +2. Generate embeddings on that data using [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) +3. Write those embeddings to a vector DB using a [ParDo](https://beam.apache.org/documentation/programming-guide/#pardo) + +**Pipeline 2 - use knowledge base to perform RAG:** + +1. Ingest data from external source using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) +2. Generate embeddings on that data using [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) +3. Enrich that data with additional embeddings from your vector DB using [Enrichment](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) +4. Use that enriched data to prompt your LLM with [RunInference](https://beam.apache.org/documentation/transforms/python/elementwise/runinference/) +5. Write that data to your desired sink using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) + +To view an example pipeline performing RAG, see https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/rag_usecase/beam_rag_notebook.ipynb diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index c716c7554db4..955c2b8797d1 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -35,7 +35,7 @@ programming guide, take a look at the {{< language-switcher java py go typescript yaml >}} {{< paragraph class="language-py" >}} -The Python SDK supports Python 3.8, 3.9, 3.10, and 3.11. +The Python SDK supports Python 3.8, 3.9, 3.10, 3.11, and 3.12. {{< /paragraph >}} {{< paragraph class="language-go">}} @@ -2024,7 +2024,7 @@ playerAccuracies := ... // PCollection #### 4.2.5. Flatten {#flatten} [`Flatten`](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/index.html?org/apache/beam/sdk/transforms/Flatten.html) -[`Flatten`](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/transforms/core.py) +[`Flatten`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.core.html#apache_beam.transforms.core.Flatten) [`Flatten`](https://github.com/apache/beam/blob/master/sdks/go/pkg/beam/flatten.go) `Flatten` is a Beam transform for `PCollection` objects that store the same data type. @@ -2045,6 +2045,22 @@ PCollectionList collections = PCollectionList.of(pc1).and(pc2).and(pc3); PCollection merged = collections.apply(Flatten.pCollections()); {{< /highlight >}} +{{< paragraph class="language-java" >}} +One can also use the [`FlattenWith`](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/index.html?org/apache/beam/sdk/transforms/Flatten.html) +transform to merge PCollections into an output PCollection in a manner more compatible with chaining. +{{< /paragraph >}} + +{{< highlight java >}} +PCollection merged = pc1 + .apply(...) + // Merges the elements of pc2 in at this point... + .apply(FlattenWith.of(pc2)) + .apply(...) + // and the elements of pc3 at this point. + .apply(FlattenWith.of(pc3)) + .apply(...); +{{< /highlight >}} + {{< highlight py >}} # Flatten takes a tuple of PCollection objects. @@ -2052,6 +2068,26 @@ PCollection merged = collections.apply(Flatten.pCollections()); {{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten >}} {{< /highlight >}} +{{< paragraph class="language-py" >}} +One can also use the [`FlattenWith`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.core.html#apache_beam.transforms.core.FlattenWith) +transform to merge PCollections into an output PCollection in a manner more compatible with chaining. +{{< /paragraph >}} + +{{< highlight py >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with >}} +{{< /highlight >}} + +{{< paragraph class="language-py" >}} +`FlattenWith` can take root `PCollection`-producing transforms +(such as `Create` and `Read`) as well as already constructed PCollections, +and will apply them and flatten their outputs into the resulting output +PCollection. +{{< /paragraph >}} + +{{< highlight py >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with_transform >}} +{{< /highlight >}} + {{< highlight go >}} // Flatten accepts any number of PCollections of the same element type. // Returns a single PCollection that contains all of the elements in input PCollections. @@ -6173,7 +6209,7 @@ class MyDoFn(beam.DoFn): self.gauge = metrics.Metrics.gauge("namespace", "gauge1") def process(self, element): - self.gaguge.set(element) + self.gauge.set(element) yield element {{< /highlight >}} diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index 7325c480955c..af73751c256a 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -93,7 +93,7 @@ from the [compatibility table](#flink-version-compatibility) below. For example: {{< highlight java >}} org.apache.beam - beam-runners-flink-1.17 + beam-runners-flink-1.18 {{< param release_latest >}} {{< /highlight >}} @@ -196,7 +196,6 @@ The optional `flink_version` option may be required as well for older versions o {{< paragraph class="language-portable" >}} Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: -[Flink 1.15](https://hub.docker.com/r/apache/beam_flink1.15_job_server). [Flink 1.16](https://hub.docker.com/r/apache/beam_flink1.16_job_server). [Flink 1.17](https://hub.docker.com/r/apache/beam_flink1.17_job_server). [Flink 1.18](https://hub.docker.com/r/apache/beam_flink1.18_job_server). @@ -208,12 +207,17 @@ To run a pipeline on an embedded Flink cluster: {{< /paragraph >}} {{< paragraph class="language-portable" >}} -(1) Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +(1) Start the JobService endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest` {{< /paragraph >}} {{< paragraph class="language-portable" >}} The JobService is the central instance where you submit your Beam pipeline to. -The JobService will create a Flink job for the pipeline and execute the job. +It creates a Flink job from your pipeline and executes it. +You might encounter an error message like `Caused by: java.io.IOException: Insufficient number of network buffers:...`. +This can be resolved by providing a Flink configuration file to override the default settings. +You can find an example configuration file [here](https://github.com/apache/beam/blob/master/runners/flink/src/test/resources/flink-conf.yaml). +To start the Job Service endpoint with your custom configuration, mount a local directory containing your Flink configuration to the `/flink-conf` path in the Docker container: +`docker run --net=host -v :/flink-conf beam-flink-runner apache/beam_flink1.18_job_server:latest` {{< /paragraph >}} {{< paragraph class="language-portable" >}} @@ -244,7 +248,7 @@ To run on a separate [Flink cluster](https://ci.apache.org/projects/flink/flink- {{< /paragraph >}} {{< paragraph class="language-portable" >}} -(2) Start JobService with Flink Rest endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest --flink-master=localhost:8081`. +(2) Start JobService with Flink Rest endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest --flink-master=localhost:8081`. {{< /paragraph >}} {{< paragraph class="language-portable" >}} @@ -312,8 +316,8 @@ reference. ## Flink Version Compatibility The Flink cluster version has to match the minor version used by the FlinkRunner. -The minor version is the first two numbers in the version string, e.g. in `1.16.0` the -minor version is `1.16`. +The minor version is the first two numbers in the version string, e.g. in `1.18.0` the +minor version is `1.18`. We try to track the latest version of Apache Flink at the time of the Beam release. A Flink version is supported by Beam for the time it is supported by the Flink community. @@ -326,6 +330,11 @@ To find out which version of Flink is compatible with Beam please see the table Artifact Id Supported Beam Versions + + 1.19.x + beam-runners-flink-1.19 + ≥ 2.61.0 + 1.18.x beam-runners-flink-1.18 @@ -339,12 +348,12 @@ To find out which version of Flink is compatible with Beam please see the table 1.16.x beam-runners-flink-1.16 - ≥ 2.47.0 + 2.47.0 - 2.60.0 1.15.x beam-runners-flink-1.15 - ≥ 2.40.0 + 2.40.0 - 2.60.0 1.14.x diff --git a/website/www/site/content/en/documentation/runtime/environments.md b/website/www/site/content/en/documentation/runtime/environments.md index d9a42db29e24..48039d50a10b 100644 --- a/website/www/site/content/en/documentation/runtime/environments.md +++ b/website/www/site/content/en/documentation/runtime/environments.md @@ -105,20 +105,22 @@ This method requires building image artifacts from Beam source. For additional i 2. Customize the `Dockerfile` for a given language, typically `sdks//container/Dockerfile` directory (e.g. the [Dockerfile for Python](https://github.com/apache/beam/blob/master/sdks/python/container/Dockerfile). -3. Return to the root Beam directory and run the Gradle `docker` target for your image. +3. Return to the root Beam directory and run the Gradle `docker` target for your + image. For self-contained instructions on building a container image, + follow [this guide](/documentation/sdks/python-sdk-image-build). ``` cd $BEAM_WORKDIR # The default repository of each SDK - ./gradlew :sdks:java:container:java8:docker ./gradlew :sdks:java:container:java11:docker ./gradlew :sdks:java:container:java17:docker + ./gradlew :sdks:java:container:java21:docker ./gradlew :sdks:go:container:docker - ./gradlew :sdks:python:container:py38:docker ./gradlew :sdks:python:container:py39:docker ./gradlew :sdks:python:container:py310:docker ./gradlew :sdks:python:container:py311:docker + ./gradlew :sdks:python:container:py312:docker # Shortcut for building all Python SDKs ./gradlew :sdks:python:container:buildAll @@ -168,9 +170,9 @@ builds the Python 3.6 container and tags it as `example-repo/beam_python3.6_sdk: From Beam 2.21.0 and later, a `docker-pull-licenses` flag was introduced to add licenses/notices for third party dependencies to the docker images. For example: ``` -./gradlew :sdks:java:container:java8:docker -Pdocker-pull-licenses +./gradlew :sdks:java:container:java11:docker -Pdocker-pull-licenses ``` -creates a Java 8 SDK image with appropriate licenses in `/opt/apache/beam/third_party_licenses/`. +creates a Java 11 SDK image with appropriate licenses in `/opt/apache/beam/third_party_licenses/`. By default, no licenses/notices are added to the docker images. diff --git a/website/www/site/content/en/documentation/sdks/python-sdk-image-build.md b/website/www/site/content/en/documentation/sdks/python-sdk-image-build.md new file mode 100644 index 000000000000..f456a686afea --- /dev/null +++ b/website/www/site/content/en/documentation/sdks/python-sdk-image-build.md @@ -0,0 +1,306 @@ + + +# Building Beam Python SDK Image Guide + +There are two options to build Beam Python SDK image. If you only need to modify +[the Python SDK boot entrypoint binary](https://github.com/apache/beam/blob/master/sdks/python/container/boot.go), +read [Update Boot Entrypoint Application Only](#update-boot-entrypoint-application-only). +If you need to build a Beam Python SDK image fully, +read [Build Beam Python SDK Image Fully](#build-beam-python-sdk-image-fully). + + +## Update Boot Entrypoint Application Only. + +If you only need to make a change to [the Python SDK boot entrypoint binary](https://github.com/apache/beam/blob/master/sdks/python/container/boot.go). You +can rebuild the boot application only and include the updated boot application +in the preexisting image. +Read [the Python container Dockerfile](https://github.com/apache/beam/blob/master/sdks/python/container/Dockerfile) +for reference. + +```shell +# From beam repo root, make changes to boot.go. +your_editor sdks/python/container/boot.go + +# Rebuild the entrypoint +./gradlew :sdks:python:container:gobuild + +cd sdks/python/container/build/target/launcher/linux_amd64 + +# Create a simple Dockerfile to use custom boot entrypoint. +cat >Dockerfile <//beam_python3.10_sdk:2.60.0-custom-boot +docker push us-central1-docker.pkg.dev///beam_python3.10_sdk:2.60.0-custom-boot +``` + +You can build a docker image if your local environment has Java, Python, Golang +and Docker installation. Try +`./gradlew :sdks:python:container:py:docker`. For example, +`:sdks:python:container:py310:docker` builds `apache/beam_python3.10_sdk` +locally if successful. You can follow this guide building a custom image from +a VM if the build fails in your local environment. + +## Build Beam Python SDK Image Fully + +This section introduces a way to build everything from the scratch. + +### Prepare VM + +Prepare a VM with Debian 11. This guide was tested on Debian 11. + +#### Google Compute Engine + +An option to create a Debian 11 VM is using a GCE instance. + +```shell +gcloud compute instances create beam-builder \ + --zone=us-central1-a \ + --image-project=debian-cloud \ + --image-family=debian-11 \ + --machine-type=n1-standard-8 \ + --boot-disk-size=20GB \ + --scopes=cloud-platform +``` + +Login to the VM. All the following steps are executed inside the VM. + +```shell +gcloud compute ssh beam-builder --zone=us-central1-a --tunnel-through-iap +``` + +Update the apt package list. + +```shell +sudo apt-get update +``` + +> [!NOTE] +> * A high CPU machine is recommended to reduce the compile time. +> * The image build needs a large disk. The build will fail with "no space left + on device" with the default disk size 10GB. +> * The `cloud-platform` is recommended to avoid permission issues with Google + Cloud Artifact Registry. You can use the default scopes if you don't push + the image to Google Cloud Artifact Registry. +> * Use a zone in the region of your docker repository of Artifact Registry if + you push the image to Artifact Registry. + +### Prerequisite Packages + +#### Java + +You need Java to run Gradle tasks. + +```shell +sudo apt-get install -y openjdk-11-jdk +``` + +#### Golang + +Download and install. Reference: https://go.dev/doc/install. + +```shell +# Download and install +curl -OL https://go.dev/dl/go1.23.2.linux-amd64.tar.gz +sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go1.23.2.linux-amd64.tar.gz + +# Add go to PATH. +export PATH=:/usr/local/go/bin:$PATH +``` + +Confirm the Golang version + +```shell +go version +``` + +Expected output: + +```text +go version go1.23.2 linux/amd64 +``` + +> [!NOTE] +> Old Go version (e.g. 1.16) will fail at `:sdks:python:container:goBuild`. + +#### Python + +This guide uses Pyenv to manage multiple Python versions. +Reference: https://realpython.com/intro-to-pyenv/#build-dependencies + +```shell +# Install dependencies +sudo apt-get install -y make build-essential libssl-dev zlib1g-dev \ +libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev \ +libncursesw5-dev xz-utils tk-dev libffi-dev liblzma-dev + +# Install Pyenv +curl https://pyenv.run | bash + +# Add pyenv to PATH. +export PATH="$HOME/.pyenv/bin:$PATH" +eval "$(pyenv init -)" +eval "$(pyenv virtualenv-init -)" +``` + +Install Python 3.9 and set the Python version. This will take several minutes. + +```shell +pyenv install 3.9 +pyenv global 3.9 +``` + +Confirm the python version. + +```shell +python --version +``` + +Expected output example: + +```text +Python 3.9.17 +``` + +> [!NOTE] +> You can use a different Python version for building with [ +`-PpythonVersion` option](https://github.com/apache/beam/blob/v2.60.0/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy#L2956-L2961) +> to Gradle task run. Otherwise, you should have `python3.9` in the build +> environment for Apache Beam 2.60.0 or later (python3.8 for older Apache Beam +> versions). If you use the wrong version, the Gradle task +`:sdks:python:setupVirtualenv` fails. + +#### Docker + +Install Docker +following [the reference](https://docs.docker.com/engine/install/debian/#install-using-the-repository). + +```shell +# Add GPG keys. +sudo apt-get update +sudo apt-get install ca-certificates curl +sudo install -m 0755 -d /etc/apt/keyrings +sudo curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc +sudo chmod a+r /etc/apt/keyrings/docker.asc + +# Add the Apt repository. +echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian \ + $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \ + sudo tee /etc/apt/sources.list.d/docker.list > /dev/null +sudo apt-get update + +# Install docker packages. +sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin +``` + +You need to run `docker` command without the root privilege in Beam Python SDK +image build. You can do this +by [adding your account to the docker group](https://docs.docker.com/engine/install/linux-postinstall/). + +```shell +sudo usermod -aG docker $USER +newgrp docker +``` + +Confirm if you can run a container without the root privilege. + +```shell +docker run hello-world +``` + +#### Git + +Git is not necessary for building Python SDK image. Git is just used to download +the Apache Beam code in this guide. + +```shell +sudo apt-get install -y git +``` + +### Build Beam Python SDK Image + +Download Apache Beam +from [the Github repository](https://github.com/apache/beam). + +```shell +git clone https://github.com/apache/beam beam +cd beam +``` + +Make changes to the Apache Beam code. + +Run the Gradle task to start Docker image build. This will take several minutes. +You can run `:sdks:python:container:py:docker` to build an image +for different Python version. +See [the supported Python version list](https://github.com/apache/beam/tree/master/sdks/python/container). +For example, `py310` is for Python 3.10. + +```shell +./gradlew :sdks:python:container:py310:docker +``` + +If the build is successful, you can see the built image locally. + +```shell +docker images +``` + +Expected output: + +```text +REPOSITORY TAG IMAGE ID CREATED SIZE +apache/beam_python3.10_sdk 2.60.0 33db45f57f25 About a minute ago 2.79GB +``` + +> [!NOTE] +> If you run the build in your local environment and Gradle task +`:sdks:python:setupVirtualenv` fails by an incompatible python version, please +> try with `-PpythonVersion` with the Python version installed in your local +> environment (e.g. `-PpythonVersion=3.10`) + +### Push to Repository + +You may push the custom image to a image repository. The image can be used +for [Dataflow custom container](https://cloud.google.com/dataflow/docs/guides/run-custom-container#usage). + +#### Google Cloud Artifact Registry + +You can push the image to Artifact Registry. No additional authentication is +necessary if you use Google Compute Engine. + +```shell +docker tag apache/beam_python3.10_sdk:2.60.0 us-central1-docker.pkg.dev///beam_python3.10_sdk:2.60.0-custom +docker push us-central1-docker.pkg.dev///beam_python3.10_sdk:2.60.0-custom +``` + +If you push an image in an environment other than a VM in Google Cloud, you +should configure [docker authentication with +`gcloud`](https://cloud.google.com/artifact-registry/docs/docker/authentication#gcloud-helper) +before `docker push`. + +#### Docker Hub + +You can push your Docker hub repository +after [docker login](https://docs.docker.com/reference/cli/docker/login/). + +```shell +docker tag apache/beam_python3.10_sdk:2.60.0 /beam_python3.10_sdk:2.60.0-custom +docker push /beam_python3.10_sdk:2.60.0-custom +``` + diff --git a/website/www/site/content/en/documentation/sdks/yaml-errors.md b/website/www/site/content/en/documentation/sdks/yaml-errors.md index 8c0d9f06ade3..903e18d6b3c7 100644 --- a/website/www/site/content/en/documentation/sdks/yaml-errors.md +++ b/website/www/site/content/en/documentation/sdks/yaml-errors.md @@ -37,7 +37,8 @@ The `output` parameter is a name that must referenced as an input to another transform that will process the errors (e.g. by writing them out). For example, the following code will write all "good" processed records to one file and -any "bad" records to a separate file. +any "bad" records, along with metadata about what error was encountered, +to a separate file. ``` pipeline: @@ -77,6 +78,8 @@ for a robust pipeline). Note also that the exact format of the error outputs is still being finalized. They can be safely printed and written to outputs, but their precise schema may change in a future version of Beam and should not yet be depended on. +Currently it has, at the very least, an `element` field which holds the element +that caused the error. Some transforms allow for extra arguments in their error_handling config, e.g. for Python functions one can give a `threshold` which limits the relative number diff --git a/website/www/site/content/en/get-started/downloads.md b/website/www/site/content/en/get-started/downloads.md index 08614b8835c1..ff432996578d 100644 --- a/website/www/site/content/en/get-started/downloads.md +++ b/website/www/site/content/en/get-started/downloads.md @@ -96,10 +96,18 @@ versions denoted `0.x.y`. ## Releases +### 2.60.0 (2024-10-17) + +Official [source code download](https://downloads.apache.org/beam/2.60.0/apache-beam-2.60.0-source-release.zip). +[SHA-512](https://downloads.apache.org/beam/2.60.0/apache-beam-2.60.0-source-release.zip.sha512). +[signature](https://downloads.apache.org/beam/2.60.0/apache-beam-2.60.0-source-release.zip.asc). + +[Release notes](https://github.com/apache/beam/releases/tag/v2.60.0) + ### 2.59.0 (2024-09-11) -Official [source code download](https://downloads.apache.org/beam/2.59.0/apache-beam-2.59.0-source-release.zip). -[SHA-512](https://downloads.apache.org/beam/2.59.0/apache-beam-2.59.0-source-release.zip.sha512). -[signature](https://downloads.apache.org/beam/2.59.0/apache-beam-2.59.0-source-release.zip.asc). +Official [source code download](https://archive.apache.org/dist/beam/2.59.0/apache-beam-2.59.0-source-release.zip). +[SHA-512](https://archive.apache.org/dist/beam/2.59.0/apache-beam-2.59.0-source-release.zip.sha512). +[signature](https://archive.apache.org/dist/beam/2.59.0/apache-beam-2.59.0-source-release.zip.asc). [Release notes](https://github.com/apache/beam/releases/tag/v2.59.0) diff --git a/website/www/site/content/en/get-started/quickstart-py.md b/website/www/site/content/en/get-started/quickstart-py.md index 3428f5346e02..d7f896153483 100644 --- a/website/www/site/content/en/get-started/quickstart-py.md +++ b/website/www/site/content/en/get-started/quickstart-py.md @@ -23,7 +23,7 @@ If you're interested in contributing to the Apache Beam Python codebase, see the {{< toc >}} -The Python SDK supports Python 3.8, 3.9, 3.10 and 3.11. Beam 2.48.0 was the last release with support for Python 3.7. +The Python SDK supports Python 3.8, 3.9, 3.10, 3.11 and 3.12. Beam 2.48.0 was the last release with support for Python 3.7. ## Set up your environment diff --git a/website/www/site/data/authors.yml b/website/www/site/data/authors.yml index 53903354bd4d..f3645f7475bf 100644 --- a/website/www/site/data/authors.yml +++ b/website/www/site/data/authors.yml @@ -287,3 +287,6 @@ jkinard: jkim: name: Jaehyeon Kim email: dottami@gmail.com +ffernandez92: + name: Ferran Fernandez + email: ffernandez.upc@gmail.com diff --git a/website/www/site/data/capability_matrix.yaml b/website/www/site/data/capability_matrix.yaml index dcbbca438b6e..e6fd51a9bb17 100644 --- a/website/www/site/data/capability_matrix.yaml +++ b/website/www/site/data/capability_matrix.yaml @@ -393,7 +393,7 @@ capability-matrix: - class: dataflow l1: "Partially" l2: non-merging windows - l3: State is supported for non-merging windows. SetState and MapState are not yet supported. + l3: "State is supported for non-merging windows. The MapState, SetState, and MultimapState state types are supported in the following scenarios: Java pipelines that don't use Streaming Engine; Java pipelines that use Streaming Engine and version 2.58.0 or later of the Java SDK. SetState, MapState, and MultimapState are not supported for pipelines that use Runner v2." - class: flink l1: "Partially" l2: non-merging windows diff --git a/website/www/site/layouts/partials/section-menu/en/sdks.html b/website/www/site/layouts/partials/section-menu/en/sdks.html index ea48eb6f40d9..243bbd92a465 100644 --- a/website/www/site/layouts/partials/section-menu/en/sdks.html +++ b/website/www/site/layouts/partials/section-menu/en/sdks.html @@ -44,6 +44,7 @@
  • Managing pipeline dependencies
  • Python multi-language pipelines quickstart
  • Python Unrecoverable Errors
  • +
  • Python SDK image build
  • diff --git a/website/www/site/static/images/logos/powered-by/behalf.png b/website/www/site/static/images/logos/powered-by/behalf.png deleted file mode 100644 index 346ec880d764..000000000000 Binary files a/website/www/site/static/images/logos/powered-by/behalf.png and /dev/null differ