diff --git a/.github/trigger_files/IO_Iceberg_Integration_Tests.json b/.github/trigger_files/IO_Iceberg_Integration_Tests.json index bbdc3a3910ef..3f63c0c9975f 100644 --- a/.github/trigger_files/IO_Iceberg_Integration_Tests.json +++ b/.github/trigger_files/IO_Iceberg_Integration_Tests.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 3 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Java.json b/.github/trigger_files/beam_PostCommit_Java.json index 9e26dfeeb6e6..920c8d132e4a 100644 --- a/.github/trigger_files/beam_PostCommit_Java.json +++ b/.github/trigger_files/beam_PostCommit_Java.json @@ -1 +1,4 @@ -{} \ No newline at end of file +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 +} \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json b/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json index 1efc8e9e4405..3f63c0c9975f 100644 --- a/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json +++ b/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Java_Nexmark_Dataflow.json b/.github/trigger_files/beam_PostCommit_Java_Nexmark_Dataflow.json new file mode 100644 index 000000000000..0967ef424bce --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Nexmark_Dataflow.json @@ -0,0 +1 @@ +{} diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.json new file mode 100644 index 000000000000..b26833333238 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/trigger_files/beam_PostCommit_Python_Dependency.json b/.github/trigger_files/beam_PostCommit_Python_Dependency.json index e69de29bb2d1..a7fc54b3e4bb 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Dependency.json +++ b/.github/trigger_files/beam_PostCommit_Python_Dependency.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 + } \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.json new file mode 100644 index 000000000000..3f63c0c9975f --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/trigger_files/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.json b/.github/trigger_files/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.json new file mode 100644 index 000000000000..9e26dfeeb6e6 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_PythonUsingJavaSQL_Dataflow.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/.github/trigger_files/beam_PostCommit_XVR_Samza.json b/.github/trigger_files/beam_PostCommit_XVR_Samza.json new file mode 100644 index 000000000000..9e26dfeeb6e6 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_Samza.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/.github/trigger_files/beam_PreCommit_Flink_Container.json b/.github/trigger_files/beam_PreCommit_Flink_Container.json new file mode 100644 index 000000000000..3f63c0c9975f --- /dev/null +++ b/.github/trigger_files/beam_PreCommit_Flink_Container.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 2 +} diff --git a/.github/workflows/IO_Iceberg_Integration_Tests.yml b/.github/workflows/IO_Iceberg_Integration_Tests.yml index 22b2b4f9287d..68a72790006f 100644 --- a/.github/workflows/IO_Iceberg_Integration_Tests.yml +++ b/.github/workflows/IO_Iceberg_Integration_Tests.yml @@ -75,4 +75,4 @@ jobs: - name: Run IcebergIO Integration Test uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :sdks:java:io:iceberg:catalogTests \ No newline at end of file + gradle-command: :sdks:java:io:iceberg:catalogTests --info \ No newline at end of file diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 971bfd857b27..206364f416f7 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -331,7 +331,6 @@ PostCommit Jobs run in a schedule against master branch and generally do not get | [ PostCommit Java SingleStoreIO IT ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml) | N/A |`beam_PostCommit_Java_SingleStoreIO_IT.json`| [![.github/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_SingleStoreIO_IT.yml?query=event%3Aschedule) | | [ PostCommit Java PVR Spark3 Streaming ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml) | N/A |`beam_PostCommit_Java_PVR_Spark3_Streaming.json`| [![.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml?query=event%3Aschedule) | | [ PostCommit Java PVR Spark Batch ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark_Batch.yml) | N/A |`beam_PostCommit_Java_PVR_Spark_Batch.json`| [![.github/workflows/beam_PostCommit_Java_PVR_Spark_Batch.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark_Batch.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_PVR_Spark_Batch.yml?query=event%3Aschedule) | -| [ PostCommit Java Sickbay ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Sickbay.yml) | N/A |`beam_PostCommit_Java_Sickbay.json`| [![.github/workflows/beam_PostCommit_Java_Sickbay.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Sickbay.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Sickbay.yml?query=event%3Aschedule) | | [ PostCommit Java Tpcds Dataflow ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Dataflow.yml) | N/A |`beam_PostCommit_Java_Tpcds_Dataflow.json`| [![.github/workflows/beam_PostCommit_Java_Tpcds_Dataflow.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Dataflow.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Dataflow.yml?query=event%3Aschedule) | | [ PostCommit Java Tpcds Flink ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Flink.yml) | N/A |`beam_PostCommit_Java_Tpcds_Flink.json`| [![.github/workflows/beam_PostCommit_Java_Tpcds_Flink.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Flink.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Flink.yml?query=event%3Aschedule) | | [ PostCommit Java Tpcds Spark ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Spark.yml) | N/A |`beam_PostCommit_Java_Tpcds_Spark.json`| [![.github/workflows/beam_PostCommit_Java_Tpcds_Spark.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Spark.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_Tpcds_Spark.yml?query=event%3Aschedule) | @@ -372,7 +371,6 @@ PostCommit Jobs run in a schedule against master branch and generally do not get | [ PostCommit Python Xlang Gcp Dataflow ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml) | N/A |`beam_PostCommit_Python_Xlang_Gcp_Dataflow.json`| [![.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Dataflow.yml?query=event%3Aschedule) | | [ PostCommit Python Xlang Gcp Direct ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml) | N/A |`beam_PostCommit_Python_Xlang_Gcp_Direct.json`| [![.github/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_Gcp_Direct.yml?query=event%3Aschedule) | | [ PostCommit Python Xlang IO Dataflow ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml) | N/A |`beam_PostCommit_Python_Xlang_IO_Dataflow.json`| [![.github/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Python_Xlang_IO_Dataflow.yml?query=event%3Aschedule) | -| [ PostCommit Sickbay Python ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Sickbay_Python.yml) | ['3.8','3.9','3.10','3.11'] |`beam_PostCommit_Sickbay_Python.json`| [![.github/workflows/beam_PostCommit_Sickbay_Python.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Sickbay_Python.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Sickbay_Python.yml?query=event%3Aschedule) | | [ PostCommit SQL ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_SQL.yml) | N/A |`beam_PostCommit_SQL.json`| [![.github/workflows/beam_PostCommit_SQL.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_SQL.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_SQL.yml?query=event%3Aschedule) | | [ PostCommit TransformService Direct ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_TransformService_Direct.yml) | N/A |`beam_PostCommit_TransformService_Direct.json`| [![.github/workflows/beam_PostCommit_TransformService_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_TransformService_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_TransformService_Direct.yml?query=event%3Aschedule) | [ PostCommit Website Test](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Website_Test.yml) | N/A |`beam_PostCommit_Website_Test.json`| [![.github/workflows/beam_PostCommit_Website_Test.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Website_Test.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Website_Test.yml?query=event%3Aschedule) | diff --git a/.github/workflows/beam_LoadTests_Python_Combine_Flink_Streaming.yml b/.github/workflows/beam_LoadTests_Python_Combine_Flink_Streaming.yml index baf950589c8e..243e9d32c066 100644 --- a/.github/workflows/beam_LoadTests_Python_Combine_Flink_Streaming.yml +++ b/.github/workflows/beam_LoadTests_Python_Combine_Flink_Streaming.yml @@ -65,7 +65,7 @@ jobs: (github.event_name == 'schedule' && github.repository == 'apache/beam') || github.event.comment.body == 'Run Load Tests Python Combine Flink Streaming' runs-on: [self-hosted, ubuntu-20.04, main] - timeout-minutes: 720 + timeout-minutes: 80 name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) strategy: matrix: @@ -89,17 +89,22 @@ jobs: test-type: load test-language: python argument-file-paths: | - ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt - ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt + ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_1.txt + ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_2.txt + # large loads do not work now + # ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt + # ${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt - name: Start Flink with parallelism 16 env: FLINK_NUM_WORKERS: 16 + HIGH_MEM_MACHINE: n1-highmem-16 + HIGH_MEM_FLINK_PROPS: flink:taskmanager.memory.process.size=16g,flink:taskmanager.memory.flink.size=12g,flink:taskmanager.memory.jvm-overhead.max=4g,flink:jobmanager.memory.process.size=6g,flink:jobmanager.memory.jvm-overhead.max= 2g,flink:jobmanager.memory.flink.size=4g run: | cd ${{ github.workspace }}/.test-infra/dataproc; ./flink_cluster.sh create - name: get current time run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV - # The env variables are created and populated in the test-arguments-action as "_test_arguments_" - - name: run Load test 2GB Fanout 4 + # The env variables are created and populated in the test-arguments-action as "_test_arguments_" + - name: run Load test small Fanout 1 uses: ./.github/actions/gradle-command-self-hosted-action with: gradle-command: :sdks:python:apache_beam:testing:load_tests:run @@ -108,7 +113,7 @@ jobs: -PloadTest.mainClass=apache_beam.testing.load_tests.combine_test \ -Prunner=PortableRunner \ '-PloadTest.args=${{ env.beam_LoadTests_Python_Combine_Flink_Streaming_test_arguments_1 }} --job_name=load-tests-python-flink-streaming-combine-4-${{env.NOW_UTC}}' \ - - name: run Load test 2GB Fanout 8 + - name: run Load test small Fanout 2 uses: ./.github/actions/gradle-command-self-hosted-action with: gradle-command: :sdks:python:apache_beam:testing:load_tests:run @@ -123,4 +128,4 @@ jobs: ${{ github.workspace }}/.test-infra/dataproc/flink_cluster.sh delete # // TODO(https://github.com/apache/beam/issues/20402). Skipping some cases because they are too slow: - # load-tests-python-flink-streaming-combine-1' \ No newline at end of file + # load-tests-python-flink-streaming-combine-1' diff --git a/.github/workflows/beam_PostCommit_Java_Sickbay.yml b/.github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml similarity index 64% rename from .github/workflows/beam_PostCommit_Java_Sickbay.yml rename to .github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml index 95c36fc863cf..4fb236c7c991 100644 --- a/.github/workflows/beam_PostCommit_Java_Sickbay.yml +++ b/.github/workflows/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.yml @@ -15,13 +15,17 @@ # specific language governing permissions and limitations # under the License. -name: PostCommit Java Sickbay +name: PostCommit Java ValidatesDistrolessContainer Dataflow on: schedule: - - cron: '30 4/6 * * *' + - cron: '30 6/8 * * *' pull_request_target: - paths: ['.github/trigger_files/beam_PostCommit_Java_Sickbay.json'] + paths: + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json' + - '.github/trigger_files/beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow.json' + workflow_dispatch: # This allows a subsequently queued workflow run to interrupt previous runs @@ -51,19 +55,19 @@ env: GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} jobs: - beam_PostCommit_Java_Sickbay: + beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow: name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) runs-on: [self-hosted, ubuntu-20.04, main] - timeout-minutes: 120 + timeout-minutes: 390 strategy: - matrix: - job_name: [beam_PostCommit_Java_Sickbay] - job_phrase: [Run Java Sickbay] + matrix: + job_name: [beam_PostCommit_Java_ValidatesDistrolessContainer_Dataflow] + job_phrase: [Run Java Dataflow ValidatesDistrolessContainer] if: | github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request_target' || (github.event_name == 'schedule' && github.repository == 'apache/beam') || - github.event.comment.body == 'Run Java Sickbay' + github.event.comment.body == 'Run Java Dataflow ValidatesDistrolessContainer' steps: - uses: actions/checkout@v4 - name: Setup repository @@ -74,10 +78,28 @@ jobs: github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action - - name: run PostCommit Java Sickbay script + with: + java-version: | + 17 + 21 + - name: Setup docker + run: | + gcloud auth configure-docker us-docker.pkg.dev --quiet + gcloud auth configure-docker us.gcr.io --quiet + gcloud auth configure-docker gcr.io --quiet + gcloud auth configure-docker us-central1-docker.pkg.dev --quiet + - name: run validatesDistrolessContainer script (Java 17) + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :runners:google-cloud-dataflow-java:examplesJavaRunnerV2IntegrationTestDistroless + arguments: '-PtestJavaVersion=java17 -PdockerTag=$(date +%s)' + max-workers: 12 + - name: run validatesDistrolessContainer script (Java 21) uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :javaPostCommitSickbay + gradle-command: :runners:google-cloud-dataflow-java:examplesJavaRunnerV2IntegrationTestDistroless + arguments: '-PtestJavaVersion=java21 -PdockerTag=$(date +%s)' + max-workers: 12 - name: Archive JUnit Test Results uses: actions/upload-artifact@v4 if: ${{ !success() }} @@ -90,4 +112,4 @@ jobs: with: commit: '${{ env.prsha || env.GITHUB_SHA }}' comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} - files: '**/build/test-results/**/*.xml' \ No newline at end of file + files: '**/build/test-results/**/*.xml' diff --git a/.github/workflows/beam_PostCommit_Sickbay_Python.yml b/.github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml similarity index 52% rename from .github/workflows/beam_PostCommit_Sickbay_Python.yml rename to .github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml index 6d253e03723d..6f8a7bdd0631 100644 --- a/.github/workflows/beam_PostCommit_Sickbay_Python.yml +++ b/.github/workflows/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.yml @@ -1,31 +1,32 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -name: PostCommit Sickbay Python +name: PostCommit Python ValidatesDistrolessContainer Dataflow on: + schedule: + - cron: '15 5/6 * * *' pull_request_target: - paths: ['.github/trigger_files/beam_PostCommit_Sickbay_Python.json'] + paths: + - 'release/trigger_all_tests.json' + # Since distroless is based on original sdk container images, we want to also trigger distroless checks here. + - '.github/trigger_files/beam_PostCommit_Python_ValidatesContainer_Dataflow.json' + - '.github/trigger_files/beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow.json' workflow_dispatch: - -# This allows a subsequently queued workflow run to interrupt previous runs -concurrency: - group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' - cancel-in-progress: true + issue_comment: + types: [created] #Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event permissions: @@ -43,51 +44,65 @@ permissions: security-events: read statuses: read +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + env: DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} jobs: - beam_PostCommit_Sickbay_Python: - name: ${{ matrix.job_name }} (${{ matrix.job_phrase_1 }} ${{ matrix.python_version }} ${{ matrix.job_phrase_2 }}) + beam_PostCommit_Python_ValidatesContainer_Dataflow: + if: | + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + github.event_name == 'workflow_dispatch' || + github.event_name == 'pull_request_target' || + startsWith(github.event.comment.body, 'Run Python Dataflow ValidatesDistrolessContainer') runs-on: [self-hosted, ubuntu-20.04, main] - timeout-minutes: 180 + timeout-minutes: 100 + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) strategy: fail-fast: false matrix: - job_name: [beam_PostCommit_Sickbay_Python] - job_phrase_1: [Run Python] - job_phrase_2: [PostCommit Sickbay] - python_version: ['3.9', '3.10', '3.11', '3.12'] - if: | - github.event_name == 'workflow_dispatch' || - github.event_name == 'pull_request_target' || - (github.event_name == 'schedule' && github.repository == 'apache/beam') || - (startswith(github.event.comment.body, 'Run Python') && - endswith(github.event.comment.body, 'PostCommit Sickbay')) + job_name: ["beam_PostCommit_Python_ValidatesDistrolessContainer_Dataflow"] + job_phrase: ["Run Python Dataflow ValidatesDistrolessContainer"] + python_version: ['3.9','3.10','3.11','3.12'] steps: - uses: actions/checkout@v4 - name: Setup repository uses: ./.github/actions/setup-action with: - comment_phrase: ${{ matrix.job_phrase_1 }} ${{ matrix.python_version }} ${{ matrix.job_phrase_2 }} + comment_phrase: ${{ matrix.job_phrase }} ${{ matrix.python_version }} github_token: ${{ secrets.GITHUB_TOKEN }} - github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase_1 }} ${{ matrix.python_version }} ${{ matrix.job_phrase_2 }}) + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) - name: Setup environment uses: ./.github/actions/setup-environment-action with: + java-version: | + 11 + 8 python-version: ${{ matrix.python_version }} + - name: Setup docker + run: | + gcloud auth configure-docker us-docker.pkg.dev --quiet + gcloud auth configure-docker us.gcr.io --quiet + gcloud auth configure-docker gcr.io --quiet + gcloud auth configure-docker us-central1-docker.pkg.dev --quiet - name: Set PY_VER_CLEAN id: set_py_ver_clean run: | PY_VER=${{ matrix.python_version }} PY_VER_CLEAN=${PY_VER//.} echo "py_ver_clean=$PY_VER_CLEAN" >> $GITHUB_OUTPUT - - name: run PostCommit Python ${{ matrix.python_version }} script + - name: Run validatesDistrolessContainer script + env: + USER: github-actions uses: ./.github/actions/gradle-command-self-hosted-action with: - gradle-command: :sdks:python:test-suites:dataflow:py${{steps.set_py_ver_clean.outputs.py_ver_clean}}:postCommitSickbay + gradle-command: :sdks:python:test-suites:dataflow:py${{steps.set_py_ver_clean.outputs.py_ver_clean}}:validatesDistrolessContainer arguments: | -PpythonVersion=${{ matrix.python_version }} \ - name: Archive Python Test Results diff --git a/.github/workflows/beam_PreCommit_Flink_Container.yml b/.github/workflows/beam_PreCommit_Flink_Container.yml new file mode 100644 index 000000000000..519b0273420a --- /dev/null +++ b/.github/workflows/beam_PreCommit_Flink_Container.yml @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PreCommit Flink Container + +on: + pull_request_target: + paths: + - 'model/**' + - 'sdks/python/**' + - 'release/**' + - 'sdks/java/io/kafka/**' + - 'runners/core-construction-java/**' + - 'runners/core-java/**' + - 'runners/extensions-java/**' + - 'runners/flink/**' + - 'runners/java-fn-execution/**' + - 'runners/reference/**' + - '.github/trigger_files/beam_PreCommit_Flink_Container.json' + - 'release/trigger_all_tests.json' + push: + branches: ['master', 'release-*'] + tags: 'v*' + schedule: + - cron: '0 */6 * * *' + workflow_dispatch: + +# Setting explicit permissions for the action to avoid the default permissions which are `write-all` +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.sha || github.head_ref || github.ref }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + INFLUXDB_USER: ${{ secrets.INFLUXDB_USER }} + INFLUXDB_USER_PASSWORD: ${{ secrets.INFLUXDB_USER_PASSWORD }} + GCLOUD_ZONE: us-central1-a + CLUSTER_NAME: beam-precommit-flink-container-${{ github.run_id }} + GCS_BUCKET: gs://beam-flink-cluster + FLINK_DOWNLOAD_URL: https://archive.apache.org/dist/flink/flink-1.17.0/flink-1.17.0-bin-scala_2.12.tgz + HADOOP_DOWNLOAD_URL: https://repo.maven.apache.org/maven2/org/apache/flink/flink-shaded-hadoop-2-uber/2.8.3-10.0/flink-shaded-hadoop-2-uber-2.8.3-10.0.jar + FLINK_TASKMANAGER_SLOTS: 1 + DETACHED_MODE: true + HARNESS_IMAGES_TO_PULL: gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest + JOB_SERVER_IMAGE: gcr.io/apache-beam-testing/beam_portability/beam_flink1.17_job_server:latest + ARTIFACTS_DIR: gs://beam-flink-cluster/beam-precommit-flink-container-${{ github.run_id }} + +jobs: + beam_PreCommit_Flink_Container: + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'push' || + github.event_name == 'schedule' || + github.event_name == 'pull_request_target' || + github.event.comment.body == 'Run Flink Container PreCommit' + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 45 + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + strategy: + matrix: + job_name: ["beam_PreCommit_Flink_Container"] + job_phrase: ["Run Flink Container PreCommit"] + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + python-version: default + - name: Prepare test arguments + uses: ./.github/actions/test-arguments-action + with: + test-type: precommit + test-language: go,python,java + argument-file-paths: | + ${{ github.workspace }}/.github/workflows/flink-tests-pipeline-options/go_Combine_Flink_Batch_small.txt + ${{ github.workspace }}/.github/workflows/flink-tests-pipeline-options/python_Combine_Flink_Batch_small.txt + ${{ github.workspace }}/.github/workflows/flink-tests-pipeline-options/java_Combine_Flink_Batch_small.txt + - name: get current time + run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV + - name: Start Flink with 2 workers + env: + FLINK_NUM_WORKERS: 2 + run: | + cd ${{ github.workspace }}/.test-infra/dataproc; ./flink_cluster.sh create + # Run a simple Go Combine load test to verify the Flink container + - name: Run Flink Container Test with Go Combine + timeout-minutes: 10 + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:go:test:load:run + arguments: | + -PloadTest.mainClass=combine \ + -Prunner=PortableRunner \ + '-PloadTest.args=${{ env.beam_PreCommit_Flink_Container_test_arguments_1 }} --job_name=flink-tests-go-${{env.NOW_UTC}}' + + # Run a Python Combine load test to verify the Flink container + - name: Run Flink Container Test with Python Combine + timeout-minutes: 20 + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:python:apache_beam:testing:load_tests:run + arguments: | + -PloadTest.mainClass=apache_beam.testing.load_tests.combine_test \ + -Prunner=FlinkRunner \ + '-PloadTest.args=${{ env.beam_PreCommit_Flink_Container_test_arguments_2 }} --job_name=flink-tests-python-${{env.NOW_UTC}}' + + # Run a Java Combine load test to verify the Flink container + - name: Run Flink Container Test with Java Combine + timeout-minutes: 10 + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:java:testing:load-tests:run + arguments: | + -PloadTest.mainClass=org.apache.beam.sdk.loadtests.CombineLoadTest \ + -Prunner=:runners:flink:1.17 \ + '-PloadTest.args=${{ env.beam_PreCommit_Flink_Container_test_arguments_3 }} --jobName=flink-tests-java11-${{env.NOW_UTC}}' + + - name: Teardown Flink + if: always() + run: | + ${{ github.workspace }}/.test-infra/dataproc/flink_cluster.sh delete diff --git a/.github/workflows/flink-tests-pipeline-options/go_Combine_Flink_Batch_small.txt b/.github/workflows/flink-tests-pipeline-options/go_Combine_Flink_Batch_small.txt new file mode 100644 index 000000000000..6b44f53886b2 --- /dev/null +++ b/.github/workflows/flink-tests-pipeline-options/go_Combine_Flink_Batch_small.txt @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--input_options=''{\"num_records\":200,\"key_size\":1,\"value_size\":9}'' +--fanout=1 +--top_count=10 +--parallelism=2 +--endpoint=localhost:8099 +--environment_type=DOCKER +--environment_config=gcr.io/apache-beam-testing/beam-sdk/beam_go_sdk:latest +--runner=FlinkRunner \ No newline at end of file diff --git a/.github/workflows/flink-tests-pipeline-options/java_Combine_Flink_Batch_small.txt b/.github/workflows/flink-tests-pipeline-options/java_Combine_Flink_Batch_small.txt new file mode 100644 index 000000000000..e792682bfbc4 --- /dev/null +++ b/.github/workflows/flink-tests-pipeline-options/java_Combine_Flink_Batch_small.txt @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--sourceOptions={"numRecords":200,"keySizeBytes":1,"valueSizeBytes":9} +--fanout=1 +--iterations=1 +--topCount=10 +--parallelism=2 +--jobEndpoint=localhost:8099 +--defaultEnvironmentType=DOCKER +--defaultEnvironmentConfig=gcr.io/apache-beam-testing/beam-sdk/beam_java11_sdk:latest +--runner=FlinkRunner \ No newline at end of file diff --git a/.github/workflows/flink-tests-pipeline-options/python_Combine_Flink_Batch_small.txt b/.github/workflows/flink-tests-pipeline-options/python_Combine_Flink_Batch_small.txt new file mode 100644 index 000000000000..5522a8f9b823 --- /dev/null +++ b/.github/workflows/flink-tests-pipeline-options/python_Combine_Flink_Batch_small.txt @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--input_options=''{\\"num_records\\":200,\\"key_size\\":1,\\"value_size\\":9,\\"algorithm\\":\\"lcg\\"}'' +--parallelism=2 +--job_endpoint=localhost:8099 +--environment_type=DOCKER +--environment_config=gcr.io/apache-beam-testing/beam-sdk/beam_python3.9_sdk:latest +--top_count=10 +--runner=PortableRunner \ No newline at end of file diff --git a/.github/workflows/go_tests.yml b/.github/workflows/go_tests.yml index e85c4eba866b..5ae3609ed997 100644 --- a/.github/workflows/go_tests.yml +++ b/.github/workflows/go_tests.yml @@ -50,7 +50,7 @@ jobs: - name: Delete old coverage run: "cd sdks && rm -rf .coverage.txt || :" - name: Run coverage - run: cd sdks && go test -coverprofile=coverage.txt -covermode=atomic ./go/pkg/... ./go/container/... ./java/container/... ./python/container/... ./typescript/container/... + run: cd sdks && go test -timeout=25m -coverprofile=coverage.txt -covermode=atomic ./go/pkg/... ./go/container/... ./java/container/... ./python/container/... ./typescript/container/... - uses: codecov/codecov-action@v3 with: flags: go diff --git a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt index 650236a9c500..6280e01dccdb 100644 --- a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt +++ b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_4.txt @@ -27,4 +27,5 @@ --top_count=20 --streaming --use_stateful_load_generator ---runner=PortableRunner \ No newline at end of file +--runner=PortableRunner +--max_cache_memory_usage_mb=256 \ No newline at end of file diff --git a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt index 4208571fef62..e1b77d15b95b 100644 --- a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt +++ b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_2GB_Fanout_8.txt @@ -27,4 +27,5 @@ --top_count=20 --streaming --use_stateful_load_generator ---runner=PortableRunner \ No newline at end of file +--runner=PortableRunner +--max_cache_memory_usage_mb=256 \ No newline at end of file diff --git a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_1.txt b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_1.txt new file mode 100644 index 000000000000..f16e9e4b06ef --- /dev/null +++ b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_1.txt @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--publish_to_big_query=true +--metrics_dataset=load_test +--metrics_table=python_flink_streaming_combine_4 +--influx_measurement=python_streaming_combine_4 +--input_options=''{\\"num_records\\":200000,\\"key_size\\":10,\\"value_size\\":90,\\"algorithm\\":\\"lcg\\"}'' +--parallelism=16 +--job_endpoint=localhost:8099 +--environment_type=DOCKER +--environment_config=gcr.io/apache-beam-testing/beam-sdk/beam_python3.9_sdk:latest +--fanout=1 +--top_count=20 +--streaming +--use_stateful_load_generator +--runner=PortableRunner +--max_cache_memory_usage_mb=256 \ No newline at end of file diff --git a/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_2.txt b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_2.txt new file mode 100644 index 000000000000..5f66e519c31a --- /dev/null +++ b/.github/workflows/load-tests-pipeline-options/python_Combine_Flink_Streaming_small_Fanout_2.txt @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--publish_to_big_query=true +--metrics_dataset=load_test +--metrics_table=python_flink_streaming_combine_5 +--influx_measurement=python_streaming_combine_5 +--input_options=''{\\"num_records\\":200000,\\"key_size\\":10,\\"value_size\\":90,\\"algorithm\\":\\"lcg\\"}'' +--parallelism=16 +--job_endpoint=localhost:8099 +--environment_type=DOCKER +--environment_config=gcr.io/apache-beam-testing/beam-sdk/beam_python3.9_sdk:latest +--fanout=2 +--top_count=20 +--streaming +--use_stateful_load_generator +--runner=PortableRunner +--max_cache_memory_usage_mb=256 \ No newline at end of file diff --git a/.github/workflows/republish_released_docker_containers.yml b/.github/workflows/republish_released_docker_containers.yml index d359ba2c4489..ed6e74ecf13d 100644 --- a/.github/workflows/republish_released_docker_containers.yml +++ b/.github/workflows/republish_released_docker_containers.yml @@ -34,8 +34,8 @@ on: - cron: "0 6 * * 1" env: docker_registry: gcr.io - release: ${{ github.event.inputs.RELEASE || "2.60.0" }} - rc: ${{ github.event.inputs.RC || "2" }} + release: ${{ github.event.inputs.RELEASE || "2.61.0" }} + rc: ${{ github.event.inputs.RC || "3" }} jobs: diff --git a/.test-infra/dataproc/flink_cluster.sh b/.test-infra/dataproc/flink_cluster.sh index 759d7a6fcc38..4a97850f5ac1 100755 --- a/.test-infra/dataproc/flink_cluster.sh +++ b/.test-infra/dataproc/flink_cluster.sh @@ -129,13 +129,26 @@ function create_cluster() { local image_version=$DATAPROC_VERSION echo "Starting dataproc cluster. Dataproc version: $image_version" - # Docker init action restarts yarn so we need to start yarn session after this restart happens. - # This is why flink init action is invoked last. - # TODO(11/11/2022) remove --worker-machine-type and --master-machine-type once N2 CPUs quota relaxed - # Dataproc 2.1 uses n2-standard-2 by default but there is N2 CPUs=24 quota limit - gcloud dataproc clusters create $CLUSTER_NAME --region=$GCLOUD_REGION --num-workers=$FLINK_NUM_WORKERS --public-ip-address \ - --master-machine-type=n1-standard-2 --worker-machine-type=n1-standard-2 --metadata "${metadata}", \ - --image-version=$image_version --zone=$GCLOUD_ZONE --optional-components=FLINK,DOCKER --quiet + local worker_machine_type="n1-standard-2" # Default worker type + local master_machine_type="n1-standard-2" # Default master type + + if [[ -n "${HIGH_MEM_MACHINE:=}" ]]; then + worker_machine_type="${HIGH_MEM_MACHINE}" + master_machine_type="${HIGH_MEM_MACHINE}" + + gcloud dataproc clusters create $CLUSTER_NAME --enable-component-gateway --region=$GCLOUD_REGION --num-workers=$FLINK_NUM_WORKERS --public-ip-address \ + --master-machine-type=${master_machine_type} --worker-machine-type=${worker_machine_type} --metadata "${metadata}", \ + --image-version=$image_version --zone=$GCLOUD_ZONE --optional-components=FLINK,DOCKER \ + --properties="${HIGH_MEM_FLINK_PROPS}" + else + # Docker init action restarts yarn so we need to start yarn session after this restart happens. + # This is why flink init action is invoked last. + # TODO(11/22/2024) remove --worker-machine-type and --master-machine-type once N2 CPUs quota relaxed + # Dataproc 2.1 uses n2-standard-2 by default but there is N2 CPUs=24 quota limit for this project + gcloud dataproc clusters create $CLUSTER_NAME --enable-component-gateway --region=$GCLOUD_REGION --num-workers=$FLINK_NUM_WORKERS --public-ip-address \ + --master-machine-type=${master_machine_type} --worker-machine-type=${worker_machine_type} --metadata "${metadata}", \ + --image-version=$image_version --zone=$GCLOUD_ZONE --optional-components=FLINK,DOCKER --quiet + fi } # Runs init actions for Docker, Portability framework (Beam) and Flink cluster diff --git a/.test-infra/metrics/sync/github/github_runs_prefetcher/code/config.yaml b/.test-infra/metrics/sync/github/github_runs_prefetcher/code/config.yaml index 722cc4e8f29e..eccaaa5f3b17 100644 --- a/.test-infra/metrics/sync/github/github_runs_prefetcher/code/config.yaml +++ b/.test-infra/metrics/sync/github/github_runs_prefetcher/code/config.yaml @@ -124,7 +124,6 @@ categories: - "PostCommit Java PVR Samza" - "PreCommit Java Tika IO Direct" - "PostCommit Java SingleStoreIO IT" - - "PostCommit Java Sickbay" - "PostCommit Java ValidatesRunner Direct" - "PreCommit Java SingleStore IO Direct" - "PreCommit Java InfluxDb IO Direct" @@ -227,7 +226,6 @@ categories: - "PreCommit Python Transforms" - "Build python source distribution and wheels" - "Python tests" - - "PostCommit Sickbay Python" - "PreCommit Portable Python" - "PreCommit Python Coverage" - "PreCommit Python Docker" @@ -317,7 +315,6 @@ categories: - "PostCommit PortableJar Spark" - "PreCommit Integration and Load Test Framework" - "pr-bot-update-reviewers" - - "Cut Release Branch" - "Generate issue report" - "Dask Runner Tests" - "PreCommit Typescript" @@ -328,6 +325,12 @@ categories: - "Assign Milestone on issue close" - "Local environment tests" - "PreCommit SQL" - - "LabelPrs" + - "LabelPrs" + - name: safe_to_ignore + groupThreshold: 0 + tests: - "build_release_candidate" + - "Cut Release Branch" + - "PostCommit Java Sickbay" + - "PostCommit Sickbay Python" diff --git a/CHANGES.md b/CHANGES.md index 979cbbd67329..fc32398a7a5a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -62,6 +62,7 @@ ## I/Os +* gcs-connector config options can be set via GcsOptions (Java) ([#32769](https://github.com/apache/beam/pull/32769)). * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## New Features / Improvements @@ -70,6 +71,7 @@ ## Breaking Changes +* Upgraded ZetaSQL to 2024.11.1 ([#32902](https://github.com/apache/beam/pull/32902)). Java11+ is now needed if Beam's ZetaSQL component is used. * X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). ## Deprecations @@ -88,18 +90,15 @@ * ([#X](https://github.com/apache/beam/issues/X)). -# [2.61.0] - Unreleased +# [2.61.0] - 2024-11-25 ## Highlights -* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). -* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). * [Python] Introduce Managed Transforms API ([#31495](https://github.com/apache/beam/pull/31495)) * Flink 1.19 support added ([#32648](https://github.com/apache/beam/pull/32648)) ## I/Os -* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * [Managed Iceberg] Support creating tables if needed ([#32686](https://github.com/apache/beam/pull/32686)) * [Managed Iceberg] Now available in Python SDK ([#31495](https://github.com/apache/beam/pull/31495)) * [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) @@ -111,36 +110,23 @@ ## New Features / Improvements * Added support for read with metadata in MqttIO (Java) ([#32195](https://github.com/apache/beam/issues/32195)) -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). -* Added support for processing events which use a global sequence to "ordered" extension (Java) [#32540](https://github.com/apache/beam/pull/32540) +* Added support for processing events which use a global sequence to "ordered" extension (Java) ([#32540](https://github.com/apache/beam/pull/32540)) * Add new meta-transform FlattenWith and Tee that allow one to introduce branching without breaking the linear/chaining style of pipeline construction. - -## Breaking Changes - -* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). +* Use Prism as a fallback to the Python Portable runner when running a pipeline with the Python Direct runner ([#32876](https://github.com/apache/beam/pull/32876)) ## Deprecations * Removed support for Flink 1.15 and 1.16 * Removed support for Python 3.8 -* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). ## Bugfixes -* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * (Java) Fixed tearDown not invoked when DoFn throws on Portable Runners ([#18592](https://github.com/apache/beam/issues/18592), [#31381](https://github.com/apache/beam/issues/31381)). * (Java) Fixed protobuf error with MapState.remove() in Dataflow Streaming Java Legacy Runner without Streaming Engine ([#32892](https://github.com/apache/beam/issues/32892)). * Adding flag to support conditionally disabling auto-commit in JdbcIO ReadFn ([#31111](https://github.com/apache/beam/issues/31111)) * (Python) Fixed BigQuery Enrichment bug that can lead to multiple conditions returning duplicate rows, batching returning incorrect results and conditions not scoped by row during batching ([#32780](https://github.com/apache/beam/pull/32780)). -## Security Fixes -* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). - -## Known Issues - -* ([#X](https://github.com/apache/beam/issues/X)). - # [2.60.0] - 2024-10-17 ## Highlights @@ -187,6 +173,7 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) * (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)). * (Java, Python, Go) Fixed PeriodicSequence backlog bytes reporting, which was preventing Dataflow Runner autoscaling from functioning properly ([#32506](https://github.com/apache/beam/issues/32506)). * (Java) Fix improper decoding of rows with schemas containing nullable fields when encoded with a schema with equal encoding positions but modified field order. ([#32388](https://github.com/apache/beam/issues/32388)). +* (Java) Skip close on bundles in BigtableIO.Read ([#32661](https://github.com/apache/beam/pull/32661), [#32759](https://github.com/apache/beam/pull/32759)). ## Known Issues diff --git a/examples/notebooks/beam-ml/automatic_model_refresh.ipynb b/examples/notebooks/beam-ml/automatic_model_refresh.ipynb index 96717cfef60c..2f80846f313b 100644 --- a/examples/notebooks/beam-ml/automatic_model_refresh.ipynb +++ b/examples/notebooks/beam-ml/automatic_model_refresh.ipynb @@ -1,22 +1,21 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "OsFaZscKSPvo" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -36,22 +35,13 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "OsFaZscKSPvo" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "ZUSiAR62SgO8" + }, "source": [ "# Update ML models in running pipelines\n", "\n", @@ -63,20 +53,13 @@ " View source on GitHub\n", " \n", "\n" - ], - "metadata": { - "id": "ZUSiAR62SgO8" - }, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "tBtqF5UpKJNZ" + }, "source": [ "This notebook demonstrates how to perform automatic model updates without stopping your Apache Beam pipeline.\n", "You can use side inputs to update your model in real time, even while the Apache Beam pipeline is running. The side input is passed in a `ModelHandler` configuration object. You can update the model either by leveraging one of Apache Beam's provided patterns, such as the `WatchFilePattern`, or by configuring a custom side input `PCollection` that defines the logic for the model update.\n", @@ -85,36 +68,19 @@ "For more information about side inputs, see the [Side inputs](https://beam.apache.org/documentation/programming-guide/#side-inputs) section in the Apache Beam Programming Guide.\n", "\n", "This example uses `WatchFilePattern` as a side input. `WatchFilePattern` is used to watch for file updates that match the `file_pattern` based on timestamps. It emits the latest `ModelMetadata`, which is used in the RunInference `PTransform` to automatically update the ML model without stopping the Apache Beam pipeline.\n" - ], - "metadata": { - "id": "tBtqF5UpKJNZ" - }, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "SPuXFowiTpWx" + }, "source": [ "## Before you begin\n", "Install the dependencies required to run this notebook.\n", "\n", "To use RunInference with side inputs for automatic model updates, use Apache Beam version 2.46.0 or later." - ], - "metadata": { - "id": "SPuXFowiTpWx" - }, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "code", @@ -122,23 +88,39 @@ "metadata": { "id": "1RyTYsFEIOlA" }, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "!pip install apache_beam[gcp]>=2.46.0 tensorflow==2.15.0 tensorflow_hub==0.16.1 keras==2.15.0 Pillow==11.0.0 --quiet" ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Rs4cwwNrIV9H" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# Imports required for the notebook.\n", "import logging\n", "import time\n", + "import os\n", "from typing import Iterable\n", "from typing import Tuple\n", "\n", @@ -156,21 +138,23 @@ "import numpy\n", "from PIL import Image\n", "import tensorflow as tf" - ], - "metadata": { - "id": "Rs4cwwNrIV9H" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jAKpPcmmGm03" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# Authenticate to your Google Cloud account.\n", "def auth_to_colab():\n", @@ -178,21 +162,13 @@ " auth.authenticate_user()\n", "\n", "auth_to_colab()" - ], - "metadata": { - "id": "jAKpPcmmGm03" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "ORYNKhH3WQyP" + }, "source": [ "## Configure the runner\n", "\n", @@ -202,24 +178,37 @@ "* Configure the pipeline options for the pipeline to run on Dataflow. Make sure the pipeline is using streaming mode.\n", "\n", "In the following code, replace `BUCKET_NAME` with the the name of your Cloud Storage bucket." - ], - "metadata": { - "id": "ORYNKhH3WQyP" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wWjbnq6X-4uE" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "options = PipelineOptions()\n", "options.view_as(StandardOptions).streaming = True\n", "\n", - "BUCKET_NAME = '' # Replace with your bucket name.\n", + "# Replace with your bucket name.\n", + "BUCKET_NAME = '' # @param {type:'string'} \n", + "os.environ['BUCKET_NAME'] = BUCKET_NAME\n", "\n", "# Provide required pipeline options for the Dataflow Runner.\n", "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n", "\n", "# Set the project to the default project in your current Google Cloud environment.\n", - "options.view_as(GoogleCloudOptions).project = ''\n", + "PROJECT_NAME = '' # @param {type:'string'}\n", + "options.view_as(GoogleCloudOptions).project = PROJECT_NAME\n", "\n", "# Set the Google Cloud region that you want to run Dataflow in.\n", "options.view_as(GoogleCloudOptions).region = 'us-central1'\n", @@ -244,113 +233,120 @@ "# To expedite the model update process, it's recommended to set num_workers>1.\n", "# https://github.com/apache/beam/issues/28776\n", "options.view_as(WorkerOptions).num_workers = 5" - ], - "metadata": { - "id": "wWjbnq6X-4uE" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", - "source": [ - "Install the `tensorflow` and `tensorflow_hub` dependencies on Dataflow. Use the `requirements_file` pipeline option to pass these dependencies." - ], "metadata": { "id": "HTJV8pO2Wcw4" - } + }, + "source": [ + "Install the `tensorflow` and `tensorflow_hub` dependencies on Dataflow. Use the `requirements_file` pipeline option to pass these dependencies." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lEy4PkluWbdm" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# In a requirements file, define the dependencies required for the pipeline.\n", "!printf 'tensorflow==2.15.0\\ntensorflow_hub==0.16.1\\nkeras==2.15.0\\nPillow==11.0.0' > ./requirements.txt\n", "# Install the pipeline dependencies on Dataflow.\n", "options.view_as(SetupOptions).requirements_file = './requirements.txt'" - ], - "metadata": { - "id": "lEy4PkluWbdm" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "_AUNH_GJk_NE" + }, "source": [ "## Use the TensorFlow model handler\n", " This example uses `TFModelHandlerTensor` as the model handler and the `resnet_101` model trained on [ImageNet](https://www.image-net.org/).\n", "\n", "\n", "For the Dataflow runner, you need to store the model in a remote location that the Apache Beam pipeline can access. For this example, download the `ResNet101` model, and upload it to the Google Cloud Storage bucket.\n" - ], - "metadata": { - "id": "_AUNH_GJk_NE" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ibkWiwVNvyrn" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "model = tf.keras.applications.resnet.ResNet101()\n", "model.save('resnet101_weights_tf_dim_ordering_tf_kernels.keras')\n", "# After saving the model locally, upload the model to GCS bucket and provide that gcs bucket `URI` as `model_uri` to the `TFModelHandler`\n", - "# Replace `BUCKET_NAME` value with actual bucket name.\n", - "!gsutil cp resnet101_weights_tf_dim_ordering_tf_kernels.keras gs:///dataflow/resnet101_weights_tf_dim_ordering_tf_kernels.keras" - ], - "metadata": { - "id": "ibkWiwVNvyrn" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "!gsutil cp resnet101_weights_tf_dim_ordering_tf_kernels.keras gs://${BUCKET_NAME}/dataflow/resnet101_weights_tf_dim_ordering_tf_kernels.keras" + ] }, { "cell_type": "code", - "source": [ - "model_handler = TFModelHandlerTensor(\n", - " model_uri=dataflow_gcs_location + \"/resnet101_weights_tf_dim_ordering_tf_kernels.keras\")" - ], + "execution_count": null, "metadata": { "id": "kkSnsxwUk-Sp" }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "model_handler = TFModelHandlerTensor(\n", + " model_uri=dataflow_gcs_location + \"/resnet101_weights_tf_dim_ordering_tf_kernels.keras\")" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "tZH0r0sL-if5" + }, "source": [ "## Preprocess images\n", "\n", "Use `preprocess_image` to run the inference, read the image, and convert the image to a TensorFlow tensor." - ], - "metadata": { - "id": "tZH0r0sL-if5" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dU5imgTt-8Ne" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "def preprocess_image(image_name, image_dir):\n", " img = tf.keras.utils.get_file(image_name, image_dir + image_name)\n", @@ -358,21 +354,23 @@ " img = numpy.array(img) / 255.0\n", " img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)\n", " return img_tensor" - ], - "metadata": { - "id": "dU5imgTt-8Ne" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6V5tJxO6-gyt" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "class PostProcessor(beam.DoFn):\n", " \"\"\"Process the PredictionResult to get the predicted label.\n", @@ -387,62 +385,66 @@ " imagenet_labels = numpy.array(open(labels_path).read().splitlines())\n", " predicted_class_name = imagenet_labels[predicted_class]\n", " yield predicted_class_name.title(), element.model_id" - ], - "metadata": { - "id": "6V5tJxO6-gyt" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "code", - "source": [ - "# Define the pipeline object.\n", - "pipeline = beam.Pipeline(options=options)" - ], + "execution_count": null, "metadata": { "id": "GpdKk72O_NXT" }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Define the pipeline object.\n", + "pipeline = beam.Pipeline(options=options)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "elZ53uxc_9Hv" + }, "source": [ "Next, review the pipeline steps and examine the code.\n", "\n", "### Pipeline steps\n" - ], - "metadata": { - "id": "elZ53uxc_9Hv" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "305tkV2sAD-S" + }, "source": [ "1. Create a `PeriodicImpulse` transform, which emits output every `n` seconds. The `PeriodicImpulse` transform generates an infinite sequence of elements with a given runtime interval.\n", "\n", " In this example, `PeriodicImpulse` mimics the Pub/Sub source. Because the inputs in a streaming pipeline arrive in intervals, use `PeriodicImpulse` to output elements at `m` intervals.\n", "To learn more about `PeriodicImpulse`, see the [`PeriodicImpulse` code](https://github.com/apache/beam/blob/9c52e0594d6f0e59cd17ee005acfb41da508e0d5/sdks/python/apache_beam/transforms/periodicsequence.py#L150)." - ], - "metadata": { - "id": "305tkV2sAD-S" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vUFStz66_Tbb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "start_timestamp = time.time() # start timestamp of the periodic impulse\n", "end_timestamp = start_timestamp + 60 * 20 # end timestamp of the periodic impulse (will run for 20 minutes).\n", @@ -455,72 +457,76 @@ " start_timestamp=start_timestamp,\n", " stop_timestamp=end_timestamp,\n", " fire_interval=main_input_fire_interval))" - ], - "metadata": { - "id": "vUFStz66_Tbb" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "8-sal2rFAxP2" + }, "source": [ "2. To read and preprocess the images, use the `preprocess_image` function. This example uses `Cat-with-beanie.jpg` for all inferences.\n", "\n", " **Note**: The image used for prediction is licensed in CC-BY. The creator is listed in the [LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt) file." - ], - "metadata": { - "id": "8-sal2rFAxP2" - } + ] }, { "cell_type": "markdown", - "source": [ - "![download.png]()" - ], "metadata": { "id": "gW4cE8bhXS-d" - } + }, + "source": [ + "![download.png]()" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dGg11TpV_aV6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "image_data = (periodic_impulse | beam.Map(lambda x: \"Cat-with-beanie.jpg\")\n", " | \"ReadImage\" >> beam.Map(lambda image_name: preprocess_image(\n", " image_name=image_name, image_dir='https://storage.googleapis.com/apache-beam-samples/image_captioning/')))" - ], - "metadata": { - "id": "dGg11TpV_aV6" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "eB0-ewd-BCKE" + }, "source": [ "3. Pass the images to the RunInference `PTransform`. RunInference takes `model_handler` and `model_metadata_pcoll` as input parameters.\n", " * `model_metadata_pcoll` is a side input `PCollection` to the RunInference `PTransform`. This side input updates the `model_uri` in the `model_handler` while the Apache Beam pipeline runs.\n", " * Use `WatchFilePattern` as side input to watch a `file_pattern` matching `.keras` files. In this case, the `file_pattern` is `'gs://BUCKET_NAME/dataflow/*keras'`.\n", "\n" - ], - "metadata": { - "id": "eB0-ewd-BCKE" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_AjvvexJ_hUq" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ " # The side input used to watch for the .keras file and update the model_uri of the TFModelHandlerTensor.\n", "file_pattern = dataflow_gcs_location + '/*.keras'\n", @@ -534,108 +540,117 @@ " | \"ApplyWindowing\" >> beam.WindowInto(beam.window.FixedWindows(10))\n", " | \"RunInference\" >> RunInference(model_handler=model_handler,\n", " model_metadata_pcoll=side_input_pcoll))" - ], - "metadata": { - "id": "_AjvvexJ_hUq" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "lTA4wRWNDVis" + }, "source": [ "4. Post-process the `PredictionResult` object.\n", "When the inference is complete, RunInference outputs a `PredictionResult` object that contains the fields `example`, `inference`, and `model_id`. The `model_id` field identifies the model used to run the inference. The `PostProcessor` returns the predicted label and the model ID used to run the inference on the predicted label." - ], - "metadata": { - "id": "lTA4wRWNDVis" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9TB76fo-_vZJ" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "post_processor = (\n", " inferences\n", " | \"PostProcessResults\" >> beam.ParDo(PostProcessor())\n", " | \"LogResults\" >> beam.Map(logging.info))" - ], - "metadata": { - "id": "9TB76fo-_vZJ" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "wYp-mBHHjOjA" + }, "source": [ "### Watch for the model update\n", "\n", "After the pipeline starts processing data, when you see output emitted from the RunInference `PTransform`, upload a `resnet152` model saved in the `.keras` format to a Google Cloud Storage bucket location that matches the `file_pattern` you defined earlier.\n" - ], - "metadata": { - "id": "wYp-mBHHjOjA" - } + ] }, { "cell_type": "code", - "source": [ - "model = tf.keras.applications.resnet.ResNet152()\n", - "model.save('resnet152_weights_tf_dim_ordering_tf_kernels.keras')\n", - "# Replace the `BUCKET_NAME` with the actual bucket name.\n", - "!gsutil cp resnet152_weights_tf_dim_ordering_tf_kernels.keras gs:///resnet152_weights_tf_dim_ordering_tf_kernels.keras" - ], + "execution_count": null, "metadata": { "id": "FpUfNBSWH9Xy" }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "model = tf.keras.applications.resnet.ResNet152()\n", + "model.save('resnet152_weights_tf_dim_ordering_tf_kernels.keras')\n", + "!gsutil cp resnet152_weights_tf_dim_ordering_tf_kernels.keras gs://${BUCKET_NAME}/resnet152_weights_tf_dim_ordering_tf_kernels.keras" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "_ty03jDnKdKR" + }, "source": [ "## Run the pipeline\n", "\n", "Use the following code to run the pipeline." - ], - "metadata": { - "id": "_ty03jDnKdKR" - } + ] }, { "cell_type": "code", - "source": [ - "# Run the pipeline.\n", - "result = pipeline.run().wait_until_finish()" - ], + "execution_count": null, "metadata": { "id": "wd0VJLeLEWBU" }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Run the pipeline.\n", + "result = pipeline.run().wait_until_finish()" + ] } - ] + ], + "metadata": { + "colab": { + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb index 182b88b9c72a..dedaa6b65a5e 100644 --- a/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb +++ b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb @@ -1,22 +1,13 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "55h6JBJeJGqg" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -36,16 +27,13 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "id": "55h6JBJeJGqg", - "cellView": "form" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "YrOuxMeKJZxC" + }, "source": [ "# Use Apache Beam and BigQuery to enrich data\n", "\n", @@ -57,13 +45,13 @@ " View source on GitHub\n", " \n", "\n" - ], - "metadata": { - "id": "YrOuxMeKJZxC" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "pf2bL-PmJScZ" + }, "source": [ "This notebook shows how to enrich data by using the Apache Beam [enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) with [BigQuery](https://cloud.google.com/bigquery/docs/overview). The enrichment transform is an Apache Beam turnkey transform that lets you enrich data by using a key-value lookup. This transform has the following features:\n", "\n", @@ -79,38 +67,40 @@ "\n", "### Install Apache Beam\n", "To use the enrichment transform with the built-in BigQuery handler, install the Apache Beam SDK version 2.57.0 or later." - ], - "metadata": { - "id": "pf2bL-PmJScZ" - } + ] }, { "cell_type": "code", - "source": [ - "!pip install torch\n", - "!pip install apache_beam[interactive,gcp]==2.57.0 --quiet" - ], + "execution_count": null, "metadata": { "id": "oVbWf73FJSzf" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "!pip install torch\n", + "!pip install apache_beam[interactive,gcp]==2.57.0 --quiet" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "siSUsfR5tKX9" + }, "source": [ "Import the following modules:\n", "- Pub/Sub for streaming data\n", "- BigQuery for enrichment\n", "- Apache Beam for running the streaming pipeline\n", "- PyTorch to predict customer churn" - ], - "metadata": { - "id": "siSUsfR5tKX9" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "p6bruDqFJkXE" + }, + "outputs": [], "source": [ "import datetime\n", "import json\n", @@ -137,49 +127,47 @@ "import pandas as pd\n", "\n", "from sklearn.preprocessing import LabelEncoder" - ], - "metadata": { - "id": "p6bruDqFJkXE" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "t0QfhuUlJozO" + }, "source": [ "### Authenticate with Google Cloud\n", "This notebook reads data from Pub/Sub and BigQuery. To use your Google Cloud account, authenticate this notebook.\n", "To prepare for this step, replace `` with your Google Cloud project ID." - ], - "metadata": { - "id": "t0QfhuUlJozO" - } + ] }, { "cell_type": "code", - "source": [ - "PROJECT_ID = \"\"\n" - ], + "execution_count": null, "metadata": { "id": "RwoBZjD1JwnD" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "PROJECT_ID = \"\" # @param {type:'string'}\n" + ] }, { "cell_type": "code", - "source": [ - "from google.colab import auth\n", - "auth.authenticate_user(project_id=PROJECT_ID)" - ], + "execution_count": null, "metadata": { "id": "rVAyQxoeKflB" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "from google.colab import auth\n", + "auth.authenticate_user(project_id=PROJECT_ID)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "1vDwknoHKoa-" + }, "source": [ "### Set up the BigQuery tables\n", "\n", @@ -187,36 +175,38 @@ "\n", "- Replace `` with the name of your BigQuery dataset. Only letters (uppercase or lowercase), numbers, and underscores are allowed.\n", "- If the dataset does not exist, a new dataset with this ID is created." - ], - "metadata": { - "id": "1vDwknoHKoa-" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UxeGFqSJu-G6" + }, + "outputs": [], "source": [ - "DATASET_ID = \"\"\n", + "DATASET_ID = \"\" # @param {type:'string'}\n", "\n", "CUSTOMERS_TABLE_ID = f'{PROJECT_ID}.{DATASET_ID}.customers'\n", "USAGE_TABLE_ID = f'{PROJECT_ID}.{DATASET_ID}.usage'" - ], - "metadata": { - "id": "UxeGFqSJu-G6" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Create customer and usage tables, and insert fake data." - ], "metadata": { "id": "Gw4RfZavyfpo" - } + }, + "source": [ + "Create customer and usage tables, and insert fake data." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-QRZC4v0KipK" + }, + "outputs": [], "source": [ "client = bigquery.Client(project=PROJECT_ID)\n", "\n", @@ -276,33 +266,33 @@ "job.result() # Wait for the job to complete.\n", "\n", "print(f\"Usage table created and populated: {USAGE_TABLE_ID}\")" - ], - "metadata": { - "id": "-QRZC4v0KipK" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "### Train the model" - ], "metadata": { "id": "PZCjCzxaLOJt" - } + }, + "source": [ + "### Train the model" + ] }, { "cell_type": "markdown", - "source": [ - "Create sample data and train a simple model for churn prediction." - ], "metadata": { "id": "R4dIHclDLfIj" - } + }, + "source": [ + "Create sample data and train a simple model for churn prediction." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YoMjdqJ1KxOM" + }, + "outputs": [], "source": [ "# Create fake training data\n", "data = {\n", @@ -319,51 +309,51 @@ "df = pd.DataFrame(data)\n", "df['plan'] = plan_encoder.transform(data['plan'])\n", "\n" - ], - "metadata": { - "id": "YoMjdqJ1KxOM" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "EgIFJx76MF3v" + }, "source": [ "Preprocess the data:\n", "\n", "1. Convert the lists to tensors.\n", "2. Separate the features from the expected prediction." - ], - "metadata": { - "id": "EgIFJx76MF3v" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "P-8lKzdzLnGo" + }, + "outputs": [], "source": [ "features = ['age', 'plan', 'contract_length', 'avg_monthly_calls', 'avg_monthly_data_usage_gb']\n", "target = 'churned'\n", "\n", "X = torch.tensor(df[features].values, dtype=torch.float)\n", "Y = torch.tensor(df[target], dtype=torch.float)" - ], - "metadata": { - "id": "P-8lKzdzLnGo" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Define a model that has five input features and predicts a single value." - ], "metadata": { "id": "4mcNOez1MQZP" - } + }, + "source": [ + "Define a model that has five input features and predicts a single value." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YvdPNlzoMTtl" + }, + "outputs": [], "source": [ "def build_model(n_inputs, n_outputs):\n", " \"\"\"build_model builds and returns a model that takes\n", @@ -375,24 +365,24 @@ " torch.nn.ReLU(),\n", " torch.nn.Linear(16, n_outputs),\n", " torch.nn.Sigmoid())" - ], - "metadata": { - "id": "YvdPNlzoMTtl" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Train the model." - ], "metadata": { "id": "GaLBmcvrMOWy" - } + }, + "source": [ + "Train the model." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0XqctMiPMaim" + }, + "outputs": [], "source": [ "model = build_model(n_inputs=5, n_outputs=1)\n", "\n", @@ -407,61 +397,61 @@ " loss = loss_fn(pred, Y[i].unsqueeze(0))\n", " loss.backward()\n", " optimizer.step()" - ], - "metadata": { - "id": "0XqctMiPMaim" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Save the model to the `STATE_DICT_PATH` variable." - ], "metadata": { "id": "m7MD6RwGMdyU" - } + }, + "source": [ + "Save the model to the `STATE_DICT_PATH` variable." + ] }, { "cell_type": "code", - "source": [ - "STATE_DICT_PATH = './model.pth'\n", - "torch.save(model.state_dict(), STATE_DICT_PATH)" - ], + "execution_count": null, "metadata": { "id": "Q9WIjw53MgcR" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "STATE_DICT_PATH = './model.pth'\n", + "torch.save(model.state_dict(), STATE_DICT_PATH)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "CJVYA0N0MnZS" + }, "source": [ "### Publish messages to Pub/Sub\n", "Create the Pub/Sub topic and subscription to use for data streaming." - ], - "metadata": { - "id": "CJVYA0N0MnZS" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0uwZz_ijyzL8" + }, + "outputs": [], "source": [ "# Replace with the name of your Pub/Sub topic.\n", - "TOPIC = \"\"\n", + "TOPIC = \"\" # @param {type:'string'}\n", "\n", "# Replace with the subscription for your topic.\n", - "SUBSCRIPTION = \"\"" - ], - "metadata": { - "id": "0uwZz_ijyzL8" - }, - "execution_count": null, - "outputs": [] + "SUBSCRIPTION = \"\" # @param {type:'string'}" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hIgsCWIozdDu" + }, + "outputs": [], "source": [ "from google.api_core.exceptions import AlreadyExists\n", "\n", @@ -482,25 +472,25 @@ " print(f\"Created subscription: {subscription.name}\")\n", "except AlreadyExists:\n", " print(f\"Subscription {subscription_path} already exists.\")" - ], - "metadata": { - "id": "hIgsCWIozdDu" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "VqUaFm_yywjU" + }, "source": [ "\n", "Use the Pub/Sub Python client to publish messages." - ], - "metadata": { - "id": "VqUaFm_yywjU" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fOq1uNXvMku-" + }, + "outputs": [], "source": [ "messages = [\n", " {'customer_id': i}\n", @@ -510,15 +500,13 @@ "for message in messages:\n", " data = json.dumps(message).encode('utf-8')\n", " publish_future = publisher.publish(topic_path, data)" - ], - "metadata": { - "id": "fOq1uNXvMku-" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "giXOGruKM8ZL" + }, "source": [ "## Use the BigQuery enrichment handler\n", "\n", @@ -566,13 +554,15 @@ "* One for usage data that uses a custom aggregation query by using the `query_fn` function\n", "\n", "These handlers are used in the Enrichment transforms in this pipeline to fetch and join data from BigQuery with the streaming data." - ], - "metadata": { - "id": "giXOGruKM8ZL" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "C8XLmBDeMyrB" + }, + "outputs": [], "source": [ "user_data_handler = BigQueryEnrichmentHandler(\n", " project=PROJECT_ID,\n", @@ -613,37 +603,37 @@ " project=PROJECT_ID,\n", " query_fn=usage_data_query_fn\n", ")" - ], - "metadata": { - "id": "C8XLmBDeMyrB" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "3oPYypvmPiyg" + }, "source": [ "In this example:\n", "1. The `user_data_handler` handler uses the `table_name`, `row_restriction_template`, and `fields` parameter combination to fetch customer data.\n", "2. The `usage_data_handler` handler uses the `query_fn` parameter to execute a more complex query that aggregates usage data." - ], - "metadata": { - "id": "3oPYypvmPiyg" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "ksON9uOBQbZm" + }, "source": [ "## Use the `PytorchModelHandlerTensor` interface to run inference\n", "\n", "Define functions to convert enriched data to the tensor format for the model." - ], - "metadata": { - "id": "ksON9uOBQbZm" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XgPontIVP0Cv" + }, + "outputs": [], "source": [ "def convert_row_to_tensor(customer_data):\n", " import pandas as pd\n", @@ -656,69 +646,69 @@ " model_class=build_model,\n", " model_params={'n_inputs':5, 'n_outputs':1}\n", ")).with_preprocess_fn(convert_row_to_tensor)" - ], - "metadata": { - "id": "XgPontIVP0Cv" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Define a `DoFn` to format the output." - ], "metadata": { "id": "O9e7ddgGQxh2" - } + }, + "source": [ + "Define a `DoFn` to format the output." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NMj0V5VyQukk" + }, + "outputs": [], "source": [ "class PostProcessor(beam.DoFn):\n", " def process(self, element, *args, **kwargs):\n", " print('Customer %d churn risk: %s' % (element[0], \"High\" if element[1].inference[0].item() > 0.5 else \"Low\"))" - ], - "metadata": { - "id": "NMj0V5VyQukk" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "-N3a1s2FQ66z" + }, "source": [ "## Run the pipeline\n", "\n", "Configure the pipeline to run in streaming mode." - ], - "metadata": { - "id": "-N3a1s2FQ66z" - } + ] }, { "cell_type": "code", - "source": [ - "options = pipeline_options.PipelineOptions()\n", - "options.view_as(pipeline_options.StandardOptions).streaming = True # Streaming mode is set True" - ], + "execution_count": null, "metadata": { "id": "rgJeV-jWQ4wo" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "options = pipeline_options.PipelineOptions()\n", + "options.view_as(pipeline_options.StandardOptions).streaming = True # Streaming mode is set True" + ] }, { "cell_type": "markdown", - "source": [ - "Pub/Sub sends the data in bytes. Convert the data to `beam.Row` objects by using a `DoFn`." - ], "metadata": { "id": "NRljYVR5RCMi" - } + }, + "source": [ + "Pub/Sub sends the data in bytes. Convert the data to `beam.Row` objects by using a `DoFn`." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Bb-e3yjtQ2iU" + }, + "outputs": [], "source": [ "class DecodeBytes(beam.DoFn):\n", " \"\"\"\n", @@ -729,26 +719,26 @@ " def process(self, element, *args, **kwargs):\n", " element_dict = json.loads(element.decode('utf-8'))\n", " yield beam.Row(**element_dict)" - ], - "metadata": { - "id": "Bb-e3yjtQ2iU" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "Q1HV8wH-RIbj" + }, "source": [ "Use the following code to run the pipeline.\n", "\n", "**Note:** Because this pipeline is a streaming pipeline, you need to manually stop the cell. If you don't stop the cell, the pipeline continues to run." - ], - "metadata": { - "id": "Q1HV8wH-RIbj" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "y6HBH8yoRFp2" + }, + "outputs": [], "source": [ "with beam.Pipeline(options=options) as p:\n", " _ = (p\n", @@ -760,12 +750,22 @@ " | \"RunInference\" >> RunInference(keyed_model_handler)\n", " | \"Format Output\" >> beam.ParDo(PostProcessor())\n", " )" - ], - "metadata": { - "id": "y6HBH8yoRFp2" - }, - "execution_count": null, - "outputs": [] + ] } - ] + ], + "metadata": { + "colab": { + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb index 95be8b1d957c..f2e63d2e4f06 100644 --- a/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb +++ b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb @@ -151,9 +151,9 @@ }, "outputs": [], "source": [ - "PROJECT_ID = \"\"\n", - "INSTANCE_ID = \"\"\n", - "TABLE_ID = \"\"" + "PROJECT_ID = \"\" # @param {type:'string'}\n", + "INSTANCE_ID = \"\" # @param {type:'string'}\n", + "TABLE_ID = \"\" # @param {type:'string'}" ] }, { @@ -457,10 +457,10 @@ "outputs": [], "source": [ "# Replace with the name of your Pub/Sub topic.\n", - "TOPIC = \"\"\n", + "TOPIC = \"\" # @param {type:'string'}\n", "\n", "# Replace with the subscription for your topic.\n", - "SUBSCRIPTION = \"\"\n" + "SUBSCRIPTION = \"\" # @param {type:'string'}\n" ] }, { @@ -532,16 +532,16 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "UEpjy_IsW4P4" + }, "source": [ "The `row_key` parameter represents the field in input schema (`beam.Row`) that contains the row key for a row in the table.\n", "\n", "Starting with Apache Beam version 2.54.0, you can perform either of the following tasks when a table uses composite row keys:\n", "* Modify the input schema to contain the row key in the format required by Bigtable.\n", "* Use a custom enrichment handler. For more information, see the [example handler with composite row key support](https://www.toptal.com/developers/paste-gd/BYFGUL08#)." - ], - "metadata": { - "id": "UEpjy_IsW4P4" - } + ] }, { "cell_type": "code", @@ -636,6 +636,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "fe3bIclV1jZ5" + }, "source": [ "To provide a `lambda` function for using a custom join with the enrichment transform, see the following example.\n", "\n", @@ -648,13 +651,13 @@ " ...\n", " )\n", "```" - ], - "metadata": { - "id": "fe3bIclV1jZ5" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "uilxdknE3ihO" + }, "source": [ "Because the enrichment transform makes API calls to the remote service, use the `timeout` parameter to specify a timeout duration of 10 seconds:\n", "\n", @@ -667,10 +670,7 @@ " ...\n", " )\n", "```" - ], - "metadata": { - "id": "uilxdknE3ihO" - } + ] }, { "cell_type": "markdown", @@ -855,11 +855,11 @@ ], "metadata": { "colab": { - "provenance": [], - "toc_visible": true, "collapsed_sections": [ "RpqZFfFfA_Dt" - ] + ], + "provenance": [], + "toc_visible": true }, "kernelspec": { "display_name": "Python 3", diff --git a/examples/notebooks/beam-ml/data_preprocessing/vertex_ai_text_embeddings.ipynb b/examples/notebooks/beam-ml/data_preprocessing/vertex_ai_text_embeddings.ipynb index 49e2f35b13be..4d816ef97fb0 100644 --- a/examples/notebooks/beam-ml/data_preprocessing/vertex_ai_text_embeddings.ipynb +++ b/examples/notebooks/beam-ml/data_preprocessing/vertex_ai_text_embeddings.ipynb @@ -1,18 +1,4 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "code", @@ -44,6 +30,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "ZUSiAR62SgO8" + }, "source": [ "# Generate text embeddings by using the Vertex AI API\n", "\n", @@ -55,13 +44,13 @@ " View source on GitHub\n", " \n", "\n" - ], - "metadata": { - "id": "ZUSiAR62SgO8" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "bkpSCGCWlqAf" + }, "source": [ "Text embeddings are a way to represent text as numerical vectors. This process lets computers understand and process text data, which is essential for many natural language processing (NLP) tasks.\n", "\n", @@ -84,71 +73,72 @@ "* Do one of the following tasks:\n", " * Configure credentials for your Google Cloud project. For more information, see [Google Auth Library for Python](https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth).\n", " * Store the path to a service account JSON file by using the [GOOGLE_APPLICATION_CREDENTIALS](https://cloud.google.com/docs/authentication/application-default-credentials#GAC) environment variable." - ], - "metadata": { - "id": "bkpSCGCWlqAf" - } + ] }, { "cell_type": "markdown", - "source": [ - "To use your Google Cloud account, authenticate this notebook." - ], "metadata": { "id": "W29FgO5Qv2ew" - } + }, + "source": [ + "To use your Google Cloud account, authenticate this notebook." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nYyyGYt3licq" + }, + "outputs": [], "source": [ "from google.colab import auth\n", "auth.authenticate_user()\n", "\n", - "project = '' # Replace with a valid Google Cloud project ID." - ], - "metadata": { - "id": "nYyyGYt3licq" - }, - "execution_count": null, - "outputs": [] + "# Replace with a valid Google Cloud project ID.\n", + "project = '' # @param {type:'string'}" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "UQROd16ZDN5y" + }, "source": [ "## Install dependencies\n", " Install Apache Beam and the dependencies required for the Vertex AI text-embeddings API." - ], - "metadata": { - "id": "UQROd16ZDN5y" - } + ] }, { "cell_type": "code", - "source": [ - "! pip install apache_beam[gcp]>=2.53.0 --quiet" - ], + "execution_count": null, "metadata": { "id": "BTxob7d5DLBM" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "! pip install apache_beam[gcp]>=2.53.0 --quiet" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SkMhR7H6n1P0" + }, + "outputs": [], "source": [ "import tempfile\n", "import apache_beam as beam\n", "from apache_beam.ml.transforms.base import MLTransform\n", "from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAITextEmbeddings" - ], - "metadata": { - "id": "SkMhR7H6n1P0" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "cokOaX2kzyke" + }, "source": [ "## Transform the data\n", "\n", @@ -156,25 +146,27 @@ "\n", "### Use MLTransform in write mode\n", "\n", - "In `write` mode, `MLTransform` saves the transforms and their attributes to an artifact location. Then, when you run `MLTransform` in `read` mode, these transforms are used. This process ensures that you're applying the same preprocessing steps when you train your model and when you serve the model in production or test its accuracy." - ], - "metadata": { - "id": "cokOaX2kzyke" - } + "In `write` mode, `MLTransform` saves the transforms and their attributes to an artifact location. Then, when you run `MLTransform` in `read` mode, these transforms are used. This process ensures that you're applying the same preprocessing steps when you train your model and when you serve the model in production or test its accuracy." + ] }, { "cell_type": "markdown", + "metadata": { + "id": "-x7fVvuy-aDs" + }, "source": [ "### Get the data\n", "\n", "`MLTransform` processes dictionaries that include column names and their associated text data. To generate embeddings for specific columns, specify these column names in the `columns` argument of `VertexAITextEmbeddings`. This transform uses the the Vertex AI text-embeddings API for online predictions to generate an embeddings vector for each sentence." - ], - "metadata": { - "id": "-x7fVvuy-aDs" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "be-vR159pylF" + }, + "outputs": [], "source": [ "artifact_location = tempfile.mkdtemp(prefix='vertex_ai')\n", "\n", @@ -201,32 +193,11 @@ " for key in d.keys():\n", " d[key] = d[key][:10]\n", " return d" - ], - "metadata": { - "id": "be-vR159pylF" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "embedding_transform = VertexAITextEmbeddings(\n", - " model_name=text_embedding_model_name, columns=['x'], project=project)\n", - "\n", - "with beam.Pipeline() as pipeline:\n", - " data_pcoll = (\n", - " pipeline\n", - " | \"CreateData\" >> beam.Create(content))\n", - " transformed_pcoll = (\n", - " data_pcoll\n", - " | \"MLTransform\" >> MLTransform(write_artifact_location=artifact_location).with_transform(embedding_transform))\n", - "\n", - " # Show only the first ten elements of the embeddings to prevent clutter in the output.\n", - " transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)\n", - "\n", - " transformed_pcoll | \"PrintEmbeddingShape\" >> beam.Map(lambda x: print(f\"Embedding shape: {len(x['x'])}\"))" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -234,11 +205,10 @@ "id": "UQGm1be3p7lM", "outputId": "b41172ca-1c73-4952-ca87-bfe45ca88a6c" }, - "execution_count": null, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "{'x': [0.041293490678071976, -0.010302993468940258, -0.048611514270305634, -0.01360565796494484, 0.06441926211118698, 0.022573700174689293, 0.016446372494101524, -0.033894773572683334, 0.004581860266625881, 0.060710687190294266]}\n", "Embedding shape: 10\n", @@ -248,23 +218,58 @@ "Embedding shape: 10\n" ] } + ], + "source": [ + "embedding_transform = VertexAITextEmbeddings(\n", + " model_name=text_embedding_model_name, columns=['x'], project=project)\n", + "\n", + "with beam.Pipeline() as pipeline:\n", + " data_pcoll = (\n", + " pipeline\n", + " | \"CreateData\" >> beam.Create(content))\n", + " transformed_pcoll = (\n", + " data_pcoll\n", + " | \"MLTransform\" >> MLTransform(write_artifact_location=artifact_location).with_transform(embedding_transform))\n", + "\n", + " # Show only the first ten elements of the embeddings to prevent clutter in the output.\n", + " transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)\n", + "\n", + " transformed_pcoll | \"PrintEmbeddingShape\" >> beam.Map(lambda x: print(f\"Embedding shape: {len(x['x'])}\"))" ] }, { "cell_type": "markdown", + "metadata": { + "id": "JLkmQkiLx_6h" + }, "source": [ "### Use MLTransform in read mode\n", "\n", "In `read` mode, `MLTransform` uses the artifacts saved during `write` mode. In this example, the transform and its attributes are loaded from the saved artifacts. You don't need to specify artifacts again during `read` mode.\n", "\n", "In this way, `MLTransform` provides consistent preprocessing steps for training and inference workloads." - ], - "metadata": { - "id": "JLkmQkiLx_6h" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "r8Y5vgfLx_Xu", + "outputId": "e7cbf6b7-5c31-4efa-90cf-7a8a108ecc77" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'x': [0.04782044142484665, -0.010078949853777885, -0.05793016776442528, -0.026060665026307106, 0.05756739526987076, 0.02292264811694622, 0.014818413183093071, -0.03718176111578941, -0.005486017093062401, 0.04709304869174957]}\n", + "{'x': [0.042911216616630554, -0.007554919924587011, -0.08996245265007019, -0.02607591263949871, 0.0008614308317191899, -0.023671219125390053, 0.03999944031238556, -0.02983051724731922, -0.015057179145514965, 0.022963201627135277]}\n" + ] + } + ], "source": [ "test_content = [\n", " {\n", @@ -284,25 +289,21 @@ " | \"MLTransform\" >> MLTransform(read_artifact_location=artifact_location))\n", "\n", " transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)\n" - ], - "metadata": { - "id": "r8Y5vgfLx_Xu", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "e7cbf6b7-5c31-4efa-90cf-7a8a108ecc77" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'x': [0.04782044142484665, -0.010078949853777885, -0.05793016776442528, -0.026060665026307106, 0.05756739526987076, 0.02292264811694622, 0.014818413183093071, -0.03718176111578941, -0.005486017093062401, 0.04709304869174957]}\n", - "{'x': [0.042911216616630554, -0.007554919924587011, -0.08996245265007019, -0.02607591263949871, 0.0008614308317191899, -0.023671219125390053, 0.03999944031238556, -0.02983051724731922, -0.015057179145514965, 0.022963201627135277]}\n" - ] - } ] } - ] + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb b/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb index a488caf7d3ac..3af7455222a9 100644 --- a/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb +++ b/examples/notebooks/beam-ml/dataframe_api_preprocessing.ipynb @@ -2,6 +2,12 @@ "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "sARMhsXz8yR1" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -21,16 +27,13 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "sARMhsXz8yR1" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "A8xNRyZMW1yK" + }, "source": [ "# Preprocessing with the Apache Beam DataFrames API\n", "\n", @@ -44,13 +47,13 @@ " View source on GitHub\n", " \n", "\n" - ], - "metadata": { - "id": "A8xNRyZMW1yK" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "iFZC1inKuUCy" + }, "source": [ "For rapid execution, Pandas loads all of the data into memory on a single machine (one node). This configuration works well when dealing with small-scale datasets. However, many projects involve datasets that are too big to fit in memory. These use cases generally require parallel data processing frameworks, such as Apache Beam.\n", "\n", @@ -71,21 +74,18 @@ "\n", "In this example, the first section demonstrates how to build and execute a pipeline locally using the interactive runner.\n", "The second section uses a distributed runner to demonstrate how to run the pipeline on the full dataset.\n" - ], - "metadata": { - "id": "iFZC1inKuUCy" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "A0f2HJ22D4lt" + }, "source": [ "## Install Apache Beam\n", "\n", "To explore the elements within a `PCollection`, install Apache Beam with the `interactive` component to use the Interactive runner. The DataFrames API methods invoked in this example are available in Apache Beam SDK versions 2.43 and later.\n" - ], - "metadata": { - "id": "A0f2HJ22D4lt" - } + ] }, { "cell_type": "markdown", @@ -100,8 +100,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "-OJC0Xn5Um-C", - "beam:comment": "TODO(https://github.com/apache/issues/23961): Just install 2.43.0 once it's released, [`issue 23276`](https://github.com/apache/beam/issues/23276) is currently not implemented for Beam 2.42 (required fix for implementing `str.get_dummies()`" + "beam:comment": "TODO(https://github.com/apache/issues/23961): Just install 2.43.0 once it's released, [`issue 23276`](https://github.com/apache/beam/issues/23276) is currently not implemented for Beam 2.42 (required fix for implementing `str.get_dummies()`", + "id": "-OJC0Xn5Um-C" }, "outputs": [], "source": [ @@ -114,6 +114,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "3NO6RgB7GkkE" + }, "source": [ "## Local exploration with the Interactive Beam runner\n", "Use the [Interactive Beam](https://beam.apache.org/releases/pydoc/2.20.0/apache_beam.runners.interactive.interactive_beam.html) runner to explore and develop your pipeline.\n", @@ -121,10 +124,7 @@ "\n", "\n", "This section uses a subset of the original dataset, because the notebook instance has limited compute resources.\n" - ], - "metadata": { - "id": "3NO6RgB7GkkE" - } + ] }, { "cell_type": "markdown", @@ -186,13 +186,13 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "cvAu5T0ENjuQ" + }, "source": [ "\n", "Inspect the dataset columns and their types." - ], - "metadata": { - "id": "cvAu5T0ENjuQ" - } + ] }, { "cell_type": "code", @@ -206,7 +206,6 @@ }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "spk_id int64\n", @@ -225,8 +224,9 @@ "dtype: object" ] }, + "execution_count": 27, "metadata": {}, - "execution_count": 27 + "output_type": "execute_result" } ], "source": [ @@ -235,12 +235,12 @@ }, { "cell_type": "markdown", - "source": [ - "When using Interactive Beam, to bring a Beam DataFrame into local memory as a Pandas DataFrame, use `ib.collect()`." - ], "metadata": { "id": "1Wa6fpbyQige" - } + }, + "source": [ + "When using Interactive Beam, to bring a Beam DataFrame into local memory as a Pandas DataFrame, use `ib.collect()`." + ] }, { "cell_type": "code", @@ -255,11 +255,7 @@ }, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -268,101 +264,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_79206f341d7de09f6cacdd05be309575\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_79206f341d7de09f6cacdd05be309575\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_79206f341d7de09f6cacdd05be309575\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_79206f341d7de09f6cacdd05be309575\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " spk_id full_name near_earth_object \\\n", - "0 2000001 1 Ceres N \n", - "1 2000002 2 Pallas N \n", - "2 2000003 3 Juno N \n", - "3 2000004 4 Vesta N \n", - "4 2000005 5 Astraea N \n", - "... ... ... ... \n", - "9994 2009995 9995 Alouette (4805 P-L) N \n", - "9995 2009996 9996 ANS (9070 P-L) N \n", - "9996 2009997 9997 COBE (1217 T-1) N \n", - "9997 2009998 9998 ISO (1293 T-1) N \n", - "9998 2009999 9999 Wiles (4196 T-2) N \n", - "\n", - " absolute_magnitude diameter albedo diameter_sigma eccentricity \\\n", - "0 3.40 939.400 0.0900 0.200 0.076009 \n", - "1 4.20 545.000 0.1010 18.000 0.229972 \n", - "2 5.33 246.596 0.2140 10.594 0.256936 \n", - "3 3.00 525.400 0.4228 0.200 0.088721 \n", - "4 6.90 106.699 0.2740 3.140 0.190913 \n", - "... ... ... ... ... ... \n", - "9994 15.10 2.564 0.2450 0.550 0.160610 \n", - "9995 13.60 8.978 0.1130 0.376 0.235174 \n", - "9996 14.30 NaN NaN NaN 0.113059 \n", - "9997 15.10 2.235 0.3880 0.373 0.093852 \n", - "9998 13.00 7.148 0.2620 0.065 0.071351 \n", - "\n", - " inclination moid_ld object_class semi_major_axis_au_unit \\\n", - "0 10.594067 620.640533 MBA 2.769165 \n", - "1 34.832932 480.348639 MBA 2.773841 \n", - "2 12.991043 402.514639 MBA 2.668285 \n", - "3 7.141771 443.451432 MBA 2.361418 \n", - "4 5.367427 426.433027 MBA 2.574037 \n", - "... ... ... ... ... \n", - "9994 2.311731 388.723233 MBA 2.390249 \n", - "9995 7.657713 444.194746 MBA 2.796605 \n", - "9996 2.459643 495.460110 MBA 2.545674 \n", - "9997 3.912263 373.848377 MBA 2.160961 \n", - "9998 3.198839 632.144398 MBA 2.839917 \n", - "\n", - " hazardous_flag \n", - "0 N \n", - "1 N \n", - "2 N \n", - "3 N \n", - "4 N \n", - "... ... \n", - "9994 N \n", - "9995 N \n", - "9996 N \n", - "9997 N \n", - "9998 N \n", - "\n", - "[9999 rows x 13 columns]" - ], "text/html": [ "\n", "
\n", @@ -657,10 +575,66 @@ "
\n", " \n", " " + ], + "text/plain": [ + " spk_id full_name near_earth_object \\\n", + "0 2000001 1 Ceres N \n", + "1 2000002 2 Pallas N \n", + "2 2000003 3 Juno N \n", + "3 2000004 4 Vesta N \n", + "4 2000005 5 Astraea N \n", + "... ... ... ... \n", + "9994 2009995 9995 Alouette (4805 P-L) N \n", + "9995 2009996 9996 ANS (9070 P-L) N \n", + "9996 2009997 9997 COBE (1217 T-1) N \n", + "9997 2009998 9998 ISO (1293 T-1) N \n", + "9998 2009999 9999 Wiles (4196 T-2) N \n", + "\n", + " absolute_magnitude diameter albedo diameter_sigma eccentricity \\\n", + "0 3.40 939.400 0.0900 0.200 0.076009 \n", + "1 4.20 545.000 0.1010 18.000 0.229972 \n", + "2 5.33 246.596 0.2140 10.594 0.256936 \n", + "3 3.00 525.400 0.4228 0.200 0.088721 \n", + "4 6.90 106.699 0.2740 3.140 0.190913 \n", + "... ... ... ... ... ... \n", + "9994 15.10 2.564 0.2450 0.550 0.160610 \n", + "9995 13.60 8.978 0.1130 0.376 0.235174 \n", + "9996 14.30 NaN NaN NaN 0.113059 \n", + "9997 15.10 2.235 0.3880 0.373 0.093852 \n", + "9998 13.00 7.148 0.2620 0.065 0.071351 \n", + "\n", + " inclination moid_ld object_class semi_major_axis_au_unit \\\n", + "0 10.594067 620.640533 MBA 2.769165 \n", + "1 34.832932 480.348639 MBA 2.773841 \n", + "2 12.991043 402.514639 MBA 2.668285 \n", + "3 7.141771 443.451432 MBA 2.361418 \n", + "4 5.367427 426.433027 MBA 2.574037 \n", + "... ... ... ... ... \n", + "9994 2.311731 388.723233 MBA 2.390249 \n", + "9995 7.657713 444.194746 MBA 2.796605 \n", + "9996 2.459643 495.460110 MBA 2.545674 \n", + "9997 3.912263 373.848377 MBA 2.160961 \n", + "9998 3.198839 632.144398 MBA 2.839917 \n", + "\n", + " hazardous_flag \n", + "0 N \n", + "1 N \n", + "2 N \n", + "3 N \n", + "4 N \n", + "... ... \n", + "9994 N \n", + "9995 N \n", + "9996 N \n", + "9997 N \n", + "9998 N \n", + "\n", + "[9999 rows x 13 columns]" ] }, + "execution_count": 28, "metadata": {}, - "execution_count": 28 + "output_type": "execute_result" } ], "source": [ @@ -669,34 +643,29 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "8jV9odKhNyF2" + }, "source": [ "The datasets contain the following two types of columns:\n", "\n", "* **Numerical columns:** Use [normalization](https://developers.google.com/machine-learning/data-prep/transform/normalization) to transform these columns so that they can be used to train a machine learning model.\n", "\n", "* **Categorical columns:** Transform those columns with [one-hot encoding](https://developers.google.com/machine-learning/data-prep/transform/transform-categorical) to use them during training. \n" - ], - "metadata": { - "id": "8jV9odKhNyF2" - } + ] }, { "cell_type": "markdown", - "source": [ - "Use the standard pandas command `DataFrame.describe()` to generate descriptive statistics for the numerical columns, such as percentile, mean, std, and so on. " - ], "metadata": { "id": "MGAErO0lAYws" - } + }, + "source": [ + "Use the standard pandas command `DataFrame.describe()` to generate descriptive statistics for the numerical columns, such as percentile, mean, std, and so on. " + ] }, { "cell_type": "code", - "source": [ - "with dataframe.allow_non_parallel_operations():\n", - " beam_df_description = ib.collect(beam_df.describe())\n", - "\n", - "beam_df_description" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -705,14 +674,9 @@ "id": "Befv697VBGM7", "outputId": "bb465020-94e4-4b3c-fda6-6e43da199be1" }, - "execution_count": null, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -721,77 +685,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_98687cb0060a8077a8abab6e464e4a75\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_98687cb0060a8077a8abab6e464e4a75\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_98687cb0060a8077a8abab6e464e4a75\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_98687cb0060a8077a8abab6e464e4a75\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " spk_id absolute_magnitude diameter albedo \\\n", - "count 9.999000e+03 9999.000000 8688.000000 8672.000000 \n", - "mean 2.005000e+06 12.675380 19.245446 0.197723 \n", - "std 2.886607e+03 1.639609 30.190191 0.138819 \n", - "min 2.000001e+06 3.000000 0.300000 0.008000 \n", - "25% 2.002500e+06 11.900000 5.614000 0.074000 \n", - "50% 2.005000e+06 12.900000 9.814000 0.187000 \n", - "75% 2.007500e+06 13.700000 19.156750 0.283000 \n", - "max 2.009999e+06 20.700000 939.400000 1.000000 \n", - "\n", - " diameter_sigma eccentricity inclination moid_ld \\\n", - "count 8591.000000 9999.000000 9999.000000 9999.000000 \n", - "mean 0.454072 0.148716 7.890742 509.805237 \n", - "std 1.093676 0.083803 6.336244 205.046582 \n", - "min 0.006000 0.001003 0.042716 0.131028 \n", - "25% 0.120000 0.093780 3.220137 377.829197 \n", - "50% 0.201000 0.140335 6.018836 470.650523 \n", - "75% 0.375000 0.187092 10.918176 636.010802 \n", - "max 39.297000 0.889831 68.018875 4241.524913 \n", - "\n", - " semi_major_axis_au_unit \n", - "count 9999.000000 \n", - "mean 2.689836 \n", - "std 0.607190 \n", - "min 0.832048 \n", - "25% 2.340816 \n", - "50% 2.614468 \n", - "75% 3.005449 \n", - "max 24.667968 " - ], "text/html": [ "\n", "
\n", @@ -1001,27 +911,65 @@ "
\n", " \n", " " - ] - }, - "metadata": {}, - "execution_count": 21 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D9uJtHLSSAMC" - }, - "source": [ - "Before running any transformations, verify that all of the columns need to be used for model training. Start by looking at the column description provided by the [JPL website](https://ssd.jpl.nasa.gov/sbdb_query.cgi):\n", - "\n", - "* **spk_id:** Object primary SPK-ID.\n", - "* **full_name:** Asteroid name.\n", - "* **near_earth_object:** Near-earth object flag.\n", - "* **absolute_magnitude:** The apparent magnitude an object would have if it were located at a distance of 10 parsecs.\n", - "* **diameter:** Object diameter (from equivalent sphere) km unit.\n", - "* **albedo:** A measure of the diffuse reflection of solar radiation out of the total solar radiation, measured on a scale from 0 to 1.\n", + ], + "text/plain": [ + " spk_id absolute_magnitude diameter albedo \\\n", + "count 9.999000e+03 9999.000000 8688.000000 8672.000000 \n", + "mean 2.005000e+06 12.675380 19.245446 0.197723 \n", + "std 2.886607e+03 1.639609 30.190191 0.138819 \n", + "min 2.000001e+06 3.000000 0.300000 0.008000 \n", + "25% 2.002500e+06 11.900000 5.614000 0.074000 \n", + "50% 2.005000e+06 12.900000 9.814000 0.187000 \n", + "75% 2.007500e+06 13.700000 19.156750 0.283000 \n", + "max 2.009999e+06 20.700000 939.400000 1.000000 \n", + "\n", + " diameter_sigma eccentricity inclination moid_ld \\\n", + "count 8591.000000 9999.000000 9999.000000 9999.000000 \n", + "mean 0.454072 0.148716 7.890742 509.805237 \n", + "std 1.093676 0.083803 6.336244 205.046582 \n", + "min 0.006000 0.001003 0.042716 0.131028 \n", + "25% 0.120000 0.093780 3.220137 377.829197 \n", + "50% 0.201000 0.140335 6.018836 470.650523 \n", + "75% 0.375000 0.187092 10.918176 636.010802 \n", + "max 39.297000 0.889831 68.018875 4241.524913 \n", + "\n", + " semi_major_axis_au_unit \n", + "count 9999.000000 \n", + "mean 2.689836 \n", + "std 0.607190 \n", + "min 0.832048 \n", + "25% 2.340816 \n", + "50% 2.614468 \n", + "75% 3.005449 \n", + "max 24.667968 " + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with dataframe.allow_non_parallel_operations():\n", + " beam_df_description = ib.collect(beam_df.describe())\n", + "\n", + "beam_df_description" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D9uJtHLSSAMC" + }, + "source": [ + "Before running any transformations, verify that all of the columns need to be used for model training. Start by looking at the column description provided by the [JPL website](https://ssd.jpl.nasa.gov/sbdb_query.cgi):\n", + "\n", + "* **spk_id:** Object primary SPK-ID.\n", + "* **full_name:** Asteroid name.\n", + "* **near_earth_object:** Near-earth object flag.\n", + "* **absolute_magnitude:** The apparent magnitude an object would have if it were located at a distance of 10 parsecs.\n", + "* **diameter:** Object diameter (from equivalent sphere) km unit.\n", + "* **albedo:** A measure of the diffuse reflection of solar radiation out of the total solar radiation, measured on a scale from 0 to 1.\n", "* **diameter_sigma:** 1-sigma uncertainty in object diameter km unit.\n", "* **eccentricity:** A value between 0 and 1 that refers to how flat or round the asteroid is.\n", "* **inclination:** The angle with respect to the x-y ecliptic plane.\n", @@ -1073,19 +1021,15 @@ }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/content/beam/sdks/python/apache_beam/dataframe/frame_base.py:145: RuntimeWarning: invalid value encountered in long_scalars\n", " lambda left, right: getattr(left, op)(right), name=op, args=[other])\n" ] }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1094,45 +1038,22 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_868f8ad001ab00c7013b65472a513917\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_868f8ad001ab00c7013b65472a513917\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_868f8ad001ab00c7013b65472a513917\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_868f8ad001ab00c7013b65472a513917\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { "text/plain": [ "near_earth_object 0.000000\n", @@ -1149,8 +1070,9 @@ "dtype: float64" ] }, + "execution_count": 30, "metadata": {}, - "execution_count": 30 + "output_type": "execute_result" } ], "source": [ @@ -1170,20 +1092,16 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "tHYeCHREwvyB", "colab": { "base_uri": "https://localhost:8080/", "height": 538 }, + "id": "tHYeCHREwvyB", "outputId": "3be686d0-f56a-4054-a71a-d3019bf379e8" }, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1192,75 +1110,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_f88b77f183371d1a45fa87bed4a545f6\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_f88b77f183371d1a45fa87bed4a545f6\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_f88b77f183371d1a45fa87bed4a545f6\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_f88b77f183371d1a45fa87bed4a545f6\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " near_earth_object absolute_magnitude eccentricity inclination \\\n", - "0 N 3.40 0.076009 10.594067 \n", - "1 N 4.20 0.229972 34.832932 \n", - "2 N 5.33 0.256936 12.991043 \n", - "3 N 3.00 0.088721 7.141771 \n", - "4 N 6.90 0.190913 5.367427 \n", - "... ... ... ... ... \n", - "9994 N 15.10 0.160610 2.311731 \n", - "9995 N 13.60 0.235174 7.657713 \n", - "9996 N 14.30 0.113059 2.459643 \n", - "9997 N 15.10 0.093852 3.912263 \n", - "9998 N 13.00 0.071351 3.198839 \n", - "\n", - " moid_ld object_class semi_major_axis_au_unit hazardous_flag \n", - "0 620.640533 MBA 2.769165 N \n", - "1 480.348639 MBA 2.773841 N \n", - "2 402.514639 MBA 2.668285 N \n", - "3 443.451432 MBA 2.361418 N \n", - "4 426.433027 MBA 2.574037 N \n", - "... ... ... ... ... \n", - "9994 388.723233 MBA 2.390249 N \n", - "9995 444.194746 MBA 2.796605 N \n", - "9996 495.460110 MBA 2.545674 N \n", - "9997 373.848377 MBA 2.160961 N \n", - "9998 632.144398 MBA 2.839917 N \n", - "\n", - "[9999 rows x 8 columns]" - ], "text/html": [ "\n", "
\n", @@ -1495,10 +1361,40 @@ "
\n", " \n", " " + ], + "text/plain": [ + " near_earth_object absolute_magnitude eccentricity inclination \\\n", + "0 N 3.40 0.076009 10.594067 \n", + "1 N 4.20 0.229972 34.832932 \n", + "2 N 5.33 0.256936 12.991043 \n", + "3 N 3.00 0.088721 7.141771 \n", + "4 N 6.90 0.190913 5.367427 \n", + "... ... ... ... ... \n", + "9994 N 15.10 0.160610 2.311731 \n", + "9995 N 13.60 0.235174 7.657713 \n", + "9996 N 14.30 0.113059 2.459643 \n", + "9997 N 15.10 0.093852 3.912263 \n", + "9998 N 13.00 0.071351 3.198839 \n", + "\n", + " moid_ld object_class semi_major_axis_au_unit hazardous_flag \n", + "0 620.640533 MBA 2.769165 N \n", + "1 480.348639 MBA 2.773841 N \n", + "2 402.514639 MBA 2.668285 N \n", + "3 443.451432 MBA 2.361418 N \n", + "4 426.433027 MBA 2.574037 N \n", + "... ... ... ... ... \n", + "9994 388.723233 MBA 2.390249 N \n", + "9995 444.194746 MBA 2.796605 N \n", + "9996 495.460110 MBA 2.545674 N \n", + "9997 373.848377 MBA 2.160961 N \n", + "9998 632.144398 MBA 2.839917 N \n", + "\n", + "[9999 rows x 8 columns]" ] }, + "execution_count": 31, "metadata": {}, - "execution_count": 31 + "output_type": "execute_result" } ], "source": [ @@ -1559,19 +1455,15 @@ }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/content/beam/sdks/python/apache_beam/dataframe/frame_base.py:145: RuntimeWarning: invalid value encountered in double_scalars\n", " lambda left, right: getattr(left, op)(right), name=op, args=[other])\n" ] }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1580,75 +1472,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_55302fa5950ce6ceb9f99ff9a168097a\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_55302fa5950ce6ceb9f99ff9a168097a\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_55302fa5950ce6ceb9f99ff9a168097a\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_55302fa5950ce6ceb9f99ff9a168097a\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " absolute_magnitude eccentricity inclination moid_ld \\\n", - "306 -1.570727 -0.062543 -0.278518 0.373194 \n", - "310 -1.631718 -1.724526 -0.736389 1.087833 \n", - "546 -1.753698 1.028793 1.415303 -0.339489 \n", - "635 -1.875678 0.244869 0.005905 0.214107 \n", - "701 -3.278451 -1.570523 2.006145 1.542754 \n", - "... ... ... ... ... \n", - "9697 0.807888 -1.151809 -0.082944 -0.129556 \n", - "9813 1.722740 0.844551 -0.583247 -1.006447 \n", - "9868 0.807888 -0.207399 -0.784665 -0.462136 \n", - "9903 0.868878 0.460086 0.092258 -0.107597 \n", - "9956 0.746898 -0.234132 -0.161116 -0.601379 \n", - "\n", - " semi_major_axis_au_unit \n", - "306 0.357201 \n", - "310 0.344233 \n", - "546 0.139080 \n", - "635 0.367559 \n", - "701 0.829337 \n", - "... ... \n", - "9697 -0.533538 \n", - "9813 -0.677961 \n", - "9868 -0.539794 \n", - "9903 0.071794 \n", - "9956 -0.664887 \n", - "\n", - "[9999 rows x 5 columns]" - ], "text/html": [ "\n", "
\n", @@ -1847,10 +1687,40 @@ "
\n", " \n", " " + ], + "text/plain": [ + " absolute_magnitude eccentricity inclination moid_ld \\\n", + "306 -1.570727 -0.062543 -0.278518 0.373194 \n", + "310 -1.631718 -1.724526 -0.736389 1.087833 \n", + "546 -1.753698 1.028793 1.415303 -0.339489 \n", + "635 -1.875678 0.244869 0.005905 0.214107 \n", + "701 -3.278451 -1.570523 2.006145 1.542754 \n", + "... ... ... ... ... \n", + "9697 0.807888 -1.151809 -0.082944 -0.129556 \n", + "9813 1.722740 0.844551 -0.583247 -1.006447 \n", + "9868 0.807888 -0.207399 -0.784665 -0.462136 \n", + "9903 0.868878 0.460086 0.092258 -0.107597 \n", + "9956 0.746898 -0.234132 -0.161116 -0.601379 \n", + "\n", + " semi_major_axis_au_unit \n", + "306 0.357201 \n", + "310 0.344233 \n", + "546 0.139080 \n", + "635 0.367559 \n", + "701 0.829337 \n", + "... ... \n", + "9697 -0.533538 \n", + "9813 -0.677961 \n", + "9868 -0.539794 \n", + "9903 0.071794 \n", + "9956 -0.664887 \n", + "\n", + "[9999 rows x 5 columns]" ] }, + "execution_count": 33, "metadata": {}, - "execution_count": 33 + "output_type": "execute_result" } ], "source": [ @@ -1895,12 +1765,7 @@ }, { "cell_type": "code", - "source": [ - "for categorical_col in categorical_cols:\n", - " beam_df_categorical = get_one_hot_encoding(df=beam_df, categorical_col=categorical_col)\n", - " beam_df_numericals = beam_df_numericals.merge(beam_df_categorical, left_index = True, right_index = True)\n", - "ib.collect(beam_df_numericals)" - ], + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1909,14 +1774,9 @@ "id": "k9rvtWqHf6Qw", "outputId": "b8d8ae57-6dba-45b4-e7ae-e4b14084eede" }, - "execution_count": null, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1925,49 +1785,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6b2563c7f661bc0fc5729c2577d6f232\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6b2563c7f661bc0fc5729c2577d6f232\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6b2563c7f661bc0fc5729c2577d6f232\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6b2563c7f661bc0fc5729c2577d6f232\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -1976,49 +1810,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6fa896083b128ad99059af69a3d7fc7e\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6fa896083b128ad99059af69a3d7fc7e\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6fa896083b128ad99059af69a3d7fc7e\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6fa896083b128ad99059af69a3d7fc7e\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2027,49 +1835,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6339347de9805da541eba53abaee2d5e\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_6339347de9805da541eba53abaee2d5e\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6339347de9805da541eba53abaee2d5e\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_6339347de9805da541eba53abaee2d5e\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2078,127 +1860,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_1af5b908898a1e5949dcc20549f650eb\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_1af5b908898a1e5949dcc20549f650eb\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_1af5b908898a1e5949dcc20549f650eb\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_1af5b908898a1e5949dcc20549f650eb\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " absolute_magnitude eccentricity inclination moid_ld \\\n", - "0 -5.657067 -0.867596 0.426645 0.540537 \n", - "12 -3.583402 -0.756931 1.364340 0.238610 \n", - "47 -3.400432 -0.912290 -0.211925 1.136060 \n", - "381 -2.363599 0.271412 -0.078826 0.535299 \n", - "515 -2.729540 1.469775 0.799915 -0.602881 \n", - "... ... ... ... ... \n", - "9146 0.563927 -0.508757 -0.327512 -0.637391 \n", - "9657 1.478779 0.487849 -0.637779 -0.648240 \n", - "9704 0.380957 -0.238383 0.443053 0.670490 \n", - "9879 1.295809 -0.442966 -0.698505 -0.494818 \n", - "9980 0.746898 -1.455992 -0.849144 0.592902 \n", - "\n", - " semi_major_axis_au_unit near_earth_object_N near_earth_object_Y \\\n", - "0 0.130649 1 0 \n", - "12 -0.187375 1 0 \n", - "47 0.691182 1 0 \n", - "381 0.712755 1 0 \n", - "515 -0.014654 1 0 \n", - "... ... ... ... \n", - "9146 -0.820638 1 0 \n", - "9657 -0.468778 1 0 \n", - "9704 0.587128 1 0 \n", - "9879 -0.662602 1 0 \n", - "9980 -0.022726 1 0 \n", - "\n", - " near_earth_object_nan object_class_AMO object_class_APO ... \\\n", - "0 0 0 0 ... \n", - "12 0 0 0 ... \n", - "47 0 0 0 ... \n", - "381 0 0 0 ... \n", - "515 0 0 0 ... \n", - "... ... ... ... ... \n", - "9146 0 0 0 ... \n", - "9657 0 0 0 ... \n", - "9704 0 0 0 ... \n", - "9879 0 0 0 ... \n", - "9980 0 0 0 ... \n", - "\n", - " object_class_CEN object_class_IMB object_class_MBA object_class_MCA \\\n", - "0 0 0 1 0 \n", - "12 0 0 1 0 \n", - "47 0 0 1 0 \n", - "381 0 0 1 0 \n", - "515 0 0 1 0 \n", - "... ... ... ... ... \n", - "9146 0 0 1 0 \n", - "9657 0 0 1 0 \n", - "9704 0 0 1 0 \n", - "9879 0 0 1 0 \n", - "9980 0 0 1 0 \n", - "\n", - " object_class_OMB object_class_TJN object_class_nan hazardous_flag_N \\\n", - "0 0 0 0 1 \n", - "12 0 0 0 1 \n", - "47 0 0 0 1 \n", - "381 0 0 0 1 \n", - "515 0 0 0 1 \n", - "... ... ... ... ... \n", - "9146 0 0 0 1 \n", - "9657 0 0 0 1 \n", - "9704 0 0 0 1 \n", - "9879 0 0 0 1 \n", - "9980 0 0 0 1 \n", - "\n", - " hazardous_flag_Y hazardous_flag_nan \n", - "0 0 0 \n", - "12 0 0 \n", - "47 0 0 \n", - "381 0 0 \n", - "515 0 0 \n", - "... ... ... \n", - "9146 0 0 \n", - "9657 0 0 \n", - "9704 0 0 \n", - "9879 0 0 \n", - "9980 0 0 \n", - "\n", - "[9999 rows x 22 columns]" - ], "text/html": [ "\n", "
\n", @@ -2589,11 +2267,99 @@ "
\n", " \n", " " + ], + "text/plain": [ + " absolute_magnitude eccentricity inclination moid_ld \\\n", + "0 -5.657067 -0.867596 0.426645 0.540537 \n", + "12 -3.583402 -0.756931 1.364340 0.238610 \n", + "47 -3.400432 -0.912290 -0.211925 1.136060 \n", + "381 -2.363599 0.271412 -0.078826 0.535299 \n", + "515 -2.729540 1.469775 0.799915 -0.602881 \n", + "... ... ... ... ... \n", + "9146 0.563927 -0.508757 -0.327512 -0.637391 \n", + "9657 1.478779 0.487849 -0.637779 -0.648240 \n", + "9704 0.380957 -0.238383 0.443053 0.670490 \n", + "9879 1.295809 -0.442966 -0.698505 -0.494818 \n", + "9980 0.746898 -1.455992 -0.849144 0.592902 \n", + "\n", + " semi_major_axis_au_unit near_earth_object_N near_earth_object_Y \\\n", + "0 0.130649 1 0 \n", + "12 -0.187375 1 0 \n", + "47 0.691182 1 0 \n", + "381 0.712755 1 0 \n", + "515 -0.014654 1 0 \n", + "... ... ... ... \n", + "9146 -0.820638 1 0 \n", + "9657 -0.468778 1 0 \n", + "9704 0.587128 1 0 \n", + "9879 -0.662602 1 0 \n", + "9980 -0.022726 1 0 \n", + "\n", + " near_earth_object_nan object_class_AMO object_class_APO ... \\\n", + "0 0 0 0 ... \n", + "12 0 0 0 ... \n", + "47 0 0 0 ... \n", + "381 0 0 0 ... \n", + "515 0 0 0 ... \n", + "... ... ... ... ... \n", + "9146 0 0 0 ... \n", + "9657 0 0 0 ... \n", + "9704 0 0 0 ... \n", + "9879 0 0 0 ... \n", + "9980 0 0 0 ... \n", + "\n", + " object_class_CEN object_class_IMB object_class_MBA object_class_MCA \\\n", + "0 0 0 1 0 \n", + "12 0 0 1 0 \n", + "47 0 0 1 0 \n", + "381 0 0 1 0 \n", + "515 0 0 1 0 \n", + "... ... ... ... ... \n", + "9146 0 0 1 0 \n", + "9657 0 0 1 0 \n", + "9704 0 0 1 0 \n", + "9879 0 0 1 0 \n", + "9980 0 0 1 0 \n", + "\n", + " object_class_OMB object_class_TJN object_class_nan hazardous_flag_N \\\n", + "0 0 0 0 1 \n", + "12 0 0 0 1 \n", + "47 0 0 0 1 \n", + "381 0 0 0 1 \n", + "515 0 0 0 1 \n", + "... ... ... ... ... \n", + "9146 0 0 0 1 \n", + "9657 0 0 0 1 \n", + "9704 0 0 0 1 \n", + "9879 0 0 0 1 \n", + "9980 0 0 0 1 \n", + "\n", + " hazardous_flag_Y hazardous_flag_nan \n", + "0 0 0 \n", + "12 0 0 \n", + "47 0 0 \n", + "381 0 0 \n", + "515 0 0 \n", + "... ... ... \n", + "9146 0 0 \n", + "9657 0 0 \n", + "9704 0 0 \n", + "9879 0 0 \n", + "9980 0 0 \n", + "\n", + "[9999 rows x 22 columns]" ] }, + "execution_count": 35, "metadata": {}, - "execution_count": 35 + "output_type": "execute_result" } + ], + "source": [ + "for categorical_col in categorical_cols:\n", + " beam_df_categorical = get_one_hot_encoding(df=beam_df, categorical_col=categorical_col)\n", + " beam_df_numericals = beam_df_numericals.merge(beam_df_categorical, left_index = True, right_index = True)\n", + "ib.collect(beam_df_numericals)" ] }, { @@ -2613,28 +2379,24 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "ndaSNond0v8Q", "colab": { "base_uri": "https://localhost:8080/", "height": 651 }, + "id": "ndaSNond0v8Q", "outputId": "b265e915-e649-44e4-a31a-95ac85c0ebf6" }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/content/beam/sdks/python/apache_beam/dataframe/frame_base.py:145: RuntimeWarning: invalid value encountered in double_scalars\n", " lambda left, right: getattr(left, op)(right), name=op, args=[other])\n" ] }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2643,49 +2405,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_cb06c945824aa1bb68aa31ad7e601b74\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_cb06c945824aa1bb68aa31ad7e601b74\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_cb06c945824aa1bb68aa31ad7e601b74\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_cb06c945824aa1bb68aa31ad7e601b74\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2694,49 +2430,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fb923f80fecb72b4fa55e5cfdba16d23\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fb923f80fecb72b4fa55e5cfdba16d23\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fb923f80fecb72b4fa55e5cfdba16d23\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fb923f80fecb72b4fa55e5cfdba16d23\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2745,49 +2455,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_3f4b1a0f483cd017e004e11816a91d3b\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_3f4b1a0f483cd017e004e11816a91d3b\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_3f4b1a0f483cd017e004e11816a91d3b\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_3f4b1a0f483cd017e004e11816a91d3b\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -2796,127 +2480,23 @@ " Processing... collect\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fce8902eccbfaa17e32ba0c7c242ccec\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fce8902eccbfaa17e32ba0c7c242ccec\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fce8902eccbfaa17e32ba0c7c242ccec\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fce8902eccbfaa17e32ba0c7c242ccec\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " absolute_magnitude eccentricity inclination moid_ld \\\n", - "0 -5.657067 -0.867596 0.426645 0.540537 \n", - "12 -3.583402 -0.756931 1.364340 0.238610 \n", - "47 -3.400432 -0.912290 -0.211925 1.136060 \n", - "381 -2.363599 0.271412 -0.078826 0.535299 \n", - "515 -2.729540 1.469775 0.799915 -0.602881 \n", - "... ... ... ... ... \n", - "9146 0.563927 -0.508757 -0.327512 -0.637391 \n", - "9657 1.478779 0.487849 -0.637779 -0.648240 \n", - "9704 0.380957 -0.238383 0.443053 0.670490 \n", - "9879 1.295809 -0.442966 -0.698505 -0.494818 \n", - "9980 0.746898 -1.455992 -0.849144 0.592902 \n", - "\n", - " semi_major_axis_au_unit near_earth_object_N near_earth_object_Y \\\n", - "0 0.130649 1 0 \n", - "12 -0.187375 1 0 \n", - "47 0.691182 1 0 \n", - "381 0.712755 1 0 \n", - "515 -0.014654 1 0 \n", - "... ... ... ... \n", - "9146 -0.820638 1 0 \n", - "9657 -0.468778 1 0 \n", - "9704 0.587128 1 0 \n", - "9879 -0.662602 1 0 \n", - "9980 -0.022726 1 0 \n", - "\n", - " near_earth_object_nan object_class_AMO object_class_APO ... \\\n", - "0 0 0 0 ... \n", - "12 0 0 0 ... \n", - "47 0 0 0 ... \n", - "381 0 0 0 ... \n", - "515 0 0 0 ... \n", - "... ... ... ... ... \n", - "9146 0 0 0 ... \n", - "9657 0 0 0 ... \n", - "9704 0 0 0 ... \n", - "9879 0 0 0 ... \n", - "9980 0 0 0 ... \n", - "\n", - " object_class_CEN object_class_IMB object_class_MBA object_class_MCA \\\n", - "0 0 0 1 0 \n", - "12 0 0 1 0 \n", - "47 0 0 1 0 \n", - "381 0 0 1 0 \n", - "515 0 0 1 0 \n", - "... ... ... ... ... \n", - "9146 0 0 1 0 \n", - "9657 0 0 1 0 \n", - "9704 0 0 1 0 \n", - "9879 0 0 1 0 \n", - "9980 0 0 1 0 \n", - "\n", - " object_class_OMB object_class_TJN object_class_nan hazardous_flag_N \\\n", - "0 0 0 0 1 \n", - "12 0 0 0 1 \n", - "47 0 0 0 1 \n", - "381 0 0 0 1 \n", - "515 0 0 0 1 \n", - "... ... ... ... ... \n", - "9146 0 0 0 1 \n", - "9657 0 0 0 1 \n", - "9704 0 0 0 1 \n", - "9879 0 0 0 1 \n", - "9980 0 0 0 1 \n", - "\n", - " hazardous_flag_Y hazardous_flag_nan \n", - "0 0 0 \n", - "12 0 0 \n", - "47 0 0 \n", - "381 0 0 \n", - "515 0 0 \n", - "... ... ... \n", - "9146 0 0 \n", - "9657 0 0 \n", - "9704 0 0 \n", - "9879 0 0 \n", - "9980 0 0 \n", - "\n", - "[9999 rows x 22 columns]" - ], "text/html": [ "\n", "
\n", @@ -3307,10 +2887,92 @@ "
\n", " \n", " " + ], + "text/plain": [ + " absolute_magnitude eccentricity inclination moid_ld \\\n", + "0 -5.657067 -0.867596 0.426645 0.540537 \n", + "12 -3.583402 -0.756931 1.364340 0.238610 \n", + "47 -3.400432 -0.912290 -0.211925 1.136060 \n", + "381 -2.363599 0.271412 -0.078826 0.535299 \n", + "515 -2.729540 1.469775 0.799915 -0.602881 \n", + "... ... ... ... ... \n", + "9146 0.563927 -0.508757 -0.327512 -0.637391 \n", + "9657 1.478779 0.487849 -0.637779 -0.648240 \n", + "9704 0.380957 -0.238383 0.443053 0.670490 \n", + "9879 1.295809 -0.442966 -0.698505 -0.494818 \n", + "9980 0.746898 -1.455992 -0.849144 0.592902 \n", + "\n", + " semi_major_axis_au_unit near_earth_object_N near_earth_object_Y \\\n", + "0 0.130649 1 0 \n", + "12 -0.187375 1 0 \n", + "47 0.691182 1 0 \n", + "381 0.712755 1 0 \n", + "515 -0.014654 1 0 \n", + "... ... ... ... \n", + "9146 -0.820638 1 0 \n", + "9657 -0.468778 1 0 \n", + "9704 0.587128 1 0 \n", + "9879 -0.662602 1 0 \n", + "9980 -0.022726 1 0 \n", + "\n", + " near_earth_object_nan object_class_AMO object_class_APO ... \\\n", + "0 0 0 0 ... \n", + "12 0 0 0 ... \n", + "47 0 0 0 ... \n", + "381 0 0 0 ... \n", + "515 0 0 0 ... \n", + "... ... ... ... ... \n", + "9146 0 0 0 ... \n", + "9657 0 0 0 ... \n", + "9704 0 0 0 ... \n", + "9879 0 0 0 ... \n", + "9980 0 0 0 ... \n", + "\n", + " object_class_CEN object_class_IMB object_class_MBA object_class_MCA \\\n", + "0 0 0 1 0 \n", + "12 0 0 1 0 \n", + "47 0 0 1 0 \n", + "381 0 0 1 0 \n", + "515 0 0 1 0 \n", + "... ... ... ... ... \n", + "9146 0 0 1 0 \n", + "9657 0 0 1 0 \n", + "9704 0 0 1 0 \n", + "9879 0 0 1 0 \n", + "9980 0 0 1 0 \n", + "\n", + " object_class_OMB object_class_TJN object_class_nan hazardous_flag_N \\\n", + "0 0 0 0 1 \n", + "12 0 0 0 1 \n", + "47 0 0 0 1 \n", + "381 0 0 0 1 \n", + "515 0 0 0 1 \n", + "... ... ... ... ... \n", + "9146 0 0 0 1 \n", + "9657 0 0 0 1 \n", + "9704 0 0 0 1 \n", + "9879 0 0 0 1 \n", + "9980 0 0 0 1 \n", + "\n", + " hazardous_flag_Y hazardous_flag_nan \n", + "0 0 0 \n", + "12 0 0 \n", + "47 0 0 \n", + "381 0 0 \n", + "515 0 0 \n", + "... ... ... \n", + "9146 0 0 \n", + "9657 0 0 \n", + "9704 0 0 \n", + "9879 0 0 \n", + "9980 0 0 \n", + "\n", + "[9999 rows x 22 columns]" ] }, + "execution_count": 36, "metadata": {}, - "execution_count": 36 + "output_type": "execute_result" } ], "source": [ @@ -3356,31 +3018,36 @@ }, { "cell_type": "code", - "source": [ - "PROJECT_ID = \"\"\n", - "REGION = \"us-central1\"\n", - "TEMP_DIR = \"gs:///tmp\"\n", - "OUTPUT_DIR = \"gs:///dataframe-result\"" - ], + "execution_count": null, "metadata": { "id": "dDBYbMEWbL4t" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "PROJECT_ID = \"\" # @param {type:'string'}\n", + "REGION = \"us-central1\"\n", + "TEMP_DIR = \"gs:///tmp\" # @param {type:'string'}\n", + "OUTPUT_DIR = \"gs:///dataframe-result\" # @param {type:'string'}" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "Qk1GaYoSc9-1" + }, "source": [ "These steps process the full dataset, `full.csv`, which contains approximately one million rows. To materialize the deferred DataFrame, these steps also write the results to a CSV file instead of using `ib.collect()`.\n", "\n", "To switch from an interactive runner to a distributed runner, update the pipeline options. The rest of the pipeline steps don't change." - ], - "metadata": { - "id": "Qk1GaYoSc9-1" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1XovR0gKbMlK" + }, + "outputs": [], "source": [ "# Specify the location of the source CSV file (the full dataset).\n", "source_csv_file = 'gs://apache-beam-samples/nasa_jpl_asteroid/full.csv'\n", @@ -3417,44 +3084,42 @@ "\n", "# Write the preprocessed dataset to a CSV file.\n", "beam_df_numericals.to_csv(os.path.join(OUTPUT_DIR, \"preprocessed_data.csv\"))" - ], - "metadata": { - "id": "1XovR0gKbMlK" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Submit and run the pipeline." - ], "metadata": { "id": "a789u4Yecs_g" - } + }, + "source": [ + "Submit and run the pipeline." + ] }, { "cell_type": "code", - "source": [ - "p.run().wait_until_finish()" - ], + "execution_count": null, "metadata": { "id": "pbUlC102bPaZ" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "p.run().wait_until_finish()" + ] }, { "cell_type": "markdown", - "source": [ - "Wait while the pipeline job runs." - ], "metadata": { "id": "dzdqmzKzTOng" - } + }, + "source": [ + "Wait while the pipeline job runs." + ] }, { "cell_type": "markdown", + "metadata": { + "id": "UOLr6YgOOSVQ" + }, "source": [ "## What's next \n", "\n", @@ -3464,13 +3129,13 @@ "[Structured data classification from scratch](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/).\n", "\n", "To continue learning, find another dataset to use with the Apache Beam DataFrames API processing. Think carefully about which features to include in your model and how to represent them.\n" - ], - "metadata": { - "id": "UOLr6YgOOSVQ" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "nG9WXXVcMCe_" + }, "source": [ "## Resources\n", "\n", @@ -3479,10 +3144,7 @@ "* [10 minutes to Pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/10min.html) - A quickstart guide to the Pandas DataFrames.\n", "* [Pandas DataFrame API](https://pandas.pydata.org/pandas-docs/stable/reference/frame.html) - The API reference for the Pandas DataFrames.\n", "* [Data preparation and feature training in ML](https://developers.google.com/machine-learning/data-prep) - A guideline about data transformation for ML training." - ], - "metadata": { - "id": "nG9WXXVcMCe_" - } + ] } ], "metadata": { diff --git a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb index 686c19da7f66..1b20270f327a 100644 --- a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb +++ b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb @@ -367,7 +367,7 @@ "# options.view_as(WorkerOptions).disk_size_gb=200\n", "# options.view_as(GoogleCloudOptions).dataflow_service_options=[\"worker_accelerator=type:nvidia-l4;count:1;install-nvidia-driver\"]\n", "\n", - "topic_reviews=\"\"" + "topic_reviews=\"\" # @param {type:'string'}" ] }, { diff --git a/examples/notebooks/beam-ml/nlp_tensorflow_streaming.ipynb b/examples/notebooks/beam-ml/nlp_tensorflow_streaming.ipynb index f9a263e39030..6f5048e7e8ee 100644 --- a/examples/notebooks/beam-ml/nlp_tensorflow_streaming.ipynb +++ b/examples/notebooks/beam-ml/nlp_tensorflow_streaming.ipynb @@ -496,23 +496,23 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Epoch 1/10\n", "25/25 [==============================] - ETA: 0s - loss: 0.5931 - accuracy: 0.7650" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as _update_step_xla, lstm_cell_7_layer_call_fn, lstm_cell_7_layer_call_and_return_conditional_losses, lstm_cell_8_layer_call_fn, lstm_cell_8_layer_call_and_return_conditional_losses while saving (showing 5 of 9). These functions will not be directly callable after loading.\n" ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r25/25 [==============================] - 60s 2s/step - loss: 0.5931 - accuracy: 0.7650 - val_loss: 0.3625 - val_accuracy: 0.8900\n", "Epoch 2/10\n", @@ -536,14 +536,14 @@ ] }, { - "output_type": "execute_result", "data": { "text/plain": [ "" ] }, + "execution_count": 57, "metadata": {}, - "execution_count": 57 + "output_type": "execute_result" } ], "source": [ @@ -604,14 +604,14 @@ }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "" ] }, + "execution_count": 59, "metadata": {}, - "execution_count": 59 + "output_type": "execute_result" } ], "source": [ @@ -641,8 +641,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as _update_step_xla, lstm_cell_7_layer_call_fn, lstm_cell_7_layer_call_and_return_conditional_losses, lstm_cell_8_layer_call_fn, lstm_cell_8_layer_call_and_return_conditional_losses while saving (showing 5 of 9). These functions will not be directly callable after loading.\n" ] @@ -706,8 +706,10 @@ "source": [ "import os\n", "from google.cloud import pubsub_v1\n", - "PROJECT_ID = '' # Add your project ID here\n", - "TOPIC = '' # Add your topic name here\n", + "# Add your project ID here\n", + "PROJECT_ID = '' # @param {type:'string'}\n", + "# Add your topic name here\n", + "TOPIC = '' # @param {type:'string'}\n", "publisher = pubsub_v1.PublisherClient()\n", "topic_name = 'projects/{project_id}/topics/{topic}'.format(\n", " project_id = PROJECT_ID,\n", @@ -739,8 +741,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Can’t wait to watch you guys grow . Harmonies are on point and the oversized early 90’s blazers are a great touch.\n", "Amazing performance! Such an inspiring group ❤\n", @@ -908,8 +910,8 @@ }, "outputs": [], "source": [ - "# path to the topic\n", - "TOPIC_PATH = '' # Add the path to your topic here" + "# Add the path to your topic here\n", + "TOPIC_PATH = '' # @param {type:'string'}" ] }, { @@ -920,18 +922,18 @@ }, "outputs": [], "source": [ - "# path to the subscription\n", - "SUBS_PATH = '' # Add the path to your subscription here" + "# Add the path to your subscription here\n", + "SUBS_PATH = '' # @param {type:'string'}" ] }, { "cell_type": "markdown", - "source": [ - "Importing InteractiveRunner" - ], "metadata": { "id": "UliBhojEfxhq" - } + }, + "source": [ + "Importing InteractiveRunner" + ] }, { "cell_type": "code", @@ -986,8 +988,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Can’t wait to watch you guys grow . Harmonies are on point and the oversized early 90’s blazers are a great touch.\n", "Amazing performance! Such an inspiring group ❤\n", @@ -1048,7 +1050,6 @@ }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "[[[0.852806806564331, 0.14719319343566895], 'positive'],\n", @@ -1059,8 +1060,9 @@ " [[0.8648154735565186, 0.13518451154232025], 'positive']]" ] }, + "execution_count": 38, "metadata": {}, - "execution_count": 38 + "output_type": "execute_result" } ], "source": [ diff --git a/examples/notebooks/beam-ml/run_inference_pytorch.ipynb b/examples/notebooks/beam-ml/run_inference_pytorch.ipynb index eaf46be16bbd..93dd12dd20ab 100644 --- a/examples/notebooks/beam-ml/run_inference_pytorch.ipynb +++ b/examples/notebooks/beam-ml/run_inference_pytorch.ipynb @@ -1,22 +1,13 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "code", + "execution_count": 3, + "metadata": { + "cellView": "form", + "id": "C1rAsD2L-hSO" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -36,13 +27,7 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "C1rAsD2L-hSO" - }, - "execution_count": 3, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -95,23 +80,23 @@ }, { "cell_type": "code", - "source": [ - "!pip install apache_beam[gcp,dataframe] --quiet" - ], + "execution_count": null, "metadata": { "id": "loxD-rOVchRn" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "!pip install apache_beam[gcp,dataframe] --quiet" + ] }, { "cell_type": "code", "execution_count": 39, "metadata": { - "id": "7f841596-f217-46d2-b64e-1952db4de4cb", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "7f841596-f217-46d2-b64e-1952db4de4cb", "outputId": "09e0026a-cf8e-455c-9580-bfaef44683ce" }, "outputs": [], @@ -151,15 +136,15 @@ }, { "cell_type": "code", - "source": [ - "from google.colab import auth\n", - "auth.authenticate_user()" - ], + "execution_count": 41, "metadata": { "id": "V0E35R5Ka2cE" }, - "execution_count": 41, - "outputs": [] + "outputs": [], + "source": [ + "from google.colab import auth\n", + "auth.authenticate_user()" + ] }, { "cell_type": "code", @@ -170,8 +155,8 @@ "outputs": [], "source": [ "# Constants\n", - "project = \"\"\n", - "bucket = \"\"\n", + "project = \"\" # @param {type:'string'}\n", + "bucket = \"\" # @param {type:'string'}\n", "\n", "# To avoid warnings, set the project.\n", "os.environ['GOOGLE_CLOUD_PROJECT'] = project\n", @@ -183,8 +168,8 @@ { "cell_type": "markdown", "metadata": { - "tags": [], - "id": "b2b7cedc-79f5-4599-8178-e5da35dba032" + "id": "b2b7cedc-79f5-4599-8178-e5da35dba032", + "tags": [] }, "source": [ "## Create data and PyTorch models for the RunInference transform\n", @@ -294,16 +279,16 @@ "cell_type": "code", "execution_count": 46, "metadata": { - "id": "882bbada-4f6d-4370-a047-c5961e564ee8", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "882bbada-4f6d-4370-a047-c5961e564ee8", "outputId": "ab7242a9-76eb-4760-d74e-c725261e2a34" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "True\n" ] @@ -384,16 +369,16 @@ "cell_type": "code", "execution_count": 49, "metadata": { - "id": "42b2ca0f-5d44-4d15-a313-f3d56ae7f675", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "42b2ca0f-5d44-4d15-a313-f3d56ae7f675", "outputId": "9cb2f268-a500-4ad5-a075-856c87b8e3be" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "True\n" ] @@ -430,16 +415,16 @@ "cell_type": "code", "execution_count": 50, "metadata": { - "id": "e488a821-3b70-4284-96f3-ddee4dcb9d71", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "e488a821-3b70-4284-96f3-ddee4dcb9d71", "outputId": "add9af31-1cc6-496f-a6e4-3fb185c0de25" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "PredictionResult(example=tensor([20.]), inference=tensor([102.0095], grad_fn=))\n", "PredictionResult(example=tensor([40.]), inference=tensor([201.2056], grad_fn=))\n", @@ -483,16 +468,16 @@ "cell_type": "code", "execution_count": 51, "metadata": { - "id": "96f38a5a-4db0-4c39-8ce7-80d9f9911b48", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "96f38a5a-4db0-4c39-8ce7-80d9f9911b48", "outputId": "b1d689a2-9336-40b2-a984-538bec888cc9" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "input is 20.0 output is 102.00947570800781\n", "input is 40.0 output is 201.20559692382812\n", @@ -576,7 +561,7 @@ " yield (f\"key: {key}, input: {input_value.item()} output: {output_value.item()}\" )" ] }, - { + { "cell_type": "markdown", "metadata": { "id": "f22da313-5bf8-4334-865b-bbfafc374e63" @@ -592,7 +577,7 @@ "id": "c9b0fb49-d605-4f26-931a-57f42b0ad253" }, "source": [ - "#### Use BigQuery as the source", + "#### Use BigQuery as the source\n", "Follow these steps to use BigQuery as your source." ] }, @@ -627,47 +612,47 @@ }, { "cell_type": "code", - "source": [ - "!gcloud config set project $project" - ], + "execution_count": 54, "metadata": { - "id": "7mgnryX-Zlfs", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "7mgnryX-Zlfs", "outputId": "6e608e98-8369-45aa-c983-e62296202c52" }, - "execution_count": 54, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Updated property [core/project].\n" ] } + ], + "source": [ + "!gcloud config set project $project" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { - "id": "a6a984cd-2e92-4c44-821b-9bf1dd52fb7d", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "a6a984cd-2e92-4c44-821b-9bf1dd52fb7d", "outputId": "a50ab0fd-4f4e-4493-b506-41d3f7f08966" }, "outputs": [ { - "output_type": "execute_result", "data": { "text/plain": [ "" ] }, + "execution_count": 55, "metadata": {}, - "execution_count": 55 + "output_type": "execute_result" } ], "source": [ @@ -715,16 +700,16 @@ "cell_type": "code", "execution_count": 56, "metadata": { - "id": "34331897-23f5-4850-8974-67e522e956dc", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "34331897-23f5-4850-8974-67e522e956dc", "outputId": "9d2b0ba5-97a2-46bf-c9d3-e023afbd3122" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "key: third_question, input: 1000.0 output: 4962.61962890625\n", "key: second_question, input: 108.0 output: 538.472412109375\n", @@ -761,7 +746,7 @@ "id": "53ee7f24-5625-475a-b8cc-9c031591f304" }, "source": [ - "#### Use a CSV file as the source", + "#### Use a CSV file as the source\n", "Follow these steps to use a CSV file as your source." ] }, @@ -776,6 +761,11 @@ }, { "cell_type": "code", + "execution_count": 62, + "metadata": { + "id": "exAZjP7cYAFv" + }, + "outputs": [], "source": [ "# creates a CSV file with the values.\n", "csv_values = [(\"first_question\", 105.00),\n", @@ -791,27 +781,22 @@ " writer.writerow(row)\n", "\n", "assert os.path.exists(input_csv_file) == True" - ], - "metadata": { - "id": "exAZjP7cYAFv" - }, - "execution_count": 62, - "outputs": [] + ] }, { "cell_type": "code", "execution_count": 66, "metadata": { - "id": "9a054c2d-4d84-4b37-b067-1dda5347e776", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "9a054c2d-4d84-4b37-b067-1dda5347e776", "outputId": "2f2ea8b7-b425-48ae-e857-fe214c7eced2" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "key: first_question, input: 105.0 output: 523.5929565429688\n", "key: second_question, input: 108.0 output: 538.472412109375\n", @@ -890,16 +875,16 @@ "cell_type": "code", "execution_count": 68, "metadata": { - "id": "629d070e-9902-42c9-a1e7-56c3d1864f13", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "629d070e-9902-42c9-a1e7-56c3d1864f13", "outputId": "0b4d7f3c-4696-422f-b031-ee5a03e90e03" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "key: third_question * 10, input: 1000.0 output: 9889.59765625\n", "key: second_question * 10, input: 108.0 output: 1075.4891357421875\n", @@ -966,16 +951,16 @@ "cell_type": "code", "execution_count": 69, "metadata": { - "id": "8db9d649-5549-4b58-a9ad-7b8592c2bcbf", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "8db9d649-5549-4b58-a9ad-7b8592c2bcbf", "outputId": "328ba32b-40d4-445b-8b4e-5568258b8a26" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "key: original input is `third_question tensor([1000.])`, input: 4962.61962890625 output: 49045.37890625\n", "key: original input is `second_question tensor([108.])`, input: 538.472412109375 output: 5329.11083984375\n", @@ -1015,5 +1000,20 @@ " inference_result | beam.Map(print)" ] } - ] + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/run_inference_sklearn.ipynb b/examples/notebooks/beam-ml/run_inference_sklearn.ipynb index cf896a18981a..1b76f76df292 100644 --- a/examples/notebooks/beam-ml/run_inference_sklearn.ipynb +++ b/examples/notebooks/beam-ml/run_inference_sklearn.ipynb @@ -1,22 +1,13 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "C1rAsD2L-hSO" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -36,13 +27,7 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "C1rAsD2L-hSO" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -87,24 +72,20 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "zzwnMzzgdyPB" + }, "source": [ "## Before you begin\n", "Complete the following setup steps:\n", "1. Install dependencies for Apache Beam.\n", "1. Authenticate with Google Cloud.\n", "1. Specify your project and bucket. You use the project and bucket to save and load models." - ], - "metadata": { - "id": "zzwnMzzgdyPB" - } + ] }, { "cell_type": "code", - "source": [ - "!pip install google-api-core --quiet\n", - "!pip install google-cloud-pubsub google-cloud-bigquery-storage --quiet\n", - "!pip install apache-beam[gcp,dataframe] --quiet" - ], + "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -112,8 +93,12 @@ "id": "6vlKcT-Wev20", "outputId": "336e8afc-6716-41dd-a438-500353189c62" }, - "execution_count": 1, - "outputs": [] + "outputs": [], + "source": [ + "!pip install google-api-core --quiet\n", + "!pip install google-cloud-pubsub google-cloud-bigquery-storage --quiet\n", + "!pip install apache-beam[gcp,dataframe] --quiet" + ] }, { "cell_type": "markdown", @@ -128,15 +113,15 @@ }, { "cell_type": "code", - "source": [ - "from google.colab import auth\n", - "auth.authenticate_user()" - ], + "execution_count": 2, "metadata": { "id": "V0E35R5Ka2cE" }, - "execution_count": 2, - "outputs": [] + "outputs": [], + "source": [ + "from google.colab import auth\n", + "auth.authenticate_user()" + ] }, { "cell_type": "code", @@ -174,8 +159,8 @@ "import os\n", "\n", "# Constants\n", - "project = \"\"\n", - "bucket = \"\" \n", + "project = \"\" # @param {type:'string'}\n", + "bucket = \"\" # @param {type:'string'}\n", "\n", "# To avoid warnings, set the project.\n", "os.environ['GOOGLE_CLOUD_PROJECT'] = project\n" @@ -240,20 +225,18 @@ }, { "cell_type": "code", - "source": [ - "%pip install --upgrade google-cloud-bigquery --quiet" - ], + "execution_count": 9, "metadata": { "id": "AEGaqpMVqgRP" }, - "execution_count": 9, - "outputs": [] + "outputs": [], + "source": [ + "%pip install --upgrade google-cloud-bigquery --quiet" + ] }, { "cell_type": "code", - "source": [ - "!gcloud config set project $project" - ], + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -261,19 +244,41 @@ "id": "xq5AKtRrqlUx", "outputId": "fba8fb42-4958-451a-8aaa-9a838052a2f8" }, - "execution_count": 10, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Updated property [core/project].\n" ] } + ], + "source": [ + "!gcloud config set project $project" ] }, { "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QCIjN__rpoVF", + "outputId": "0ded224f-2272-482e-80f5-bb2d21b6f5d8" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Populated BigQuery table\n", "\n", @@ -306,42 +311,22 @@ "\n", "create_job = client.query(query)\n", "create_job.result()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "QCIjN__rpoVF", - "outputId": "0ded224f-2272-482e-80f5-bb2d21b6f5d8" - }, - "execution_count": 22, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 22 - } ] }, { "cell_type": "code", "execution_count": 23, "metadata": { - "id": "50a648a3-794a-4286-ab2b-fc0458db04ca", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "50a648a3-794a-4286-ab2b-fc0458db04ca", "outputId": "8eab34b4-dcc7-4df1-ec0e-8c86a34d31c6" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "PredictionResult(example=[1000.0], inference=array([5000.]))\n", "PredictionResult(example=[1013.0], inference=array([5065.]))\n", @@ -388,16 +373,16 @@ "cell_type": "code", "execution_count": 25, "metadata": { - "id": "c212916d-b517-4589-ad15-a3a1df926fb3", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "c212916d-b517-4589-ad15-a3a1df926fb3", "outputId": "61db2d76-4dfa-4b38-cf9a-645790b4c5aa" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "('third_example', PredictionResult(example=[1000.0], inference=array([5000.])))\n", "('fourth_example', PredictionResult(example=[1013.0], inference=array([5065.])))\n", @@ -424,17 +409,41 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "JQ4zvlwsRK1W" + }, "source": [ "## Run multiple models\n", "\n", "This code creates a pipeline that takes two RunInference transforms with different models and then combines the output." - ], - "metadata": { - "id": "JQ4zvlwsRK1W" - } + ] }, { "cell_type": "code", + "execution_count": 86, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0qMlX6SeR68D", + "outputId": "5e4a0852-3761-47da-aa08-0386fd524a78" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "key = third_example * 10, example = 1000.0 -> predictions 10000.0\n", + "key = fourth_example * 10, example = 1013.0 -> predictions 10130.0\n", + "key = second_example * 10, example = 108.0 -> predictions 1080.0\n", + "key = first_example * 10, example = 105.0 -> predictions 1050.0\n", + "key = third_example * 5, example = 1000.0 -> predictions 5000.0\n", + "key = fourth_example * 5, example = 1013.0 -> predictions 5065.0\n", + "key = second_example * 5, example = 108.0 -> predictions 540.0\n", + "key = first_example * 5, example = 105.0 -> predictions 525.0\n" + ] + } + ], "source": [ "from typing import Tuple\n", "\n", @@ -464,31 +473,22 @@ " _ = ((five_times, ten_times) | \"Flattened\" >> beam.Flatten()\n", " | \"format output\" >> beam.Map(format_output)\n", " | \"Print\" >> beam.Map(print))\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0qMlX6SeR68D", - "outputId": "5e4a0852-3761-47da-aa08-0386fd524a78" - }, - "execution_count": 86, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "key = third_example * 10, example = 1000.0 -> predictions 10000.0\n", - "key = fourth_example * 10, example = 1013.0 -> predictions 10130.0\n", - "key = second_example * 10, example = 108.0 -> predictions 1080.0\n", - "key = first_example * 10, example = 105.0 -> predictions 1050.0\n", - "key = third_example * 5, example = 1000.0 -> predictions 5000.0\n", - "key = fourth_example * 5, example = 1013.0 -> predictions 5065.0\n", - "key = second_example * 5, example = 108.0 -> predictions 540.0\n", - "key = first_example * 5, example = 105.0 -> predictions 525.0\n" - ] - } ] } - ] + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb b/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb index ad5bb671cce2..c15e9b21ecf9 100644 --- a/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb +++ b/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb @@ -168,8 +168,8 @@ "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor\n", "from apache_beam.options.pipeline_options import PipelineOptions\n", "\n", - "project = \"PROJECT_ID\"\n", - "bucket = \"BUCKET_NAME\"\n", + "project = \"PROJECT_ID\" # @param {type:'string'}\n", + "bucket = \"BUCKET_NAME\" # @param {type:'string'}\n", "\n", "save_model_dir_multiply = f'gs://{bucket}/tf-inference/model/multiply_five/v1/'\n", "save_weights_dir_multiply = f'gs://{bucket}/tf-inference/weights/multiply_five/v1/'\n" diff --git a/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb b/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb index 9a9c6f5d6e92..2c2f6460651b 100644 --- a/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb +++ b/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb @@ -1,32 +1,13 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "collapsed_sections": [ - "X80jy3FqHjK4", - "40qtP6zJuMXm", - "YzvZWEv-1oiK", - "rIwD_qEpX7Gu", - "O_a0-4Gb19cy", - "G-sAu3cf31f3", - "r4dpR6dQ4JwX", - "P2UMmbNW4YQV" - ] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "fFjof1NgAJwu" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -46,13 +27,7 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "fFjof1NgAJwu" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -74,6 +49,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "HrCtxslBGK8Z" + }, "source": [ "This notebook demonstrates how to use the Apache Beam [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform with TensorFlow and [TFX Basic Shared Libraries](https://github.com/tensorflow/tfx-bsl) (`tfx-bsl`).\n", "\n", @@ -89,69 +67,69 @@ "- Use the `tfx-bsl` model handler with the example data, and get a prediction inside an Apache Beam pipeline.\n", "\n", "For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation." - ], - "metadata": { - "id": "HrCtxslBGK8Z" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "HrCtxslBGK8A" + }, "source": [ "## Before you begin\n", "Set up your environment and download dependencies." - ], - "metadata": { - "id": "HrCtxslBGK8A" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "HrCtxslBGK8A" + }, "source": [ "### Import `tfx-bsl`\n", "First, import `tfx-bsl`.\n", "Creating a model handler is supported in `tfx-bsl` versions 1.10 and later." - ], - "metadata": { - "id": "HrCtxslBGK8A" - } + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "jBakpNZnAhqk" }, + "outputs": [], "source": [ "!pip install tfx_bsl==1.10.0 --quiet\n", "!pip install protobuf --quiet\n", "!pip install apache_beam --quiet" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "X80jy3FqHjK4" + }, "source": [ "### Authenticate with Google Cloud\n", "This notebook relies on saving your model to Google Cloud. To use your Google Cloud account, authenticate this notebook." - ], - "metadata": { - "id": "X80jy3FqHjK4" - } + ] }, { "cell_type": "code", + "execution_count": 2, "metadata": { "id": "Kz9sccyGBqz3" }, + "outputs": [], "source": [ "from google.colab import auth\n", "auth.authenticate_user()" - ], - "execution_count": 2, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "40qtP6zJuMXm" + }, "source": [ "### Import dependencies and set up your bucket\n", "Use the following code to import dependencies and to set up your Google Cloud Storage bucket.\n", @@ -159,16 +137,15 @@ "Replace `PROJECT_ID` and `BUCKET_NAME` with the ID of your project and the name of your bucket.\n", "\n", "**Important**: If an error occurs, restart your runtime." - ], - "metadata": { - "id": "40qtP6zJuMXm" - } + ] }, { "cell_type": "code", + "execution_count": 12, "metadata": { "id": "eEle839_Akqx" }, + "outputs": [], "source": [ "import argparse\n", "\n", @@ -190,24 +167,22 @@ "\n", "from apache_beam.options.pipeline_options import PipelineOptions\n", "\n", - "project = \"PROJECT_ID\"\n", - "bucket = \"BUCKET_NAME\"\n", + "project = \"PROJECT_ID\" # @param {type:'string'}\n", + "bucket = \"BUCKET_NAME\" # @param {type:'string'}\n", "\n", "save_model_dir_multiply = f'gs://{bucket}/tfx-inference/model/multiply_five/v1/'\n" - ], - "execution_count": 12, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "YzvZWEv-1oiK" + }, "source": [ "## Create and test a simple model\n", "\n", "This section creates and tests a model that predicts the 5 times multiplication table." - ], - "metadata": { - "id": "YzvZWEv-1oiK" - } + ] }, { "cell_type": "markdown", @@ -221,6 +196,7 @@ }, { "cell_type": "code", + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -228,26 +204,10 @@ "id": "SH7iq3zeBBJ-", "outputId": "c5adb7ec-285b-401e-f9be-1e9b83c6d0ba" }, - "source": [ - "# Create training data that represents the 5 times multiplication table for the numbers 0 to 99.\n", - "# x is the data and y is the labels.\n", - "x = numpy.arange(0, 100) # Examples\n", - "y = x * 5 # Labels\n", - "\n", - "# Build a simple linear regression model.\n", - "# Note that the model has a shape of (1) for its input layer and expects a single int64 value.\n", - "input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')\n", - "output_layer= keras.layers.Dense(1)(input_layer)\n", - "\n", - "model = keras.Model(input_layer, output_layer)\n", - "model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n", - "model.summary()" - ], - "execution_count": 4, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Model: \"model\"\n", "_________________________________________________________________\n", @@ -264,21 +224,37 @@ "_________________________________________________________________\n" ] } + ], + "source": [ + "# Create training data that represents the 5 times multiplication table for the numbers 0 to 99.\n", + "# x is the data and y is the labels.\n", + "x = numpy.arange(0, 100) # Examples\n", + "y = x * 5 # Labels\n", + "\n", + "# Build a simple linear regression model.\n", + "# Note that the model has a shape of (1) for its input layer and expects a single int64 value.\n", + "input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')\n", + "output_layer= keras.layers.Dense(1)(input_layer)\n", + "\n", + "model = keras.Model(input_layer, output_layer)\n", + "model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n", + "model.summary()" ] }, { "cell_type": "markdown", + "metadata": { + "id": "O_a0-4Gb19cy" + }, "source": [ "### Test the model\n", "\n", "This step tests the model that you created." - ], - "metadata": { - "id": "O_a0-4Gb19cy" - } + ] }, { "cell_type": "code", + "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -286,20 +262,10 @@ "id": "5XkIYXhJBFmS", "outputId": "e3bb5079-5cb8-4fe4-eb8d-d3d13d5f9f0c" }, - "source": [ - "model.fit(x, y, epochs=500, verbose=0)\n", - "test_examples =[20, 40, 60, 90]\n", - "value_to_predict = numpy.array(test_examples, dtype=numpy.float32)\n", - "predictions = model.predict(value_to_predict)\n", - "\n", - "print('Test Examples ' + str(test_examples))\n", - "print('Predictions ' + str(predictions))" - ], - "execution_count": 6, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "1/1 [==============================] - 0s 94ms/step\n", "Test Examples [20, 40, 60, 90]\n", @@ -309,10 +275,22 @@ " [34.41496 ]]\n" ] } + ], + "source": [ + "model.fit(x, y, epochs=500, verbose=0)\n", + "test_examples =[20, 40, 60, 90]\n", + "value_to_predict = numpy.array(test_examples, dtype=numpy.float32)\n", + "predictions = model.predict(value_to_predict)\n", + "\n", + "print('Test Examples ' + str(test_examples))\n", + "print('Predictions ' + str(predictions))" ] }, { "cell_type": "markdown", + "metadata": { + "id": "dEmleqiH3t71" + }, "source": [ "## RunInference with Tensorflow using `tfx-bsl`\n", "In versions 1.10.0 and later of `tfx-bsl`, you can\n", @@ -321,16 +299,15 @@ "### Populate the data in a TensorFlow proto\n", "\n", "Tensorflow data uses protos. If you are loading from a file, helpers exist for this step. Because this example uses generated data, this code populates a proto." - ], - "metadata": { - "id": "dEmleqiH3t71" - } + ] }, { "cell_type": "code", + "execution_count": 7, "metadata": { "id": "XvKc9kQilPjx" }, + "outputs": [], "source": [ "# This example shows a proto that converts the samples and labels into\n", "# tensors usable by TensorFlow.\n", @@ -371,23 +348,22 @@ " for i in value_to_predict:\n", " example = ExampleProcessor().create_example(feature=i)\n", " writer.write(example.SerializeToString())" - ], - "execution_count": 7, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "G-sAu3cf31f3" + }, "source": [ "### Fit the model\n", "\n", "This step builds a model. Because RunInference requires pretrained models, this segment builds a usable model." - ], - "metadata": { - "id": "G-sAu3cf31f3" - } + ] }, { "cell_type": "code", + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -395,6 +371,18 @@ "id": "AnbrxXPKeAOQ", "outputId": "42439aac-3a10-4e86-829f-44332aad6173" }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "RAW_DATA_TRAIN_SPEC = {\n", "'x': tf.io.FixedLenFeature([], tf.float32),\n", @@ -408,37 +396,26 @@ "dataset = dataset.repeat()\n", "\n", "model.fit(dataset, epochs=5000, steps_per_epoch=1, verbose=0)" - ], - "execution_count": 8, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 8 - } ] }, { "cell_type": "markdown", + "metadata": { + "id": "r4dpR6dQ4JwX" + }, "source": [ "### Save the model\n", "\n", "This step shows how to save your model." - ], - "metadata": { - "id": "r4dpR6dQ4JwX" - } + ] }, { "cell_type": "code", + "execution_count": 9, "metadata": { "id": "fYvrIYO3qiJx" }, + "outputs": [], "source": [ "RAW_DATA_PREDICT_SPEC = {\n", "'x': tf.io.FixedLenFeature([], tf.float32),\n", @@ -461,25 +438,24 @@ "# programs that consume SavedModels, such as serving APIs.\n", "# See https://www.tensorflow.org/api_docs/python/tf/saved_model/save\n", "tf.keras.models.save_model(model, save_model_dir_multiply, signatures=signature)" - ], - "execution_count": 9, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "P2UMmbNW4YQV" + }, "source": [ "## Run the pipeline\n", "Use the following code to run the pipeline.\n", "\n", "* `FormatOutput` demonstrates how to extract values from the output protos.\n", "* `CreateModelHandler` demonstrates the model handler that needs to be passed into the Apache Beam RunInference API." - ], - "metadata": { - "id": "P2UMmbNW4YQV" - } + ] }, { "cell_type": "code", + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -488,72 +464,24 @@ "id": "PzjmXM_KvqHY", "outputId": "0aa60bef-52a0-4ce2-d228-3fac977d59e0" }, - "source": [ - "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n", - "\n", - "class FormatOutput(beam.DoFn):\n", - " def process(self, element: prediction_log_pb2.PredictionLog):\n", - " predict_log = element.predict_log\n", - " input_value = tf.train.Example.FromString(predict_log.request.inputs['examples'].string_val[0])\n", - " input_float_value = input_value.features.feature['x'].float_list.value[0]\n", - " output_value = predict_log.response.outputs\n", - " output_float_value = output_value['output_0'].float_val[0]\n", - " yield (f\"example is {input_float_value:.2f} prediction is {output_float_value:.2f}\")\n", - "\n", - "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n", - "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n", - "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n", - "model_handler = CreateModelHandler(inference_spec_type)\n", - "with beam.Pipeline() as p:\n", - " _ = (p | tfexample_beam_record.RawRecordBeamSource()\n", - " | RunInference(model_handler)\n", - " | beam.ParDo(FormatOutput())\n", - " | beam.Map(print)\n", - " )" - ], - "execution_count": 10, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n" ] }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "WARNING:tensorflow:From /usr/local/lib/python3.9/dist-packages/tfx_bsl/beam/run_inference.py:615: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", @@ -562,8 +490,8 @@ ] }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "example is 20.00 prediction is 104.36\n", "example is 40.00 prediction is 202.62\n", @@ -571,10 +499,36 @@ "example is 90.00 prediction is 448.26\n" ] } + ], + "source": [ + "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n", + "\n", + "class FormatOutput(beam.DoFn):\n", + " def process(self, element: prediction_log_pb2.PredictionLog):\n", + " predict_log = element.predict_log\n", + " input_value = tf.train.Example.FromString(predict_log.request.inputs['examples'].string_val[0])\n", + " input_float_value = input_value.features.feature['x'].float_list.value[0]\n", + " output_value = predict_log.response.outputs\n", + " output_float_value = output_value['output_0'].float_val[0]\n", + " yield (f\"example is {input_float_value:.2f} prediction is {output_float_value:.2f}\")\n", + "\n", + "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n", + "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n", + "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n", + "model_handler = CreateModelHandler(inference_spec_type)\n", + "with beam.Pipeline() as p:\n", + " _ = (p | tfexample_beam_record.RawRecordBeamSource()\n", + " | RunInference(model_handler)\n", + " | beam.ParDo(FormatOutput())\n", + " | beam.Map(print)\n", + " )" ] }, { "cell_type": "markdown", + "metadata": { + "id": "IXikjkGdHm9n" + }, "source": [ "## Use `KeyedModelHandler` with `tfx-bsl`\n", "\n", @@ -584,13 +538,30 @@ "* If you don't know whether keys are associated with your examples, use `beam.MaybeKeyedModelHandler`.\n", "\n", "In addition to demonstrating how to use a keyed model handler, this step demonstrates how to use `tfx-bsl` examples." - ], - "metadata": { - "id": "IXikjkGdHm9n" - } + ] }, { "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KPtE3fmdJQry", + "outputId": "c33558fc-fb12-4c20-b828-b5520721f279" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "key 5.0 : example is 5.00 prediction is 30.67\n", + "key 50.0 : example is 50.00 prediction is 251.75\n", + "key 40.0 : example is 40.00 prediction is 202.62\n", + "key 100.0 : example is 100.00 prediction is 497.38\n" + ] + } + ], "source": [ "from apache_beam.ml.inference.base import KeyedModelHandler\n", "from google.protobuf import text_format\n", @@ -632,27 +603,32 @@ " | beam.ParDo(FormatOutputKeyed())\n", " | beam.Map(print)\n", " )" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KPtE3fmdJQry", - "outputId": "c33558fc-fb12-4c20-b828-b5520721f279" - }, - "execution_count": 11, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "key 5.0 : example is 5.00 prediction is 30.67\n", - "key 50.0 : example is 50.00 prediction is 251.75\n", - "key 40.0 : example is 40.00 prediction is 202.62\n", - "key 100.0 : example is 100.00 prediction is 497.38\n" - ] - } ] } - ] -} \ No newline at end of file + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "X80jy3FqHjK4", + "40qtP6zJuMXm", + "YzvZWEv-1oiK", + "rIwD_qEpX7Gu", + "O_a0-4Gb19cy", + "G-sAu3cf31f3", + "r4dpR6dQ4JwX", + "P2UMmbNW4YQV" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/notebooks/beam-ml/run_inference_vertex_ai.ipynb b/examples/notebooks/beam-ml/run_inference_vertex_ai.ipynb index 46bfc0f2fc00..2ab45e0491a7 100644 --- a/examples/notebooks/beam-ml/run_inference_vertex_ai.ipynb +++ b/examples/notebooks/beam-ml/run_inference_vertex_ai.ipynb @@ -151,6 +151,17 @@ "Replace `PROJECT_ID`, `LOCATION_NAME`, and `ENDPOINT_ID` with the ID of your project, the GCP region where your model is deployed, and the ID of your Vertex AI endpoint." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PROJECT_ID = \"\" # @param {type:'string'}\n", + "LOCATION_NAME = \"\" # @param {type:'string'}\n", + "ENDPOINT_ID = \"> beam.Create([IMG_URL])\n", " | beam.Map(lambda img_name: (img_name, download_image(img_name)))\n", diff --git a/examples/notebooks/beam-ml/run_inference_vllm.ipynb b/examples/notebooks/beam-ml/run_inference_vllm.ipynb index 13b4a915c087..40eff1af5155 100644 --- a/examples/notebooks/beam-ml/run_inference_vllm.ipynb +++ b/examples/notebooks/beam-ml/run_inference_vllm.ipynb @@ -1,24 +1,12 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "gpuType": "T4", - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OsFaZscKSPvo" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -38,15 +26,13 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "id": "OsFaZscKSPvo" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "NrHRIznKp3nS" + }, "source": [ "# Run ML inference by using vLLM on GPUs\n", "\n", @@ -58,13 +44,13 @@ " View source on GitHub\n", " \n", "" - ], - "metadata": { - "id": "NrHRIznKp3nS" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "H0ZFs9rDvtJm" + }, "source": [ "[vLLM](https://github.com/vllm-project/vllm) is a fast and user-friendly library for LLM inference and serving. vLLM optimizes LLM inference with mechanisms like PagedAttention for memory management and continuous batching for increasing throughput. For popular models, vLLM has been shown to increase throughput by a multiple of 2 to 4. With Apache Beam, you can serve models with vLLM and scale that serving with just a few lines of code.\n", "\n", @@ -75,13 +61,13 @@ "* remotely with the Dataflow runner\n", "\n", "It also shows how to swap in a different model without modifying your pipeline structure by changing the configuration." - ], - "metadata": { - "id": "H0ZFs9rDvtJm" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "6x41tnbTvQM1" + }, "source": [ "## Requirements\n", "\n", @@ -91,21 +77,18 @@ "\n", "- a computer with Docker installed\n", "- a [Google Cloud](https://cloud.google.com/) account" - ], - "metadata": { - "id": "6x41tnbTvQM1" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "8PSjyDIavRcn" + }, "source": [ "## Install dependencies\n", "\n", "Before creating your pipeline, download and install the dependencies required to develop with Apache Beam and vLLM. vLLM is supported in Apache Beam versions 2.60.0 and later." - ], - "metadata": { - "id": "8PSjyDIavRcn" - } + ] }, { "cell_type": "code", @@ -117,36 +100,65 @@ "source": [ "!pip install openai>=1.52.2\n", "!pip install vllm>=0.6.3\n", - "!pip install apache-beam[gcp]==2.60.0\n", + "!pip install triton>=3.1.0\n", + "!pip install apache-beam[gcp]==2.61.0\n", + "!pip install nest_asyncio # only needed in colab\n", "!pip check" ] }, { "cell_type": "markdown", "source": [ - "## Run locally without Apache Beam\n", - "\n", - "In this section, you run a vLLM server without using Apache Beam. Use the `facebook/opt-125m` model. This model is small enough to fit in Colab memory and doesn't require any extra authentication.\n", + "## Colab only: allow nested asyncio\n", "\n", - "First, start the vLLM server. This step might take a minute or two, because the model needs to download before vLLM starts running inference." + "The vLLM model handler logic below uses asyncio to feed vLLM records. This only works if we are not already in an asyncio event loop. Most of the time, this is fine, but colab already operates in an event loop. To work around this, we can use nest_asyncio to make things work smoothly in colab. Do not include this step outside of colab." ], "metadata": { - "id": "3xz8zuA7vcS4" + "id": "3xz8zuA7vcS3" } }, { "cell_type": "code", "source": [ - "! python -m vllm.entrypoints.openai.api_server --model facebook/opt-125m" + "# This should not be necessary outside of colab.\n", + "import nest_asyncio\n", + "nest_asyncio.apply()\n" ], "metadata": { - "id": "GbJGzINNt5sG" + "id": "sUqjOzw3wpI3" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", + "metadata": { + "id": "sUqjOzw3wpI4" + }, + "source": [ + "## Run locally without Apache Beam\n", + "\n", + "In this section, you run a vLLM server without using Apache Beam. Use the `facebook/opt-125m` model. This model is small enough to fit in Colab memory and doesn't require any extra authentication.\n", + "\n", + "First, start the vLLM server. This step might take a minute or two, because the model needs to download before vLLM starts running inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GbJGzINNt5sG" + }, + "outputs": [], + "source": [ + "! python -m vllm.entrypoints.openai.api_server --model facebook/opt-125m" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n35LXTS3uzIC" + }, "source": [ "Next, while the vLLM server is running, open a separate terminal to communicate with the vLLM serving process. To open a terminal in Colab, in the sidebar, click **Terminal**. In the terminal, run the following commands.\n", "\n", @@ -169,26 +181,28 @@ "```\n", "\n", "This code runs against the server running in the cell. You can experiment with different prompts." - ], - "metadata": { - "id": "n35LXTS3uzIC" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "Hbxi83BfwbBa" + }, "source": [ "## Run locally with Apache Beam\n", "\n", "In this section, you set up an Apache Beam pipeline to run a job with an embedded vLLM instance.\n", "\n", "First, define the `VllmCompletionsModelHandler` object. This configuration object gives Apache Beam the information that it needs to create a dedicated vLLM process in the middle of the pipeline. Apache Beam then provides examples to the pipeline. No additional code is needed." - ], - "metadata": { - "id": "Hbxi83BfwbBa" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sUqjOzw3wpI4" + }, + "outputs": [], "source": [ "from apache_beam.ml.inference.base import RunInference\n", "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n", @@ -196,24 +210,24 @@ "import apache_beam as beam\n", "\n", "model_handler = VLLMCompletionsModelHandler('facebook/opt-125m')" - ], - "metadata": { - "id": "sUqjOzw3wpI4" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Next, define examples to run inference against, and define a helper function to print out the inference results." - ], "metadata": { "id": "N06lXRKRxCz5" - } + }, + "source": [ + "Next, define examples to run inference against, and define a helper function to print out the inference results." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3a1PznmtxNR_" + }, + "outputs": [], "source": [ "class FormatOutput(beam.DoFn):\n", " def process(self, element, *args, **kwargs):\n", @@ -226,26 +240,26 @@ " \"The future of AI is\",\n", " \"Emperor penguins are\",\n", "]" - ], - "metadata": { - "id": "3a1PznmtxNR_" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "Njl0QfrLxQ0m" + }, "source": [ "Finally, run the pipeline.\n", "\n", "This step might take a minute or two, because the model needs to download before Apache Beam can start running inference." - ], - "metadata": { - "id": "Njl0QfrLxQ0m" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9yXbzV0ZmZcJ" + }, + "outputs": [], "source": [ "with beam.Pipeline() as p:\n", " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", @@ -253,26 +267,24 @@ " | beam.ParDo(FormatOutput()) # Format the output.\n", " | beam.Map(print) # Print the formatted output.\n", " )" - ], - "metadata": { - "id": "9yXbzV0ZmZcJ" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "Jv7be6Pk9Hlx" + }, "source": [ "## Run remotely on Dataflow\n", "\n", "After you validate that the pipeline can run against a vLLM locally, you can productionalize the workflow on a remote runner. This notebook runs the pipeline on the Dataflow runner." - ], - "metadata": { - "id": "Jv7be6Pk9Hlx" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "J1LMrl1Yy6QB" + }, "source": [ "### Build a Docker image\n", "\n", @@ -284,24 +296,26 @@ "\n", "- The Python version in the following cell matches the Python version defined in the Dockerfile.\n", "- The Apache Beam version defined in your dependencies matches the Apache Beam version defined in the Dockerfile." - ], - "metadata": { - "id": "J1LMrl1Yy6QB" - } + ] }, { "cell_type": "code", - "source": [ - "!python --version" - ], + "execution_count": null, "metadata": { "id": "jCQ6-D55gqfl" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "!python --version" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7QyNq_gygHLO" + }, + "outputs": [], "source": [ "cell_str='''\n", "FROM nvidia/cuda:12.4.1-devel-ubuntu22.04\n", @@ -327,7 +341,7 @@ "COPY --from=apache/beam_python3.10_sdk:2.60.0 /opt/apache/beam /opt/apache/beam\n", "\n", "RUN pip install --no-cache-dir -vvv apache-beam[gcp]==2.60.0\n", - "RUN pip install openai>=1.52.2 vllm>=0.6.3\n", + "RUN pip install openai>=1.52.2 vllm>=0.6.3 triton>=3.1.0\n", "\n", "RUN apt install libcairo2-dev pkg-config python3-dev -y\n", "RUN pip install pycairo\n", @@ -338,15 +352,13 @@ "\n", "with open('VllmDockerfile', 'w') as f:\n", " f.write(cell_str)" - ], - "metadata": { - "id": "7QyNq_gygHLO" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "zWma0YetiEn5" + }, "source": [ "After you save the Dockerfile, build and push your Docker image. Because Docker is not accessible from Colab, you need to complete this step in a separate environment.\n", "\n", @@ -358,13 +370,13 @@ " docker build -t \":\" -f VllmDockerfile ./\n", " docker image push \":\"\n", " ```" - ], - "metadata": { - "id": "zWma0YetiEn5" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "NjZyRjte0g0Q" + }, "source": [ "### Define and run the pipeline\n", "\n", @@ -378,13 +390,15 @@ "- ``: the name of the Google Cloud project that you created your bucket and Artifact Registry repository in.\n", "\n", "This workflow uses the following Dataflow service option: `worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx`. When you use this service option, Dataflow to installs a T4 GPU that uses a `5xx` series Nvidia driver on each worker machine. The 5xx driver is required to run vLLM jobs." - ], - "metadata": { - "id": "NjZyRjte0g0Q" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kXy9FRYVCSjq" + }, + "outputs": [], "source": [ "\n", "from apache_beam.options.pipeline_options import GoogleCloudOptions\n", @@ -396,9 +410,12 @@ "\n", "options = PipelineOptions()\n", "\n", - "BUCKET_NAME = '' # Replace with your bucket name.\n", - "CONTAINER_IMAGE = ':' # Replace with the image repository and tag from the previous step.\n", - "PROJECT_NAME = '' # Replace with your GCP project\n", + "# Replace with your bucket name.\n", + "BUCKET_NAME = '' # @param {type:'string'}\n", + "# Replace with the image repository and tag from the previous step.\n", + "CONTAINER_IMAGE = ':' # @param {type:'string'}\n", + "# Replace with your GCP project\n", + "PROJECT_NAME = '' # @param {type:'string'}\n", "\n", "options.view_as(GoogleCloudOptions).project = PROJECT_NAME\n", "\n", @@ -430,50 +447,50 @@ "options.view_as(WorkerOptions).machine_type = \"n1-standard-4\"\n", "\n", "options.view_as(WorkerOptions).sdk_container_image = CONTAINER_IMAGE" - ], - "metadata": { - "id": "kXy9FRYVCSjq" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Next, authenticate Colab so that it can to submit a job on your behalf." - ], "metadata": { "id": "xPhe597P1-QJ" - } + }, + "source": [ + "Next, authenticate Colab so that it can to submit a job on your behalf." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Xkf6yIVlFB8-" + }, + "outputs": [], "source": [ "def auth_to_colab():\n", " from google.colab import auth\n", " auth.authenticate_user()\n", "\n", "auth_to_colab()" - ], - "metadata": { - "id": "Xkf6yIVlFB8-" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "MJtEI6Ux2eza" + }, "source": [ "Finally, run the pipeline on Dataflow. The pipeline definition is almost exactly the same as the definition used for local execution. The pipeline options are the only change to the pipeline.\n", "\n", "The following code creates a Dataflow job in your project. You can view the results in Colab or in the Google Cloud console. Creating a Dataflow job and downloading the model might take a few minutes. After the job starts performing inference, it quickly runs through the inputs." - ], - "metadata": { - "id": "MJtEI6Ux2eza" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8gjDdru_9Dii" + }, + "outputs": [], "source": [ "import logging\n", "from apache_beam.ml.inference.base import RunInference\n", @@ -503,15 +520,13 @@ " | beam.ParDo(FormatOutput()) # Format the output.\n", " | beam.Map(logging.info) # Print the formatted output.\n", " )" - ], - "metadata": { - "id": "8gjDdru_9Dii" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "22cEHPCc28fH" + }, "source": [ "## Run vLLM with a Gemma model\n", "\n", @@ -526,55 +541,57 @@ "When you complete these steps, the following message appears on the model card page: `You have been granted access to this model`.\n", "\n", "Next, sign in to your account from this notebook by running the following code and then following the prompts." - ], - "metadata": { - "id": "22cEHPCc28fH" - } + ] }, { "cell_type": "code", - "source": [ - "! huggingface-cli login" - ], + "execution_count": null, "metadata": { "id": "JHwIsFI9kd9j" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "! huggingface-cli login" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "IjX2If8rnCol" + }, "source": [ "Verify that the notebook can now access the Gemma model. Run the following code, which starts a vLLM server to serve the Gemma 2b model. Because the default T4 Colab runtime doesn't support the full data type precision needed to run Gemma models, the `--dtype=half` parameter is required.\n", "\n", "When successful, the following cell runs indefinitely. After it starts the server process, you can shut it down. When the server process starts, the Gemma 2b model is successfully downloaded, and the server is ready to serve traffic." - ], - "metadata": { - "id": "IjX2If8rnCol" - } + ] }, { "cell_type": "code", - "source": [ - "! python -m vllm.entrypoints.openai.api_server --model google/gemma-2b --dtype=half" - ], + "execution_count": null, "metadata": { "id": "LH_oCFWMiwFs" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "! python -m vllm.entrypoints.openai.api_server --model google/gemma-2b --dtype=half" + ] }, { "cell_type": "markdown", - "source": [ - "To run the pipeline in Apache Beam, run the following code. Update the `VLLMCompletionsModelHandler` object with the new parameters, which match the command from the previous cell. Reuse all of the pipeline logic from the previous pipelines." - ], "metadata": { "id": "31BmdDUAn-SW" - } + }, + "source": [ + "To run the pipeline in Apache Beam, run the following code. Update the `VLLMCompletionsModelHandler` object with the new parameters, which match the command from the previous cell. Reuse all of the pipeline logic from the previous pipelines." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DyC2ikXg237p" + }, + "outputs": [], "source": [ "model_handler = VLLMCompletionsModelHandler('google/gemma-2b', vllm_server_kwargs={'dtype': 'half'})\n", "\n", @@ -584,15 +601,13 @@ " | beam.ParDo(FormatOutput()) # Format the output.\n", " | beam.Map(print) # Print the formatted output.\n", " )" - ], - "metadata": { - "id": "DyC2ikXg237p" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "C6OYfub6ovFK" + }, "source": [ "### Run Gemma on Dataflow\n", "\n", @@ -607,10 +622,24 @@ "2. Set pipeline options. You can reuse the options defined in this notebook. Replace the Docker image location with your new Docker image.\n", "3. Run the pipeline. Copy the pipeline that you ran on Dataflow, and replace the pipeline options with the pipeline options that you just defined.\n", "\n" - ], - "metadata": { - "id": "C6OYfub6ovFK" - } + ] } - ] + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/notebooks/beam-ml/vertex_ai_feature_store_enrichment.ipynb b/examples/notebooks/beam-ml/vertex_ai_feature_store_enrichment.ipynb index ebfcca34b94c..03feb96cbf68 100644 --- a/examples/notebooks/beam-ml/vertex_ai_feature_store_enrichment.ipynb +++ b/examples/notebooks/beam-ml/vertex_ai_feature_store_enrichment.ipynb @@ -197,8 +197,8 @@ }, "outputs": [], "source": [ - "PROJECT_ID = \"\"\n", - "LOCATION = \"\"" + "PROJECT_ID = \"\" # @param {type:'string'}\n", + "LOCATION = \"\" # @param {type:'string'}" ] }, { @@ -1790,10 +1790,10 @@ "outputs": [], "source": [ "# Replace with the name of your Pub/Sub topic.\n", - "TOPIC = \" \"\n", + "TOPIC = \"\" # @param {type:'string'}\n", "\n", "# Replace with the subscription path for your topic.\n", - "SUBSCRIPTION = \"\"" + "SUBSCRIPTION = \"\" # @param {type:'string'}" ] }, { diff --git a/examples/notebooks/healthcare/beam_nlp.ipynb b/examples/notebooks/healthcare/beam_nlp.ipynb index c2061bc4d75f..bbcbb6254024 100644 --- a/examples/notebooks/healthcare/beam_nlp.ipynb +++ b/examples/notebooks/healthcare/beam_nlp.ipynb @@ -1,25 +1,10 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" @@ -27,6 +12,12 @@ }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "lBuUTzxD2mvJ" + }, + "outputs": [], "source": [ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", "\n", @@ -46,16 +37,13 @@ "# KIND, either express or implied. See the License for the\n", "# specific language governing permissions and limitations\n", "# under the License" - ], - "metadata": { - "id": "lBuUTzxD2mvJ", - "cellView": "form" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "nEUAYCTx4Ijj" + }, "source": [ "# **Natural Language Processing Pipeline**\n", "\n", @@ -70,101 +58,103 @@ "For details about Apache Beam pipelines, including PTransforms and PCollections, visit the [Beam Programming Guide](https://beam.apache.org/documentation/programming-guide/).\n", "\n", "You'll be able to use this notebook to explore the data in each PCollection." - ], - "metadata": { - "id": "nEUAYCTx4Ijj" - } + ] }, { "cell_type": "markdown", - "source": [ - "First, lets install the necessary packages." - ], "metadata": { "id": "ZLBB0PTG5CHw" - } + }, + "source": [ + "First, lets install the necessary packages." + ] }, { "cell_type": "code", - "source": [ - "!pip install apache-beam[gcp]" - ], + "execution_count": null, "metadata": { "id": "O7hq2sse8K4u" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "!pip install apache-beam[gcp]" + ] }, { "cell_type": "markdown", - "source": [ - " **GCP Setup**" - ], "metadata": { "id": "5vQDhIv0E-LR" - } + }, + "source": [ + " **GCP Setup**" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "DGYiBYfxsSCw" + }, "source": [ "1. Authenticate your notebook by `gcloud auth application-default login` in the Colab terminal.\n", "\n", "2. Run `gcloud config set project `" - ], - "metadata": { - "id": "DGYiBYfxsSCw" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "D7lJqW2PRFcN" + }, "source": [ "Set the variables in the next cell based upon your project and preferences. The files referred to in this notebook nlpsample*.csv are in the format with one\n", "blurb of clinical note.\n", "\n", "Note that below, **us-central1** is hardcoded as the location. This is because of the limited number of [locations](https://cloud.google.com/healthcare-api/docs/how-tos/nlp) the API currently supports." - ], - "metadata": { - "id": "D7lJqW2PRFcN" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s9lhe5CZ5F3o" + }, + "outputs": [], "source": [ - "DATASET=\"\"\n", - "TEMP_LOCATION=\"\"\n", - "PROJECT=''\n", + "DATASET=\"\" # @param {type:'string'}\n", + "TEMP_LOCATION=\"\" # @param {type:'string'}\n", + "PROJECT=''# @param {type:'string'}\n", "LOCATION='us-central1'\n", "URL=f'https://healthcare.googleapis.com/v1/projects/{PROJECT}/locations/{LOCATION}/services/nlp:analyzeEntities'\n", "NLP_SERVICE=f'projects/{PROJECT}/locations/{LOCATION}/services/nlp'" - ], - "metadata": { - "id": "s9lhe5CZ5F3o" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Then, download [this raw CSV file](https://github.com/socd06/medical-nlp/blob/master/data/test.csv), and then upload it into Colab. You should be able to view this file (*test.csv*) in the \"Files\" tab in Colab after uploading." - ], "metadata": { "id": "1IArtEm8QuCR" - } + }, + "source": [ + "Then, download [this raw CSV file](https://github.com/socd06/medical-nlp/blob/master/data/test.csv), and then upload it into Colab. You should be able to view this file (*test.csv*) in the \"Files\" tab in Colab after uploading." + ] }, { "cell_type": "markdown", + "metadata": { + "id": "DI_Qkyn75LO-" + }, "source": [ "**BigQuery Setup**\n", "\n", "We will be using BigQuery to warehouse the structured data revealed in the output of the Healthcare NLP API. For this purpose, we create 3 tables to organize the data. Specifically, these will be table entities, table relations, and table entity mentions, which are all outputs of interest from the Healthcare NLP API." - ], - "metadata": { - "id": "DI_Qkyn75LO-" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bZDqtFVE5Wd_" + }, + "outputs": [], "source": [ "from google.cloud import bigquery\n", "\n", @@ -198,15 +188,15 @@ "print(\n", " \"Created table {}.{}.{}\".format(table.project, table.dataset_id, table.table_id)\n", ")" - ], - "metadata": { - "id": "bZDqtFVE5Wd_" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YK-G7uV5APuP" + }, + "outputs": [], "source": [ "from google.cloud import bigquery\n", "\n", @@ -240,15 +230,15 @@ ")\n", "\n", "\n" - ], - "metadata": { - "id": "YK-G7uV5APuP" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R9IHgZKoAQWj" + }, + "outputs": [], "source": [ "from google.cloud import bigquery\n", "\n", @@ -324,26 +314,26 @@ "print(\n", " \"Created table {}.{}.{}\".format(table.project, table.dataset_id, table.table_id)\n", ")" - ], - "metadata": { - "id": "R9IHgZKoAQWj" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "jc_iS_BP5aS4" + }, "source": [ "**Pipeline Setup**\n", "\n", "We will use InteractiveRunner in this notebook." - ], - "metadata": { - "id": "jc_iS_BP5aS4" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "07ct6kf55ihP" + }, + "outputs": [], "source": [ "# Python's regular expression library\n", "import re\n", @@ -365,24 +355,24 @@ " job_name=\"my-healthcare-nlp-job\",\n", " temp_location=TEMP_LOCATION,\n", " region=LOCATION)" - ], - "metadata": { - "id": "07ct6kf55ihP" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "The following defines a `PTransform` named `ReadLinesFromText`, that extracts lines from a file." - ], "metadata": { "id": "dO1A9_WK5lb4" - } + }, + "source": [ + "The following defines a `PTransform` named `ReadLinesFromText`, that extracts lines from a file." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t5iDRKMK5n_B" + }, + "outputs": [], "source": [ "class ReadLinesFromText(beam.PTransform):\n", "\n", @@ -392,74 +382,73 @@ " def expand(self, pcoll):\n", " return (pcoll.pipeline\n", " | beam.io.ReadFromText(self._file_pattern))" - ], - "metadata": { - "id": "t5iDRKMK5n_B" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "The following sets up an Apache Beam pipeline with the *Interactive Runner*. The *Interactive Runner* is the runner suitable for running in notebooks. A runner is an execution engine for Apache Beam pipelines." - ], "metadata": { "id": "HI_HVB185sMQ" - } + }, + "source": [ + "The following sets up an Apache Beam pipeline with the *Interactive Runner*. The *Interactive Runner* is the runner suitable for running in notebooks. A runner is an execution engine for Apache Beam pipelines." + ] }, { "cell_type": "code", - "source": [ - "p = beam.Pipeline(options = options)" - ], + "execution_count": null, "metadata": { "id": "7osCZ1om5ql0" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "p = beam.Pipeline(options = options)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "EaF8NfC_521y" + }, "source": [ "The following sets up a PTransform that extracts words from a Google Cloud Storage file that contains lines with each line containing a In our example, each line is a medical notes excerpt that will be passed through the Healthcare NLP API\n", "\n", "**\"|\"** is an overloaded operator that applies a PTransform to a PCollection to produce a new PCollection. Together with |, >> allows you to optionally name a PTransform.\n", "\n", "Usage:[PCollection] | [PTransform], **or** [PCollection] | [name] >> [PTransform]" - ], - "metadata": { - "id": "EaF8NfC_521y" - } + ] }, { "cell_type": "code", - "source": [ - "lines = p | 'read' >> ReadLinesFromText(\"test.csv\")" - ], + "execution_count": null, "metadata": { - "id": "2APAh6XQ6NYd", "colab": { "base_uri": "https://localhost:8080/", "height": 72 }, + "id": "2APAh6XQ6NYd", "outputId": "033c5110-fd5a-4da0-b59b-801a1ce9d3b1" }, - "execution_count": null, - "outputs": [ + "outputs": [], + "source": [ + "lines = p | 'read' >> ReadLinesFromText(\"test.csv\")" ] }, { "cell_type": "markdown", - "source": [ - "We then write a **DoFn** that will invoke the [NLP API](https://cloud.google.com/healthcare-api/docs/how-tos/nlp)." - ], "metadata": { "id": "vM_FbhkbGI-E" - } + }, + "source": [ + "We then write a **DoFn** that will invoke the [NLP API](https://cloud.google.com/healthcare-api/docs/how-tos/nlp)." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3ZJ-0dex9WE5" + }, + "outputs": [], "source": [ "class InvokeNLP(beam.DoFn):\n", "\n", @@ -486,24 +475,24 @@ " pcoll\n", " | \"Invoke NLP API\" >> beam.ParDo(InvokeNLP())\n", " )" - ], - "metadata": { - "id": "3ZJ-0dex9WE5" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "From our elements, being processed, we will get the entity mentions, relationships, and entities respectively." - ], "metadata": { "id": "TeYxIlNgGdK0" - } + }, + "source": [ + "From our elements, being processed, we will get the entity mentions, relationships, and entities respectively." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3KZgUv3d6haf" + }, + "outputs": [], "source": [ "import json\n", "from apache_beam import pvalue\n", @@ -529,15 +518,15 @@ " for e in element['entityMentions']:\n", " e['id'] = element['id']\n", " yield e\n" - ], - "metadata": { - "id": "3KZgUv3d6haf" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OkxgB2a-6iYN" + }, + "outputs": [], "source": [ "from apache_beam.io.gcp.internal.clients import bigquery\n", "\n", @@ -550,24 +539,24 @@ "nlp_annotations = (lines\n", " | \"Analyze\" >> AnalyzeLines()\n", " )\n" - ], - "metadata": { - "id": "OkxgB2a-6iYN" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "We then write these results to [BigQuery](https://cloud.google.com/bigquery), a cloud data warehouse." - ], "metadata": { "id": "iTh65CXIGoQn" - } + }, + "source": [ + "We then write these results to [BigQuery](https://cloud.google.com/bigquery), a cloud data warehouse." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q9GIyLeS6oAe" + }, + "outputs": [], "source": [ "resultsEntities = ( nlp_annotations\n", " | \"Break\" >> beam.ParDo(breakUpEntities())\n", @@ -576,15 +565,15 @@ " write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,\n", " create_disposition=beam.io.BigQueryDisposition.CREATE_NEVER)\n", " )" - ], - "metadata": { - "id": "Q9GIyLeS6oAe" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yOlHfkcT6s4y" + }, + "outputs": [], "source": [ "table_spec = bigquery.TableReference(\n", " projectId=PROJECT,\n", @@ -598,15 +587,15 @@ " write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,\n", " create_disposition=beam.io.BigQueryDisposition.CREATE_NEVER)\n", " )" - ], - "metadata": { - "id": "yOlHfkcT6s4y" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a6QxxnY890Za" + }, + "outputs": [], "source": [ "table_spec = bigquery.TableReference(\n", " projectId=PROJECT,\n", @@ -620,43 +609,31 @@ " write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,\n", " create_disposition=beam.io.BigQueryDisposition.CREATE_NEVER)\n", " )" - ], - "metadata": { - "id": "a6QxxnY890Za" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "You can see the job graph for the pipeline by doing:" - ], "metadata": { "id": "6rP2nO6Z60bt" - } + }, + "source": [ + "You can see the job graph for the pipeline by doing:" + ] }, { "cell_type": "code", - "source": [ - "ib.show_graph(p)" - ], + "execution_count": null, "metadata": { - "id": "zQB5h1Zq6x8d", "colab": { "base_uri": "https://localhost:8080/", "height": 806 }, + "id": "zQB5h1Zq6x8d", "outputId": "7885e493-fee8-402e-baf2-cbbf406a3eb9" }, - "execution_count": null, "outputs": [ { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", " \n", @@ -665,16 +642,16 @@ " Processing... show_graph\n", " \n", " " + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ "\n", "\n", "\n", "\n" + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "application/javascript": [ - "\n", - " if (typeof window.interactive_beam_jquery == 'undefined') {\n", - " var jqueryScript = document.createElement('script');\n", - " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", - " jqueryScript.type = 'text/javascript';\n", - " jqueryScript.onload = function() {\n", - " var datatableScript = document.createElement('script');\n", - " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", - " datatableScript.type = 'text/javascript';\n", - " datatableScript.onload = function() {\n", - " window.interactive_beam_jquery = jQuery.noConflict(true);\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fa6997b180fa86966dd888a7d59a34f7\").remove();\n", - " });\n", - " }\n", - " document.head.appendChild(datatableScript);\n", - " };\n", - " document.head.appendChild(jqueryScript);\n", - " } else {\n", - " window.interactive_beam_jquery(document).ready(function($){\n", - " \n", - " $(\"#progress_indicator_fa6997b180fa86966dd888a7d59a34f7\").remove();\n", - " });\n", - " }" - ] + "application/javascript": "\n if (typeof window.interactive_beam_jquery == 'undefined') {\n var jqueryScript = document.createElement('script');\n jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n jqueryScript.type = 'text/javascript';\n jqueryScript.onload = function() {\n var datatableScript = document.createElement('script');\n datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n datatableScript.type = 'text/javascript';\n datatableScript.onload = function() {\n window.interactive_beam_jquery = jQuery.noConflict(true);\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fa6997b180fa86966dd888a7d59a34f7\").remove();\n });\n }\n document.head.appendChild(datatableScript);\n };\n document.head.appendChild(jqueryScript);\n } else {\n window.interactive_beam_jquery(document).ready(function($){\n \n $(\"#progress_indicator_fa6997b180fa86966dd888a7d59a34f7\").remove();\n });\n }" }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } + ], + "source": [ + "ib.show_graph(p)" ] } - ] + ], + "metadata": { + "colab": { + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 4906d9cf9cb8..811a3c15f836 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -16,6 +16,8 @@ * limitations under the License. */ +import static org.apache.beam.gradle.BeamModulePlugin.getSupportedJavaVersion + import groovy.json.JsonOutput plugins { id 'org.apache.beam.module' } @@ -273,12 +275,75 @@ def createRunnerV2ValidatesRunnerTest = { Map args -> } } +tasks.register('examplesJavaRunnerV2IntegrationTestDistroless', Test.class) { + group = "verification" + dependsOn 'buildAndPushDistrolessContainerImage' + def javaVer = project.findProperty('testJavaVersion') + def repository = "us.gcr.io/apache-beam-testing/${System.getenv('USER')}" + def tag = project.findProperty('dockerTag') + def imageURL = "${repository}/beam_${javaVer}_sdk_distroless:${tag}" + def pipelineOptions = [ + "--runner=TestDataflowRunner", + "--project=${gcpProject}", + "--region=${gcpRegion}", + "--tempRoot=${dataflowValidatesTempRoot}", + "--sdkContainerImage=${imageURL}", + "--experiments=use_unified_worker,use_runner_v2", + "--firestoreDb=${firestoreDb}", + ] + systemProperty "beamTestPipelineOptions", JsonOutput.toJson(pipelineOptions) + + include '**/*IT.class' + + maxParallelForks 4 + classpath = configurations.examplesJavaIntegrationTest + testClassesDirs = files(project(":examples:java").sourceSets.test.output.classesDirs) + useJUnit { } +} + +tasks.register('buildAndPushDistrolessContainerImage', Task.class) { + // Only Java 17 and 21 are supported. + // See https://github.com/GoogleContainerTools/distroless/tree/main/java#image-contents. + def allowed = ["java17", "java21"] + doLast { + def javaVer = project.findProperty('testJavaVersion') + if (!allowed.contains(javaVer)) { + throw new GradleException("testJavaVersion must be one of ${allowed}, got: ${javaVer}") + } + if (!project.hasProperty('dockerTag')) { + throw new GradleException("dockerTag is missing but required") + } + def repository = "us.gcr.io/apache-beam-testing/${System.getenv('USER')}" + def tag = project.findProperty('dockerTag') + def imageURL = "${repository}/beam_${javaVer}_sdk_distroless:${tag}" + exec { + executable 'docker' + workingDir rootDir + args = [ + 'buildx', + 'build', + '-t', + imageURL, + '-f', + 'sdks/java/container/Dockerfile-distroless', + "--build-arg=BEAM_BASE=gcr.io/apache-beam-testing/beam-sdk/beam_${javaVer}_sdk", + "--build-arg=DISTROLESS_BASE=gcr.io/distroless/${javaVer}-debian12", + '.' + ] + } + exec { + executable 'docker' + args = ['push', imageURL] + } + } +} + // Push docker images to a container registry for use within tests. // NB: Tasks which consume docker images from the registry should depend on this // task directly ('dependsOn buildAndPushDockerJavaContainer'). This ensures the correct // task ordering such that the registry doesn't get cleaned up prior to task completion. def buildAndPushDockerJavaContainer = tasks.register("buildAndPushDockerJavaContainer") { - def javaVer = "java8" + def javaVer = getSupportedJavaVersion() if(project.hasProperty('testJavaVersion')) { javaVer = "java${project.getProperty('testJavaVersion')}" } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/options/DataflowStreamingPipelineOptions.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/options/DataflowStreamingPipelineOptions.java index 6a0208f1447f..61c38dde2b42 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/options/DataflowStreamingPipelineOptions.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/options/DataflowStreamingPipelineOptions.java @@ -20,6 +20,7 @@ import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.DefaultValueFactory; import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.Hidden; import org.apache.beam.sdk.options.PipelineOptions; import org.joda.time.Duration; @@ -219,10 +220,8 @@ public interface DataflowStreamingPipelineOptions extends PipelineOptions { void setWindmillServiceStreamMaxBackoffMillis(int value); - @Description( - "If true, Dataflow streaming pipeline will be running in direct path mode." - + " VMs must have IPv6 enabled for this to work.") - @Default.Boolean(false) + @Description("Enables direct path mode for streaming engine.") + @Default.InstanceFactory(EnableWindmillServiceDirectPathFactory.class) boolean getIsWindmillServiceDirectPathEnabled(); void setIsWindmillServiceDirectPathEnabled(boolean isWindmillServiceDirectPathEnabled); @@ -300,4 +299,12 @@ public Integer create(PipelineOptions options) { return streamingOptions.isEnableStreamingEngine() ? Integer.MAX_VALUE : 1; } } + + /** EnableStreamingEngine defaults to false unless one of the experiment is set. */ + class EnableWindmillServiceDirectPathFactory implements DefaultValueFactory { + @Override + public Boolean create(PipelineOptions options) { + return ExperimentalOptions.hasExperiment(options, "enable_windmill_service_direct_path"); + } + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 6ce60283735f..088a28e9b2db 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -17,11 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory.remoteChannel; -import com.google.api.services.dataflow.model.CounterUpdate; import com.google.api.services.dataflow.model.MapTask; import com.google.auto.value.AutoValue; +import java.io.PrintWriter; import java.util.List; import java.util.Map; import java.util.Optional; @@ -33,24 +33,28 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import javax.annotation.Nullable; import org.apache.beam.runners.core.metrics.MetricsLogger; import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; -import org.apache.beam.runners.dataflow.worker.counters.DataflowCounterUpdateExtractor; import org.apache.beam.runners.dataflow.worker.status.DebugCapture; import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ComputationStateCache; import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor; import org.apache.beam.runners.dataflow.worker.streaming.config.ComputationConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.FixedGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingApplianceComputationConfigFetcher; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingEngineComputationConfigFetcher; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; +import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandleImpl; +import org.apache.beam.runners.dataflow.worker.streaming.harness.FanOutStreamingEngineWorkerHarness; import org.apache.beam.runners.dataflow.worker.streaming.harness.SingleSourceWorkerHarness; import org.apache.beam.runners.dataflow.worker.streaming.harness.SingleSourceWorkerHarness.GetWorkSender; import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; @@ -59,12 +63,15 @@ import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingWorkerStatusReporter; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; +import org.apache.beam.runners.dataflow.worker.windmill.ApplianceWindmillClient; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; +import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter; @@ -78,8 +85,16 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcDispatcherClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCache; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingRemoteStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.IsolationChannel; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottledTimeTracker; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributors; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.StreamingWorkScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.StreamingApplianceFailureTracker; @@ -89,6 +104,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.ApplianceHeartbeatSender; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.StreamPoolHeartbeatSender; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.sdk.fn.IdGenerators; import org.apache.beam.sdk.fn.JvmInitializers; @@ -98,18 +114,25 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** Implements a Streaming Dataflow worker. */ +/** + * For internal use only. + * + *

Implements a Streaming Dataflow worker. + */ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) +@Internal public final class StreamingDataflowWorker { /** @@ -142,6 +165,10 @@ public final class StreamingDataflowWorker { private static final int DEFAULT_STATUS_PORT = 8081; private static final Random CLIENT_ID_GENERATOR = new Random(); private static final String CHANNELZ_PATH = "/channelz"; + private static final String BEAM_FN_API_EXPERIMENT = "beam_fn_api"; + private static final String ENABLE_IPV6_EXPERIMENT = "enable_private_ipv6_google_access"; + private static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL_EXPERIMENT = + "streaming_engine_use_job_settings_for_heartbeat_pool"; private final WindmillStateCache stateCache; private final StreamingWorkerStatusPages statusPages; @@ -155,9 +182,8 @@ public final class StreamingDataflowWorker { private final ReaderCache readerCache; private final DataflowExecutionStateSampler sampler = DataflowExecutionStateSampler.instance(); private final ActiveWorkRefresher activeWorkRefresher; - private final WorkCommitter workCommitter; private final StreamingWorkerStatusReporter workerStatusReporter; - private final StreamingCounters streamingCounters; + private final int numCommitThreads; private StreamingDataflowWorker( WindmillServerStub windmillServer, @@ -170,17 +196,17 @@ private StreamingDataflowWorker( DataflowWorkerHarnessOptions options, HotKeyLogger hotKeyLogger, Supplier clock, - StreamingWorkerStatusReporter workerStatusReporter, + StreamingWorkerStatusReporterFactory streamingWorkerStatusReporterFactory, FailureTracker failureTracker, WorkFailureProcessor workFailureProcessor, StreamingCounters streamingCounters, MemoryMonitor memoryMonitor, GrpcWindmillStreamFactory windmillStreamFactory, ScheduledExecutorService activeWorkRefreshExecutorFn, - ConcurrentMap stageInfoMap) { + ConcurrentMap stageInfoMap, + @Nullable GrpcDispatcherClient dispatcherClient) { // Register standard file systems. FileSystems.setDefaultPipelineOptions(options); - this.configFetcher = configFetcher; this.computationStateCache = computationStateCache; this.stateCache = windmillStateCache; @@ -189,35 +215,13 @@ private StreamingDataflowWorker( Duration.standardSeconds(options.getReaderCacheTimeoutSec()), Executors.newCachedThreadPool()); this.options = options; - - boolean windmillServiceEnabled = options.isEnableStreamingEngine(); - - int numCommitThreads = 1; - if (windmillServiceEnabled && options.getWindmillServiceCommitThreads() > 0) { - numCommitThreads = options.getWindmillServiceCommitThreads(); - } - - this.workCommitter = - windmillServiceEnabled - ? StreamingEngineWorkCommitter.builder() - .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) - .setCommitWorkStreamFactory( - WindmillStreamPool.create( - numCommitThreads, - COMMIT_STREAM_TIMEOUT, - windmillServer::commitWorkStream) - ::getCloseableStream) - .setNumCommitSenders(numCommitThreads) - .setOnCommitComplete(this::onCompleteCommit) - .build() - : StreamingApplianceWorkCommitter.create( - windmillServer::commitWork, this::onCompleteCommit); - this.workUnitExecutor = workUnitExecutor; - - this.workerStatusReporter = workerStatusReporter; - this.streamingCounters = streamingCounters; this.memoryMonitor = BackgroundMemoryMonitor.create(memoryMonitor); + this.numCommitThreads = + options.isEnableStreamingEngine() + ? Math.max(options.getWindmillServiceCommitThreads(), 1) + : 1; + StreamingWorkScheduler streamingWorkScheduler = StreamingWorkScheduler.create( options, @@ -234,107 +238,200 @@ private StreamingDataflowWorker( ID_GENERATOR, configFetcher.getGlobalConfigHandle(), stageInfoMap); - ThrottlingGetDataMetricTracker getDataMetricTracker = new ThrottlingGetDataMetricTracker(memoryMonitor); - WorkerStatusPages workerStatusPages = - WorkerStatusPages.create(DEFAULT_STATUS_PORT, memoryMonitor); - StreamingWorkerStatusPages.Builder statusPagesBuilder = StreamingWorkerStatusPages.builder(); - int stuckCommitDurationMillis; - GetDataClient getDataClient; - HeartbeatSender heartbeatSender; - if (windmillServiceEnabled) { - WindmillStreamPool getDataStreamPool = - WindmillStreamPool.create( - Math.max(1, options.getWindmillGetDataStreamCount()), - GET_DATA_STREAM_TIMEOUT, - windmillServer::getDataStream); - getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool); - if (options.getUseSeparateWindmillHeartbeatStreams() != null) { + // Status page members. Different implementations on whether the harness is streaming engine + // direct path, streaming engine cloud path, or streaming appliance. + @Nullable ChannelzServlet channelzServlet = null; + Consumer getDataStatusProvider; + Supplier currentActiveCommitBytesProvider; + if (isDirectPathPipeline(options)) { + WeightedSemaphore maxCommitByteSemaphore = Commits.maxCommitByteSemaphore(); + FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkerHarness = + FanOutStreamingEngineWorkerHarness.create( + createJobHeader(options, clientId), + GetWorkBudget.builder() + .setItems(chooseMaxBundlesOutstanding(options)) + .setBytes(MAX_GET_WORK_FETCH_BYTES) + .build(), + windmillStreamFactory, + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> + computationStateCache + .get(processingContext.computationId()) + .ifPresent( + computationState -> { + memoryMonitor.waitForResources("GetWork"); + streamingWorkScheduler.scheduleWork( + computationState, + workItem, + watermarks, + processingContext, + getWorkStreamLatencies); + }), + createFanOutStubFactory(options), + GetWorkBudgetDistributors.distributeEvenly(), + Preconditions.checkNotNull(dispatcherClient), + commitWorkStream -> + StreamingEngineWorkCommitter.builder() + // Share the commitByteSemaphore across all created workCommitters. + .setCommitByteSemaphore(maxCommitByteSemaphore) + .setBackendWorkerToken(commitWorkStream.backendWorkerToken()) + .setOnCommitComplete(this::onCompleteCommit) + .setNumCommitSenders(Math.max(options.getWindmillServiceCommitThreads(), 1)) + .setCommitWorkStreamFactory( + () -> CloseableStream.create(commitWorkStream, () -> {})) + .build(), + getDataMetricTracker); + getDataStatusProvider = getDataMetricTracker::printHtml; + currentActiveCommitBytesProvider = + fanOutStreamingEngineWorkerHarness::currentActiveCommitBytes; + channelzServlet = + createChannelzServlet( + options, fanOutStreamingEngineWorkerHarness::currentWindmillEndpoints); + this.streamingWorkerHarness = fanOutStreamingEngineWorkerHarness; + } else { + // Non-direct path pipelines. + Windmill.GetWorkRequest request = + Windmill.GetWorkRequest.newBuilder() + .setClientId(clientId) + .setMaxItems(chooseMaxBundlesOutstanding(options)) + .setMaxBytes(MAX_GET_WORK_FETCH_BYTES) + .build(); + GetDataClient getDataClient; + HeartbeatSender heartbeatSender; + WorkCommitter workCommitter; + GetWorkSender getWorkSender; + if (options.isEnableStreamingEngine()) { + WindmillStreamPool getDataStreamPool = + WindmillStreamPool.create( + Math.max(1, options.getWindmillGetDataStreamCount()), + GET_DATA_STREAM_TIMEOUT, + windmillServer::getDataStream); + getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool); heartbeatSender = - StreamPoolHeartbeatSender.Create( - Boolean.TRUE.equals(options.getUseSeparateWindmillHeartbeatStreams()) - ? separateHeartbeatPool(windmillServer) - : getDataStreamPool); - + createStreamingEngineHeartbeatSender( + options, windmillServer, getDataStreamPool, configFetcher.getGlobalConfigHandle()); + channelzServlet = + createChannelzServlet(options, windmillServer::getWindmillServiceEndpoints); + workCommitter = + StreamingEngineWorkCommitter.builder() + .setCommitWorkStreamFactory( + WindmillStreamPool.create( + numCommitThreads, + COMMIT_STREAM_TIMEOUT, + windmillServer::commitWorkStream) + ::getCloseableStream) + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) + .setNumCommitSenders(numCommitThreads) + .setOnCommitComplete(this::onCompleteCommit) + .build(); + getWorkSender = + GetWorkSender.forStreamingEngine( + receiver -> windmillServer.getWorkStream(request, receiver)); } else { - heartbeatSender = - StreamPoolHeartbeatSender.Create( - separateHeartbeatPool(windmillServer), - getDataStreamPool, - configFetcher.getGlobalConfigHandle()); + getDataClient = new ApplianceGetDataClient(windmillServer, getDataMetricTracker); + heartbeatSender = new ApplianceHeartbeatSender(windmillServer::getData); + workCommitter = + StreamingApplianceWorkCommitter.create( + windmillServer::commitWork, this::onCompleteCommit); + getWorkSender = GetWorkSender.forAppliance(() -> windmillServer.getWork(request)); } - stuckCommitDurationMillis = - options.getStuckCommitDurationMillis() > 0 ? options.getStuckCommitDurationMillis() : 0; - statusPagesBuilder - .setDebugCapture( - new DebugCapture.Manager(options, workerStatusPages.getDebugCapturePages())) - .setChannelzServlet( - new ChannelzServlet( - CHANNELZ_PATH, options, windmillServer::getWindmillServiceEndpoints)) - .setWindmillStreamFactory(windmillStreamFactory); - } else { - getDataClient = new ApplianceGetDataClient(windmillServer, getDataMetricTracker); - heartbeatSender = new ApplianceHeartbeatSender(windmillServer::getData); - stuckCommitDurationMillis = 0; + getDataStatusProvider = getDataClient::printHtml; + currentActiveCommitBytesProvider = workCommitter::currentActiveCommitBytes; + + this.streamingWorkerHarness = + SingleSourceWorkerHarness.builder() + .setStreamingWorkScheduler(streamingWorkScheduler) + .setWorkCommitter(workCommitter) + .setGetDataClient(getDataClient) + .setComputationStateFetcher(this.computationStateCache::get) + .setWaitForResources(() -> memoryMonitor.waitForResources("GetWork")) + .setHeartbeatSender(heartbeatSender) + .setThrottledTimeTracker(windmillServer::getAndResetThrottleTime) + .setGetWorkSender(getWorkSender) + .build(); } + this.workerStatusReporter = + streamingWorkerStatusReporterFactory.createStatusReporter(streamingWorkerHarness); this.activeWorkRefresher = new ActiveWorkRefresher( clock, options.getActiveWorkRefreshPeriodMillis(), - stuckCommitDurationMillis, + options.isEnableStreamingEngine() + ? Math.max(options.getStuckCommitDurationMillis(), 0) + : 0, computationStateCache::getAllPresentComputations, sampler, activeWorkRefreshExecutorFn, getDataMetricTracker::trackHeartbeats); this.statusPages = - statusPagesBuilder + createStatusPageBuilder(options, windmillStreamFactory, memoryMonitor) .setClock(clock) .setClientId(clientId) .setIsRunning(running) - .setStatusPages(workerStatusPages) .setStateCache(stateCache) .setComputationStateCache(this.computationStateCache) - .setCurrentActiveCommitBytes(workCommitter::currentActiveCommitBytes) - .setGetDataStatusProvider(getDataClient::printHtml) .setWorkUnitExecutor(workUnitExecutor) .setGlobalConfigHandle(configFetcher.getGlobalConfigHandle()) + .setChannelzServlet(channelzServlet) + .setGetDataStatusProvider(getDataStatusProvider) + .setCurrentActiveCommitBytes(currentActiveCommitBytesProvider) .build(); - Windmill.GetWorkRequest request = - Windmill.GetWorkRequest.newBuilder() - .setClientId(clientId) - .setMaxItems(chooseMaximumBundlesOutstanding()) - .setMaxBytes(MAX_GET_WORK_FETCH_BYTES) - .build(); - - this.streamingWorkerHarness = - SingleSourceWorkerHarness.builder() - .setStreamingWorkScheduler(streamingWorkScheduler) - .setWorkCommitter(workCommitter) - .setGetDataClient(getDataClient) - .setComputationStateFetcher(this.computationStateCache::get) - .setWaitForResources(() -> memoryMonitor.waitForResources("GetWork")) - .setHeartbeatSender(heartbeatSender) - .setGetWorkSender( - windmillServiceEnabled - ? GetWorkSender.forStreamingEngine( - receiver -> windmillServer.getWorkStream(request, receiver)) - : GetWorkSender.forAppliance(() -> windmillServer.getWork(request))) - .build(); - - LOG.debug("windmillServiceEnabled: {}", windmillServiceEnabled); + LOG.debug("isDirectPathEnabled: {}", options.getIsWindmillServiceDirectPathEnabled()); + LOG.debug("windmillServiceEnabled: {}", options.isEnableStreamingEngine()); LOG.debug("WindmillServiceEndpoint: {}", options.getWindmillServiceEndpoint()); LOG.debug("WindmillServicePort: {}", options.getWindmillServicePort()); LOG.debug("LocalWindmillHostport: {}", options.getLocalWindmillHostport()); } - private static WindmillStreamPool separateHeartbeatPool( - WindmillServerStub windmillServer) { - return WindmillStreamPool.create(1, GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream); + private static StreamingWorkerStatusPages.Builder createStatusPageBuilder( + DataflowWorkerHarnessOptions options, + GrpcWindmillStreamFactory windmillStreamFactory, + MemoryMonitor memoryMonitor) { + WorkerStatusPages workerStatusPages = + WorkerStatusPages.create(DEFAULT_STATUS_PORT, memoryMonitor); + + StreamingWorkerStatusPages.Builder streamingStatusPages = + StreamingWorkerStatusPages.builder().setStatusPages(workerStatusPages); + + return options.isEnableStreamingEngine() + ? streamingStatusPages + .setDebugCapture( + new DebugCapture.Manager(options, workerStatusPages.getDebugCapturePages())) + .setWindmillStreamFactory(windmillStreamFactory) + : streamingStatusPages; + } + + private static ChannelzServlet createChannelzServlet( + DataflowWorkerHarnessOptions options, + Supplier> windmillEndpointProvider) { + return new ChannelzServlet(CHANNELZ_PATH, options, windmillEndpointProvider); + } + + private static HeartbeatSender createStreamingEngineHeartbeatSender( + DataflowWorkerHarnessOptions options, + WindmillServerStub windmillClient, + WindmillStreamPool getDataStreamPool, + StreamingGlobalConfigHandle globalConfigHandle) { + // Experiment gates the logic till backend changes are rollback safe + if (!DataflowRunner.hasExperiment( + options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL_EXPERIMENT) + || options.getUseSeparateWindmillHeartbeatStreams() != null) { + return StreamPoolHeartbeatSender.create( + Boolean.TRUE.equals(options.getUseSeparateWindmillHeartbeatStreams()) + ? WindmillStreamPool.create(1, GET_DATA_STREAM_TIMEOUT, windmillClient::getDataStream) + : getDataStreamPool); + + } else { + return StreamPoolHeartbeatSender.create( + WindmillStreamPool.create(1, GET_DATA_STREAM_TIMEOUT, windmillClient::getDataStream), + getDataStreamPool, + globalConfigHandle); + } } public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions options) { @@ -387,17 +484,21 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o failureTracker, () -> Optional.ofNullable(memoryMonitor.tryToDumpHeap()), clock); - StreamingWorkerStatusReporter workerStatusReporter = - StreamingWorkerStatusReporter.create( - dataflowServiceClient, - windmillServer::getAndResetThrottleTime, - stageInfo::values, - failureTracker, - streamingCounters, - memoryMonitor, - workExecutor, - options.getWindmillHarnessUpdateReportingPeriod().getMillis(), - options.getPerWorkerMetricsUpdateReportingPeriodMillis()); + StreamingWorkerStatusReporterFactory workerStatusReporterFactory = + throttleTimeSupplier -> + StreamingWorkerStatusReporter.builder() + .setDataflowServiceClient(dataflowServiceClient) + .setWindmillQuotaThrottleTime(throttleTimeSupplier) + .setAllStageInfo(stageInfo::values) + .setFailureTracker(failureTracker) + .setStreamingCounters(streamingCounters) + .setMemoryMonitor(memoryMonitor) + .setWorkExecutor(workExecutor) + .setWindmillHarnessUpdateReportingPeriodMillis( + options.getWindmillHarnessUpdateReportingPeriod().getMillis()) + .setPerWorkerMetricsUpdateReportingPeriodMillis( + options.getPerWorkerMetricsUpdateReportingPeriodMillis()) + .build(); return new StreamingDataflowWorker( windmillServer, @@ -410,7 +511,7 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o options, new HotKeyLogger(), clock, - workerStatusReporter, + workerStatusReporterFactory, failureTracker, workFailureProcessor, streamingCounters, @@ -418,7 +519,8 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o configFetcherComputationStateCacheAndWindmillClient.windmillStreamFactory(), Executors.newSingleThreadScheduledExecutor( new ThreadFactoryBuilder().setNameFormat("RefreshWork").build()), - stageInfo); + stageInfo, + configFetcherComputationStateCacheAndWindmillClient.windmillDispatcherClient()); } /** @@ -433,53 +535,121 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o WorkUnitClient dataflowServiceClient, GrpcWindmillStreamFactory.Builder windmillStreamFactoryBuilder, Function computationStateCacheFactory) { - ComputationConfig.Fetcher configFetcher; - WindmillServerStub windmillServer; - ComputationStateCache computationStateCache; - GrpcWindmillStreamFactory windmillStreamFactory; if (options.isEnableStreamingEngine()) { GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); - configFetcher = + ComputationConfig.Fetcher configFetcher = StreamingEngineComputationConfigFetcher.create( options.getGlobalConfigRefreshPeriod().getMillis(), dataflowServiceClient); configFetcher.getGlobalConfigHandle().registerConfigObserver(dispatcherClient::onJobConfig); - computationStateCache = computationStateCacheFactory.apply(configFetcher); - windmillStreamFactory = + ComputationStateCache computationStateCache = + computationStateCacheFactory.apply(configFetcher); + GrpcWindmillStreamFactory windmillStreamFactory = windmillStreamFactoryBuilder .setProcessHeartbeatResponses( new WorkHeartbeatResponseProcessor(computationStateCache::get)) .setHealthCheckIntervalMillis( options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()) .build(); - windmillServer = GrpcWindmillServer.create(options, windmillStreamFactory, dispatcherClient); - } else { - if (options.getWindmillServiceEndpoint() != null - || options.getLocalWindmillHostport().startsWith("grpc:")) { - windmillStreamFactory = - windmillStreamFactoryBuilder - .setHealthCheckIntervalMillis( - options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()) - .build(); - windmillServer = - GrpcWindmillServer.create( - options, - windmillStreamFactory, - GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options))); - } else { - windmillStreamFactory = windmillStreamFactoryBuilder.build(); - windmillServer = new JniWindmillApplianceServer(options.getLocalWindmillHostport()); + return ConfigFetcherComputationStateCacheAndWindmillClient.builder() + .setWindmillDispatcherClient(dispatcherClient) + .setConfigFetcher(configFetcher) + .setComputationStateCache(computationStateCache) + .setWindmillStreamFactory(windmillStreamFactory) + .setWindmillServer( + GrpcWindmillServer.create(options, windmillStreamFactory, dispatcherClient)) + .build(); + } + + // Build with local Windmill client. + if (options.getWindmillServiceEndpoint() != null + || options.getLocalWindmillHostport().startsWith("grpc:")) { + GrpcDispatcherClient dispatcherClient = + GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); + GrpcWindmillStreamFactory windmillStreamFactory = + windmillStreamFactoryBuilder + .setHealthCheckIntervalMillis( + options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()) + .build(); + GrpcWindmillServer windmillServer = + GrpcWindmillServer.create(options, windmillStreamFactory, dispatcherClient); + ComputationConfig.Fetcher configFetcher = + createApplianceComputationConfigFetcher(windmillServer); + return ConfigFetcherComputationStateCacheAndWindmillClient.builder() + .setWindmillDispatcherClient(dispatcherClient) + .setWindmillServer(windmillServer) + .setWindmillStreamFactory(windmillStreamFactory) + .setConfigFetcher(configFetcher) + .setComputationStateCache(computationStateCacheFactory.apply(configFetcher)) + .build(); + } + + WindmillServerStub windmillServer = + new JniWindmillApplianceServer(options.getLocalWindmillHostport()); + ComputationConfig.Fetcher configFetcher = + createApplianceComputationConfigFetcher(windmillServer); + return ConfigFetcherComputationStateCacheAndWindmillClient.builder() + .setWindmillStreamFactory(windmillStreamFactoryBuilder.build()) + .setWindmillServer(windmillServer) + .setConfigFetcher(configFetcher) + .setComputationStateCache(computationStateCacheFactory.apply(configFetcher)) + .build(); + } + + private static StreamingApplianceComputationConfigFetcher createApplianceComputationConfigFetcher( + ApplianceWindmillClient windmillClient) { + return new StreamingApplianceComputationConfigFetcher( + windmillClient::getConfig, + new FixedGlobalConfigHandle(StreamingGlobalConfig.builder().build())); + } + + private static boolean isDirectPathPipeline(DataflowWorkerHarnessOptions options) { + if (options.isEnableStreamingEngine() && options.getIsWindmillServiceDirectPathEnabled()) { + boolean isIpV6Enabled = + Optional.ofNullable(options.getDataflowServiceOptions()) + .map(serviceOptions -> serviceOptions.contains(ENABLE_IPV6_EXPERIMENT)) + .orElse(false); + + if (isIpV6Enabled) { + return true; } - configFetcher = - new StreamingApplianceComputationConfigFetcher( - windmillServer::getConfig, - new FixedGlobalConfigHandle(StreamingGlobalConfig.builder().build())); - computationStateCache = computationStateCacheFactory.apply(configFetcher); + LOG.warn( + "DirectPath is currently only supported with IPv6 networking stack. This requires setting " + + "\"enable_private_ipv6_google_access\" in experimental pipeline options. " + + "For information on how to set experimental pipeline options see " + + "https://cloud.google.com/dataflow/docs/guides/setting-pipeline-options#experimental. " + + "Defaulting to CloudPath."); } - return ConfigFetcherComputationStateCacheAndWindmillClient.create( - configFetcher, computationStateCache, windmillServer, windmillStreamFactory); + return false; + } + + private static void validateWorkerOptions(DataflowWorkerHarnessOptions options) { + Preconditions.checkArgument( + options.isStreaming(), + "%s instantiated with options indicating batch use", + StreamingDataflowWorker.class.getName()); + + Preconditions.checkArgument( + !DataflowRunner.hasExperiment(options, BEAM_FN_API_EXPERIMENT), + "%s cannot be main() class with beam_fn_api enabled", + StreamingDataflowWorker.class.getSimpleName()); + } + + private static ChannelCachingStubFactory createFanOutStubFactory( + DataflowWorkerHarnessOptions workerOptions) { + return ChannelCachingRemoteStubFactory.create( + workerOptions.getGcpCredential(), + ChannelCache.create( + serviceAddress -> + // IsolationChannel will create and manage separate RPC channels to the same + // serviceAddress. + IsolationChannel.create( + () -> + remoteChannel( + serviceAddress, + workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec())))); } @VisibleForTesting @@ -495,7 +665,9 @@ static StreamingDataflowWorker forTesting( Supplier clock, Function executorSupplier, StreamingGlobalConfigHandleImpl globalConfigHandle, - int localRetryTimeoutMs) { + int localRetryTimeoutMs, + StreamingCounters streamingCounters, + WindmillStubFactoryFactory stubFactory) { ConcurrentMap stageInfo = new ConcurrentHashMap<>(); BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options); WindmillStateCache stateCache = @@ -538,7 +710,6 @@ static StreamingDataflowWorker forTesting( stateNameMap, stateCache.forComputation(mapTask.getStageName()))); MemoryMonitor memoryMonitor = MemoryMonitor.fromOptions(options); - StreamingCounters streamingCounters = StreamingCounters.create(); FailureTracker failureTracker = options.isEnableStreamingEngine() ? StreamingEngineFailureTracker.create( @@ -554,19 +725,23 @@ static StreamingDataflowWorker forTesting( () -> Optional.ofNullable(memoryMonitor.tryToDumpHeap()), clock, localRetryTimeoutMs); - StreamingWorkerStatusReporter workerStatusReporter = - StreamingWorkerStatusReporter.forTesting( - publishCounters, - workUnitClient, - windmillServer::getAndResetThrottleTime, - stageInfo::values, - failureTracker, - streamingCounters, - memoryMonitor, - workExecutor, - executorSupplier, - options.getWindmillHarnessUpdateReportingPeriod().getMillis(), - options.getPerWorkerMetricsUpdateReportingPeriodMillis()); + StreamingWorkerStatusReporterFactory workerStatusReporterFactory = + throttleTimeSupplier -> + StreamingWorkerStatusReporter.builder() + .setPublishCounters(publishCounters) + .setDataflowServiceClient(workUnitClient) + .setWindmillQuotaThrottleTime(throttleTimeSupplier) + .setAllStageInfo(stageInfo::values) + .setFailureTracker(failureTracker) + .setStreamingCounters(streamingCounters) + .setMemoryMonitor(memoryMonitor) + .setWorkExecutor(workExecutor) + .setExecutorFactory(executorSupplier) + .setWindmillHarnessUpdateReportingPeriodMillis( + options.getWindmillHarnessUpdateReportingPeriod().getMillis()) + .setPerWorkerMetricsUpdateReportingPeriodMillis( + options.getPerWorkerMetricsUpdateReportingPeriodMillis()) + .build(); GrpcWindmillStreamFactory.Builder windmillStreamFactory = createGrpcwindmillStreamFactoryBuilder(options, 1) @@ -584,7 +759,7 @@ static StreamingDataflowWorker forTesting( options, hotKeyLogger, clock, - workerStatusReporter, + workerStatusReporterFactory, failureTracker, workFailureProcessor, streamingCounters, @@ -596,7 +771,8 @@ static StreamingDataflowWorker forTesting( .build() : windmillStreamFactory.build(), executorSupplier.apply("RefreshWork"), - stageInfo); + stageInfo, + GrpcDispatcherClient.create(options, stubFactory)); } private static GrpcWindmillStreamFactory.Builder createGrpcwindmillStreamFactoryBuilder( @@ -605,13 +781,7 @@ private static GrpcWindmillStreamFactory.Builder createGrpcwindmillStreamFactory !options.isEnableStreamingEngine() && options.getLocalWindmillHostport() != null ? GrpcWindmillServer.LOCALHOST_MAX_BACKOFF : Duration.millis(options.getWindmillServiceStreamMaxBackoffMillis()); - return GrpcWindmillStreamFactory.of( - JobHeader.newBuilder() - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .setWorkerId(options.getWorkerId()) - .setClientId(clientId) - .build()) + return GrpcWindmillStreamFactory.of(createJobHeader(options, clientId)) .setWindmillMessagesBetweenIsReadyChecks(options.getWindmillMessagesBetweenIsReadyChecks()) .setMaxBackOffSupplier(() -> maxBackoff) .setLogEveryNStreamFailures(options.getWindmillServiceStreamingLogEveryNStreamFailures()) @@ -622,6 +792,15 @@ private static GrpcWindmillStreamFactory.Builder createGrpcwindmillStreamFactory options, "streaming_engine_disable_new_heartbeat_requests")); } + private static JobHeader createJobHeader(DataflowWorkerHarnessOptions options, long clientId) { + return JobHeader.newBuilder() + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .setClientId(clientId) + .build(); + } + private static BoundedQueueExecutor createWorkUnitExecutor(DataflowWorkerHarnessOptions options) { return new BoundedQueueExecutor( chooseMaxThreads(options), @@ -640,15 +819,7 @@ public static void main(String[] args) throws Exception { DataflowWorkerHarnessHelper.initializeGlobalStateAndPipelineOptions( StreamingDataflowWorker.class, DataflowWorkerHarnessOptions.class); DataflowWorkerHarnessHelper.configureLogging(options); - checkArgument( - options.isStreaming(), - "%s instantiated with options indicating batch use", - StreamingDataflowWorker.class.getName()); - - checkArgument( - !DataflowRunner.hasExperiment(options, "beam_fn_api"), - "%s cannot be main() class with beam_fn_api enabled", - StreamingDataflowWorker.class.getSimpleName()); + validateWorkerOptions(options); CoderTranslation.verifyModelCodersRegistered(); @@ -705,21 +876,6 @@ void reportPeriodicWorkerUpdatesForTest() { workerStatusReporter.reportPeriodicWorkerUpdates(); } - private int chooseMaximumNumberOfThreads() { - if (options.getNumberOfWorkerHarnessThreads() != 0) { - return options.getNumberOfWorkerHarnessThreads(); - } - return MAX_PROCESSING_THREADS; - } - - private int chooseMaximumBundlesOutstanding() { - int maxBundles = options.getMaxBundlesFromWindmillOutstanding(); - if (maxBundles > 0) { - return maxBundles; - } - return chooseMaximumNumberOfThreads() + 100; - } - @VisibleForTesting public boolean workExecutorIsEmpty() { return workUnitExecutor.executorQueueIsEmpty(); @@ -727,7 +883,7 @@ public boolean workExecutorIsEmpty() { @VisibleForTesting int numCommitThreads() { - return workCommitter.parallelism(); + return numCommitThreads; } @VisibleForTesting @@ -740,7 +896,6 @@ ComputationStateCache getComputationStateCache() { return computationStateCache; } - @SuppressWarnings("FutureReturnValueIgnored") public void start() { running.set(true); configFetcher.start(); @@ -791,27 +946,17 @@ private void onCompleteCommit(CompleteCommit completeCommit) { completeCommit.shardedKey(), completeCommit.workId())); } - @VisibleForTesting - public Iterable buildCounters() { - return Iterables.concat( - streamingCounters - .pendingDeltaCounters() - .extractModifiedDeltaUpdates(DataflowCounterUpdateExtractor.INSTANCE), - streamingCounters - .pendingCumulativeCounters() - .extractUpdates(false, DataflowCounterUpdateExtractor.INSTANCE)); + @FunctionalInterface + private interface StreamingWorkerStatusReporterFactory { + StreamingWorkerStatusReporter createStatusReporter(ThrottledTimeTracker throttledTimeTracker); } @AutoValue abstract static class ConfigFetcherComputationStateCacheAndWindmillClient { - private static ConfigFetcherComputationStateCacheAndWindmillClient create( - ComputationConfig.Fetcher configFetcher, - ComputationStateCache computationStateCache, - WindmillServerStub windmillServer, - GrpcWindmillStreamFactory windmillStreamFactory) { - return new AutoValue_StreamingDataflowWorker_ConfigFetcherComputationStateCacheAndWindmillClient( - configFetcher, computationStateCache, windmillServer, windmillStreamFactory); + private static Builder builder() { + return new AutoValue_StreamingDataflowWorker_ConfigFetcherComputationStateCacheAndWindmillClient + .Builder(); } abstract ComputationConfig.Fetcher configFetcher(); @@ -821,6 +966,23 @@ private static ConfigFetcherComputationStateCacheAndWindmillClient create( abstract WindmillServerStub windmillServer(); abstract GrpcWindmillStreamFactory windmillStreamFactory(); + + abstract @Nullable GrpcDispatcherClient windmillDispatcherClient(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setConfigFetcher(ComputationConfig.Fetcher value); + + abstract Builder setComputationStateCache(ComputationStateCache value); + + abstract Builder setWindmillServer(WindmillServerStub value); + + abstract Builder setWindmillStreamFactory(GrpcWindmillStreamFactory value); + + abstract Builder setWindmillDispatcherClient(GrpcDispatcherClient value); + + abstract ConfigFetcherComputationStateCacheAndWindmillClient build(); + } } /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 1ef1691b0817..21aaa23d3f85 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -68,6 +68,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -125,7 +126,8 @@ private FanOutStreamingEngineWorkerHarness( GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, Function workCommitterFactory, - ThrottlingGetDataMetricTracker getDataMetricTracker) { + ThrottlingGetDataMetricTracker getDataMetricTracker, + ExecutorService workerMetadataConsumer) { this.jobHeader = jobHeader; this.getDataMetricTracker = getDataMetricTracker; this.started = false; @@ -138,9 +140,7 @@ private FanOutStreamingEngineWorkerHarness( this.windmillStreamManager = Executors.newCachedThreadPool( new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build()); - this.workerMetadataConsumer = - Executors.newSingleThreadExecutor( - new ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build()); + this.workerMetadataConsumer = workerMetadataConsumer; this.getWorkBudgetDistributor = getWorkBudgetDistributor; this.totalGetWorkBudget = totalGetWorkBudget; this.activeMetadataVersion = Long.MIN_VALUE; @@ -171,7 +171,11 @@ public static FanOutStreamingEngineWorkerHarness create( getWorkBudgetDistributor, dispatcherClient, workCommitterFactory, - getDataMetricTracker); + getDataMetricTracker, + Executors.newSingleThreadExecutor( + new ThreadFactoryBuilder() + .setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME) + .build())); } @VisibleForTesting @@ -195,7 +199,13 @@ static FanOutStreamingEngineWorkerHarness forTesting( getWorkBudgetDistributor, dispatcherClient, workCommitterFactory, - getDataMetricTracker); + getDataMetricTracker, + // Run the workerMetadataConsumer on the direct calling thread to remove waiting and + // make unit tests more deterministic as we do not have to worry about network IO being + // blocked by the consumeWorkerMetadata() task. Test suites run in different + // environments and non-determinism has lead to past flakiness. See + // https://github.com/apache/beam/issues/28957. + MoreExecutors.newDirectExecutorService()); fanOutStreamingEngineWorkProvider.start(); return fanOutStreamingEngineWorkProvider; } @@ -371,6 +381,7 @@ private void closeStreamSender(Endpoint endpoint, StreamSender sender) { } /** Add up all the throttle times of all streams including GetWorkerMetadataStream. */ + @Override public long getAndResetThrottleTime() { return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getAndResetThrottleTime) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java index 9716b834cac4..65203288e169 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java @@ -37,6 +37,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottledTimeTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemReceiver; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.StreamingWorkScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; @@ -66,6 +67,7 @@ public final class SingleSourceWorkerHarness implements StreamingWorkerHarness { private final Function> computationStateFetcher; private final ExecutorService workProviderExecutor; private final GetWorkSender getWorkSender; + private final ThrottledTimeTracker throttledTimeTracker; SingleSourceWorkerHarness( WorkCommitter workCommitter, @@ -74,7 +76,8 @@ public final class SingleSourceWorkerHarness implements StreamingWorkerHarness { StreamingWorkScheduler streamingWorkScheduler, Runnable waitForResources, Function> computationStateFetcher, - GetWorkSender getWorkSender) { + GetWorkSender getWorkSender, + ThrottledTimeTracker throttledTimeTracker) { this.workCommitter = workCommitter; this.getDataClient = getDataClient; this.heartbeatSender = heartbeatSender; @@ -90,6 +93,7 @@ public final class SingleSourceWorkerHarness implements StreamingWorkerHarness { .build()); this.isRunning = new AtomicBoolean(false); this.getWorkSender = getWorkSender; + this.throttledTimeTracker = throttledTimeTracker; } public static SingleSourceWorkerHarness.Builder builder() { @@ -144,6 +148,11 @@ public void shutdown() { workCommitter.stop(); } + @Override + public long getAndResetThrottleTime() { + return throttledTimeTracker.getAndResetThrottleTime(); + } + private void streamingEngineDispatchLoop( Function getWorkStreamFactory) { while (isRunning.get()) { @@ -254,6 +263,8 @@ Builder setComputationStateFetcher( Builder setGetWorkSender(GetWorkSender getWorkSender); + Builder setThrottledTimeTracker(ThrottledTimeTracker throttledTimeTracker); + SingleSourceWorkerHarness build(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerHarness.java index c1b4570e2260..731a5a4b1b51 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerHarness.java @@ -17,11 +17,12 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottledTimeTracker; import org.apache.beam.sdk.annotations.Internal; /** Provides an interface to start streaming worker processing. */ @Internal -public interface StreamingWorkerHarness { +public interface StreamingWorkerHarness extends ThrottledTimeTracker { void start(); void shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusPages.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusPages.java index 6981312eff1d..ddfc6809231a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusPages.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusPages.java @@ -258,7 +258,7 @@ public interface Builder { Builder setDebugCapture(DebugCapture.Manager debugCapture); - Builder setChannelzServlet(ChannelzServlet channelzServlet); + Builder setChannelzServlet(@Nullable ChannelzServlet channelzServlet); Builder setStateCache(WindmillStateCache stateCache); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java index ba77d8e1ce26..3557f0d193c5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporter.java @@ -27,6 +27,7 @@ import com.google.api.services.dataflow.model.WorkItemStatus; import com.google.api.services.dataflow.model.WorkerMessage; import com.google.api.services.dataflow.model.WorkerMessageResponse; +import com.google.auto.value.AutoBuilder; import java.io.IOException; import java.math.RoundingMode; import java.util.ArrayList; @@ -51,6 +52,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottledTimeTracker; import org.apache.beam.runners.dataflow.worker.windmill.work.processing.failures.FailureTracker; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; @@ -78,7 +80,7 @@ public final class StreamingWorkerStatusReporter { private final int initialMaxThreadCount; private final int initialMaxBundlesOutstanding; private final WorkUnitClient dataflowServiceClient; - private final Supplier windmillQuotaThrottleTime; + private final ThrottledTimeTracker windmillQuotaThrottleTime; private final Supplier> allStageInfo; private final FailureTracker failureTracker; private final StreamingCounters streamingCounters; @@ -97,10 +99,10 @@ public final class StreamingWorkerStatusReporter { // Used to track the number of WorkerMessages that have been sent without PerWorkerMetrics. private final AtomicLong workerMessagesIndex; - private StreamingWorkerStatusReporter( + StreamingWorkerStatusReporter( boolean publishCounters, WorkUnitClient dataflowServiceClient, - Supplier windmillQuotaThrottleTime, + ThrottledTimeTracker windmillQuotaThrottleTime, Supplier> allStageInfo, FailureTracker failureTracker, StreamingCounters streamingCounters, @@ -131,57 +133,13 @@ private StreamingWorkerStatusReporter( this.workerMessagesIndex = new AtomicLong(); } - public static StreamingWorkerStatusReporter create( - WorkUnitClient workUnitClient, - Supplier windmillQuotaThrottleTime, - Supplier> allStageInfo, - FailureTracker failureTracker, - StreamingCounters streamingCounters, - MemoryMonitor memoryMonitor, - BoundedQueueExecutor workExecutor, - long windmillHarnessUpdateReportingPeriodMillis, - long perWorkerMetricsUpdateReportingPeriodMillis) { - return new StreamingWorkerStatusReporter( - /* publishCounters= */ true, - workUnitClient, - windmillQuotaThrottleTime, - allStageInfo, - failureTracker, - streamingCounters, - memoryMonitor, - workExecutor, - threadName -> - Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()), - windmillHarnessUpdateReportingPeriodMillis, - perWorkerMetricsUpdateReportingPeriodMillis); - } - - @VisibleForTesting - public static StreamingWorkerStatusReporter forTesting( - boolean publishCounters, - WorkUnitClient workUnitClient, - Supplier windmillQuotaThrottleTime, - Supplier> allStageInfo, - FailureTracker failureTracker, - StreamingCounters streamingCounters, - MemoryMonitor memoryMonitor, - BoundedQueueExecutor workExecutor, - Function executorFactory, - long windmillHarnessUpdateReportingPeriodMillis, - long perWorkerMetricsUpdateReportingPeriodMillis) { - return new StreamingWorkerStatusReporter( - publishCounters, - workUnitClient, - windmillQuotaThrottleTime, - allStageInfo, - failureTracker, - streamingCounters, - memoryMonitor, - workExecutor, - executorFactory, - windmillHarnessUpdateReportingPeriodMillis, - perWorkerMetricsUpdateReportingPeriodMillis); + public static Builder builder() { + return new AutoBuilder_StreamingWorkerStatusReporter_Builder() + .setPublishCounters(true) + .setExecutorFactory( + threadName -> + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat(threadName).build())); } /** @@ -228,6 +186,22 @@ private static void shutdownExecutor(ScheduledExecutorService executor) { } } + // Calculates the PerWorkerMetrics reporting frequency, ensuring alignment with the + // WorkerMessages RPC schedule. The desired reporting period + // (perWorkerMetricsUpdateReportingPeriodMillis) is adjusted to the nearest multiple + // of the RPC interval (windmillHarnessUpdateReportingPeriodMillis). + private static long getPerWorkerMetricsUpdateFrequency( + long windmillHarnessUpdateReportingPeriodMillis, + long perWorkerMetricsUpdateReportingPeriodMillis) { + if (windmillHarnessUpdateReportingPeriodMillis == 0) { + return 0; + } + return LongMath.divide( + perWorkerMetricsUpdateReportingPeriodMillis, + windmillHarnessUpdateReportingPeriodMillis, + RoundingMode.CEILING); + } + @SuppressWarnings("FutureReturnValueIgnored") public void start() { reportHarnessStartup(); @@ -276,27 +250,13 @@ private void reportHarnessStartup() { } } - // Calculates the PerWorkerMetrics reporting frequency, ensuring alignment with the - // WorkerMessages RPC schedule. The desired reporting period - // (perWorkerMetricsUpdateReportingPeriodMillis) is adjusted to the nearest multiple - // of the RPC interval (windmillHarnessUpdateReportingPeriodMillis). - private static long getPerWorkerMetricsUpdateFrequency( - long windmillHarnessUpdateReportingPeriodMillis, - long perWorkerMetricsUpdateReportingPeriodMillis) { - if (windmillHarnessUpdateReportingPeriodMillis == 0) { - return 0; - } - return LongMath.divide( - perWorkerMetricsUpdateReportingPeriodMillis, - windmillHarnessUpdateReportingPeriodMillis, - RoundingMode.CEILING); - } - /** Sends counter updates to Dataflow backend. */ private void sendWorkerUpdatesToDataflowService( CounterSet deltaCounters, CounterSet cumulativeCounters) throws IOException { // Throttle time is tracked by the windmillServer but is reported to DFE here. - streamingCounters.windmillQuotaThrottling().addValue(windmillQuotaThrottleTime.get()); + streamingCounters + .windmillQuotaThrottling() + .addValue(windmillQuotaThrottleTime.getAndResetThrottleTime()); if (memoryMonitor.isThrashing()) { streamingCounters.memoryThrashing().addValue(1); } @@ -496,4 +456,33 @@ private void updateThreadMetrics() { .maxOutstandingBundles() .addValue((long) workExecutor.maximumElementsOutstanding()); } + + @AutoBuilder + public interface Builder { + Builder setPublishCounters(boolean publishCounters); + + Builder setDataflowServiceClient(WorkUnitClient dataflowServiceClient); + + Builder setWindmillQuotaThrottleTime(ThrottledTimeTracker windmillQuotaThrottledTimeTracker); + + Builder setAllStageInfo(Supplier> allStageInfo); + + Builder setFailureTracker(FailureTracker failureTracker); + + Builder setStreamingCounters(StreamingCounters streamingCounters); + + Builder setMemoryMonitor(MemoryMonitor memoryMonitor); + + Builder setWorkExecutor(BoundedQueueExecutor workExecutor); + + Builder setExecutorFactory(Function executorFactory); + + Builder setWindmillHarnessUpdateReportingPeriodMillis( + long windmillHarnessUpdateReportingPeriodMillis); + + Builder setPerWorkerMetricsUpdateReportingPeriodMillis( + long perWorkerMetricsUpdateReportingPeriodMillis); + + StreamingWorkerStatusReporter build(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index 13b3ea954198..dd7fdd45ab08 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -88,17 +88,23 @@ private static Optional parseDirectEndpoint( .map(address -> AuthenticatedGcpServiceAddress.create(authenticatingService, address)) .map(WindmillServiceAddress::create); - return directEndpointIpV6Address.isPresent() - ? directEndpointIpV6Address - : tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint()) - .map(WindmillServiceAddress::create); + Optional windmillServiceAddress = + directEndpointIpV6Address.isPresent() + ? directEndpointIpV6Address + : tryParseEndpointIntoHostAndPort(endpointProto.getDirectEndpoint()) + .map(WindmillServiceAddress::create); + + if (!windmillServiceAddress.isPresent()) { + LOG.warn("Endpoint {} could not be parsed into a WindmillServiceAddress.", endpointProto); + } + + return windmillServiceAddress; } private static Optional tryParseEndpointIntoHostAndPort(String directEndpoint) { try { return Optional.of(HostAndPort.fromString(directEndpoint)); } catch (IllegalArgumentException e) { - LOG.warn("{} cannot be parsed into a gcpServiceAddress", directEndpoint); return Optional.empty(); } } @@ -113,19 +119,12 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( try { directEndpointAddress = Inet6Address.getByName(endpointProto.getDirectEndpoint()); } catch (UnknownHostException e) { - LOG.warn( - "Error occurred trying to parse direct_endpoint={} into IPv6 address. Exception={}", - endpointProto.getDirectEndpoint(), - e.toString()); return Optional.empty(); } // Inet6Address.getByAddress returns either an IPv4 or an IPv6 address depending on the format // of the direct_endpoint string. if (!(directEndpointAddress instanceof Inet6Address)) { - LOG.warn( - "{} is not an IPv6 address. Direct endpoints are expected to be in IPv6 format.", - endpointProto.getDirectEndpoint()); return Optional.empty(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index 6bae84483d16..234888831779 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -150,6 +150,8 @@ public CloudWindmillMetadataServiceV1Alpha1Stub getWindmillMetadataServiceStubBl } } + LOG.info("Windmill Service endpoint initialized after {} seconds.", secondsWaited); + ImmutableList windmillMetadataServiceStubs = dispatcherStubs.get().windmillMetadataServiceStubs(); @@ -190,7 +192,7 @@ public void onJobConfig(StreamingGlobalConfig config) { public synchronized void consumeWindmillDispatcherEndpoints( ImmutableSet dispatcherEndpoints) { - consumeWindmillDispatcherEndpoints(dispatcherEndpoints, /*forceRecreateStubs=*/ false); + consumeWindmillDispatcherEndpoints(dispatcherEndpoints, /* forceRecreateStubs= */ false); } private synchronized void consumeWindmillDispatcherEndpoints( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCache.java index db012c6bb412..c03459ee732e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCache.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs; import java.io.PrintWriter; +import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.function.Function; @@ -31,6 +32,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListener; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalListeners; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,14 +52,11 @@ public final class ChannelCache implements StatusDataProvider { private ChannelCache( Function channelFactory, - RemovalListener onChannelRemoved) { + RemovalListener onChannelRemoved, + Executor channelCloser) { this.channelCache = CacheBuilder.newBuilder() - .removalListener( - RemovalListeners.asynchronous( - onChannelRemoved, - Executors.newCachedThreadPool( - new ThreadFactoryBuilder().setNameFormat("GrpcChannelCloser").build()))) + .removalListener(RemovalListeners.asynchronous(onChannelRemoved, channelCloser)) .build( new CacheLoader() { @Override @@ -72,11 +71,13 @@ public static ChannelCache create( return new ChannelCache( channelFactory, // Shutdown the channels as they get removed from the cache, so they do not leak. - notification -> shutdownChannel(notification.getValue())); + notification -> shutdownChannel(notification.getValue()), + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat("GrpcChannelCloser").build())); } @VisibleForTesting - static ChannelCache forTesting( + public static ChannelCache forTesting( Function channelFactory, Runnable onChannelShutdown) { return new ChannelCache( channelFactory, @@ -85,7 +86,11 @@ static ChannelCache forTesting( notification -> { shutdownChannel(notification.getValue()); onChannelShutdown.run(); - }); + }, + // Run the removal synchronously on the calling thread to prevent waiting on asynchronous + // tasks to run and make unit tests deterministic. In testing, we verify that things are + // removed from the cache. + MoreExecutors.directExecutor()); } private static void shutdownChannel(ManagedChannel channel) { @@ -108,6 +113,7 @@ public void remove(WindmillServiceAddress windmillServiceAddress) { public void clear() { channelCache.invalidateAll(); + channelCache.cleanUp(); } /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottleTimer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottleTimer.java index f660112721ba..fdcb0339d23d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottleTimer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottleTimer.java @@ -25,7 +25,7 @@ * CommitWork are both blocked for x, totalTime will be 2x. However, if 2 GetWork streams are both * blocked for x totalTime will be x. All methods are thread safe. */ -public final class ThrottleTimer { +public final class ThrottleTimer implements ThrottledTimeTracker { // This is -1 if not currently being throttled or the time in // milliseconds when throttling for this type started. private long startTime = -1; @@ -56,6 +56,7 @@ public synchronized boolean throttled() { } /** Returns the combined total of all throttle times and resets those times to 0. */ + @Override public synchronized long getAndResetThrottleTime() { if (throttled()) { stop(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottledTimeTracker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottledTimeTracker.java new file mode 100644 index 000000000000..9bb8fb0a7b5f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/throttling/ThrottledTimeTracker.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.throttling; + +import org.apache.beam.sdk.annotations.Internal; + +/** + * Tracks time spent in a throttled state due to {@code Status.RESOURCE_EXHAUSTED} errors returned + * from gRPC calls. + */ +@Internal +@FunctionalInterface +public interface ThrottledTimeTracker { + + /** Returns the combined total of all throttle times and resets those times to 0. */ + long getAndResetThrottleTime(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java index fa36b11ffe55..f54091dc2b95 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSender.java @@ -42,7 +42,7 @@ private StreamPoolHeartbeatSender( this.heartbeatStreamPool.set(heartbeatStreamPool); } - public static StreamPoolHeartbeatSender Create( + public static StreamPoolHeartbeatSender create( @Nonnull WindmillStreamPool heartbeatStreamPool) { return new StreamPoolHeartbeatSender(heartbeatStreamPool); } @@ -55,7 +55,7 @@ public static StreamPoolHeartbeatSender Create( * enabled. * @param getDataPool stream to use when using separate streams for heartbeat is disabled. */ - public static StreamPoolHeartbeatSender Create( + public static StreamPoolHeartbeatSender create( @Nonnull WindmillStreamPool dedicatedHeartbeatPool, @Nonnull WindmillStreamPool getDataPool, @Nonnull StreamingGlobalConfigHandle configHandle) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index dadf02171235..6eeb7bd6bbfc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -96,6 +96,7 @@ import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.util.Structs; +import org.apache.beam.runners.dataflow.worker.counters.DataflowCounterUpdateExtractor; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.ComputationStateCache; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; @@ -104,6 +105,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfigHandleImpl; +import org.apache.beam.runners.dataflow.worker.streaming.harness.StreamingCounters; import org.apache.beam.runners.dataflow.worker.testing.RestoreDataflowLoggingMDC; import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; @@ -129,6 +131,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WatermarkHold; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; +import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory; +import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactoryFactory; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; @@ -178,6 +183,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.UnsignedLong; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -285,6 +291,7 @@ public Long get() { private final FakeWindmillServer server = new FakeWindmillServer( errorCollector, computationId -> computationStateCache.get(computationId)); + private StreamingCounters streamingCounters; public StreamingDataflowWorkerTest(Boolean streamingEngine) { this.streamingEngine = streamingEngine; @@ -304,9 +311,20 @@ private static CounterUpdate getCounter(Iterable counters, String return null; } + private Iterable buildCounters() { + return Iterables.concat( + streamingCounters + .pendingDeltaCounters() + .extractModifiedDeltaUpdates(DataflowCounterUpdateExtractor.INSTANCE), + streamingCounters + .pendingCumulativeCounters() + .extractUpdates(false, DataflowCounterUpdateExtractor.INSTANCE)); + } + @Before public void setUp() { server.clearCommitsReceived(); + streamingCounters = StreamingCounters.create(); } @After @@ -856,7 +874,13 @@ private StreamingDataflowWorker makeWorker( streamingDataflowWorkerTestParams.clock(), streamingDataflowWorkerTestParams.executorSupplier(), mockGlobalConfigHandle, - streamingDataflowWorkerTestParams.localRetryTimeoutMs()); + streamingDataflowWorkerTestParams.localRetryTimeoutMs(), + streamingCounters, + new FakeWindmillStubFactoryFactory( + new FakeWindmillStubFactory( + () -> + WindmillChannelFactory.inProcessChannel( + "StreamingDataflowWorkerTestChannel")))); this.computationStateCache = worker.getComputationStateCache(); return worker; } @@ -1715,7 +1739,7 @@ public void testMergeWindows() throws Exception { intervalWindowBytes(WINDOW_AT_ZERO))); Map result = server.waitForAndGetCommits(1); - Iterable counters = worker.buildCounters(); + Iterable counters = buildCounters(); // These tags and data are opaque strings and this is a change detector test. // The "/u" indicates the user's namespace, versus "/s" for system namespace @@ -1836,7 +1860,7 @@ public void testMergeWindows() throws Exception { expectedBytesRead += dataBuilder.build().getSerializedSize(); result = server.waitForAndGetCommits(1); - counters = worker.buildCounters(); + counters = buildCounters(); actualOutput = result.get(2L); assertEquals(1, actualOutput.getOutputMessagesCount()); @@ -2004,7 +2028,7 @@ public void testMergeWindowsCaching() throws Exception { intervalWindowBytes(WINDOW_AT_ZERO))); Map result = server.waitForAndGetCommits(1); - Iterable counters = worker.buildCounters(); + Iterable counters = buildCounters(); // These tags and data are opaque strings and this is a change detector test. // The "/u" indicates the user's namespace, versus "/s" for system namespace @@ -2125,7 +2149,7 @@ public void testMergeWindowsCaching() throws Exception { expectedBytesRead += dataBuilder.build().getSerializedSize(); result = server.waitForAndGetCommits(1); - counters = worker.buildCounters(); + counters = buildCounters(); actualOutput = result.get(2L); assertEquals(1, actualOutput.getOutputMessagesCount()); @@ -2430,7 +2454,7 @@ public void testUnboundedSources() throws Exception { null)); Map result = server.waitForAndGetCommits(1); - Iterable counters = worker.buildCounters(); + Iterable counters = buildCounters(); Windmill.WorkItemCommitRequest commit = result.get(1L); UnsignedLong finalizeId = UnsignedLong.fromLongBits(commit.getSourceStateUpdates().getFinalizeIds(0)); @@ -2492,7 +2516,7 @@ public void testUnboundedSources() throws Exception { null)); result = server.waitForAndGetCommits(1); - counters = worker.buildCounters(); + counters = buildCounters(); commit = result.get(2L); finalizeId = UnsignedLong.fromLongBits(commit.getSourceStateUpdates().getFinalizeIds(0)); @@ -2540,7 +2564,7 @@ public void testUnboundedSources() throws Exception { null)); result = server.waitForAndGetCommits(1); - counters = worker.buildCounters(); + counters = buildCounters(); commit = result.get(3L); finalizeId = UnsignedLong.fromLongBits(commit.getSourceStateUpdates().getFinalizeIds(0)); @@ -2710,7 +2734,7 @@ public void testUnboundedSourceWorkRetry() throws Exception { server.whenGetWorkCalled().thenReturn(work); Map result = server.waitForAndGetCommits(1); - Iterable counters = worker.buildCounters(); + Iterable counters = buildCounters(); Windmill.WorkItemCommitRequest commit = result.get(1L); UnsignedLong finalizeId = UnsignedLong.fromLongBits(commit.getSourceStateUpdates().getFinalizeIds(0)); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index bba6cad5529a..606d2b9dbdbc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -33,8 +33,6 @@ import java.util.HashSet; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; @@ -65,7 +63,6 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessSocketAddress; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; @@ -103,7 +100,6 @@ public class FanOutStreamingEngineWorkerHarnessTest { .build(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); private final GrpcWindmillStreamFactory streamFactory = spy(GrpcWindmillStreamFactory.of(JOB_HEADER).build()); private final ChannelCachingStubFactory stubFactory = @@ -146,22 +142,21 @@ private static WorkerMetadataResponse.Endpoint metadataResponseEndpoint(String w @Before public void setUp() throws IOException { - stubFactory.shutdown(); + getWorkerMetadataReady = new CountDownLatch(1); + fakeGetWorkerMetadataStub = new GetWorkerMetadataTestStub(getWorkerMetadataReady); fakeStreamingEngineServer = - grpcCleanup.register( - InProcessServerBuilder.forName(CHANNEL_NAME) - .fallbackHandlerRegistry(serviceRegistry) - .executor(Executors.newFixedThreadPool(1)) - .build()); + grpcCleanup + .register( + InProcessServerBuilder.forName(CHANNEL_NAME) + .directExecutor() + .addService(fakeGetWorkerMetadataStub) + .addService(new WindmillServiceFakeStub()) + .build()) + .start(); - fakeStreamingEngineServer.start(); dispatcherClient.consumeWindmillDispatcherEndpoints( ImmutableSet.of( HostAndPort.fromString(new InProcessSocketAddress(CHANNEL_NAME).toString()))); - getWorkerMetadataReady = new CountDownLatch(1); - fakeGetWorkerMetadataStub = new GetWorkerMetadataTestStub(getWorkerMetadataReady); - serviceRegistry.addService(fakeGetWorkerMetadataStub); - serviceRegistry.addService(new WindmillServiceFakeStub()); } @After @@ -174,27 +169,29 @@ public void cleanUp() { private FanOutStreamingEngineWorkerHarness newFanOutStreamingEngineWorkerHarness( GetWorkBudget getWorkBudget, GetWorkBudgetDistributor getWorkBudgetDistributor, - WorkItemScheduler workItemScheduler) { - return FanOutStreamingEngineWorkerHarness.forTesting( - JOB_HEADER, - getWorkBudget, - streamFactory, - workItemScheduler, - stubFactory, - getWorkBudgetDistributor, - dispatcherClient, - ignored -> mock(WorkCommitter.class), - new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class))); + WorkItemScheduler workItemScheduler) + throws InterruptedException { + FanOutStreamingEngineWorkerHarness harness = + FanOutStreamingEngineWorkerHarness.forTesting( + JOB_HEADER, + getWorkBudget, + streamFactory, + workItemScheduler, + stubFactory, + getWorkBudgetDistributor, + dispatcherClient, + ignored -> mock(WorkCommitter.class), + new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class))); + getWorkerMetadataReady.await(); + return harness; } @Test public void testStreamsStartCorrectly() throws InterruptedException { long items = 10L; long bytes = 10L; - int numBudgetDistributionsExpected = 1; - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected)); + TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor()); fanOutStreamingEngineWorkProvider = newFanOutStreamingEngineWorkerHarness( @@ -205,17 +202,13 @@ public void testStreamsStartCorrectly() throws InterruptedException { String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; - WorkerMetadataResponse firstWorkerMetadata = + fakeGetWorkerMetadataStub.injectWorkerMetadata( WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) .addWorkEndpoints(metadataResponseEndpoint(workerToken)) .addWorkEndpoints(metadataResponseEndpoint(workerToken2)) .putAllGlobalDataEndpoints(DEFAULT) - .build(); - - getWorkerMetadataReady.await(); - fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + .build()); StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); @@ -249,8 +242,7 @@ public void testStreamsStartCorrectly() throws InterruptedException { @Test public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() throws InterruptedException { - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(1)); + TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor()); fanOutStreamingEngineWorkProvider = newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), @@ -283,12 +275,8 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() .build()) .build(); - getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - getWorkBudgetDistributor.expectNumDistributions(1); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); assertEquals(1, currentBackends.windmillStreams().size()); Set workerTokens = @@ -325,21 +313,15 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce .putAllGlobalDataEndpoints(DEFAULT) .build(); - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(1)); + TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor()); fanOutStreamingEngineWorkProvider = newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); - getWorkerMetadataReady.await(); - fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - getWorkBudgetDistributor.expectNumDistributions(1); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any()); } @@ -354,10 +336,14 @@ public StreamObserver getDataStream( public void onNext(Windmill.StreamingGetDataRequest getDataRequest) {} @Override - public void onError(Throwable throwable) {} + public void onError(Throwable throwable) { + responseObserver.onError(throwable); + } @Override - public void onCompleted() {} + public void onCompleted() { + responseObserver.onCompleted(); + } }; } @@ -369,10 +355,14 @@ public StreamObserver getWorkStream( public void onNext(Windmill.StreamingGetWorkRequest getWorkRequest) {} @Override - public void onError(Throwable throwable) {} + public void onError(Throwable throwable) { + responseObserver.onError(throwable); + } @Override - public void onCompleted() {} + public void onCompleted() { + responseObserver.onCompleted(); + } }; } @@ -384,10 +374,14 @@ public StreamObserver commitWorkStream( public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {} @Override - public void onError(Throwable throwable) {} + public void onError(Throwable throwable) { + responseObserver.onError(throwable); + } @Override - public void onCompleted() {} + public void onCompleted() { + responseObserver.onCompleted(); + } }; } } @@ -422,7 +416,11 @@ public void onError(Throwable throwable) { } @Override - public void onCompleted() {} + public void onCompleted() { + if (responseObserver != null) { + responseObserver.onCompleted(); + } + } }; } @@ -434,25 +432,10 @@ private void injectWorkerMetadata(WorkerMetadataResponse response) { } private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor { - private CountDownLatch getWorkBudgetDistributorTriggered; - - private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) { - this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); - } - - private boolean waitForBudgetDistribution() throws InterruptedException { - return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); - } - - private void expectNumDistributions(int numBudgetDistributionsExpected) { - this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); - } - @Override public void distributeBudget( ImmutableCollection streams, GetWorkBudget getWorkBudget) { streams.forEach(stream -> stream.setBudget(getWorkBudget.items(), getWorkBudget.bytes())); - getWorkBudgetDistributorTriggered.countDown(); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java index 5a2df4baae61..4df3bf7cd823 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java @@ -60,6 +60,8 @@ private SingleSourceWorkerHarness createWorkerHarness( .setWaitForResources(waitForResources) .setStreamingWorkScheduler(streamingWorkScheduler) .setComputationStateFetcher(computationStateFetcher) + // no-op throttle time supplier. + .setThrottledTimeTracker(() -> 0L) .setGetWorkSender(getWorkSender) .build(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java index 7e65a495638f..f348e4cf1bdb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingWorkerStatusReporterTest.java @@ -39,14 +39,15 @@ @RunWith(JUnit4.class) public class StreamingWorkerStatusReporterTest { - private final long DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME = 1000; - private final long DEFAULT_HARNESS_REPORTING_PERIOD = 10000; - private final long DEFAULT_PER_WORKER_METRICS_PERIOD = 30000; + private static final long DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME = 1000; + private static final long DEFAULT_HARNESS_REPORTING_PERIOD = 10000; + private static final long DEFAULT_PER_WORKER_METRICS_PERIOD = 30000; private BoundedQueueExecutor mockExecutor; private WorkUnitClient mockWorkUnitClient; private FailureTracker mockFailureTracker; private MemoryMonitor mockMemoryMonitor; + private StreamingWorkerStatusReporter reporter; @Before public void setUp() { @@ -54,23 +55,11 @@ public void setUp() { this.mockWorkUnitClient = mock(WorkUnitClient.class); this.mockFailureTracker = mock(FailureTracker.class); this.mockMemoryMonitor = mock(MemoryMonitor.class); + this.reporter = buildWorkerStatusReporterForTest(); } @Test public void testOverrideMaximumThreadCount() throws Exception { - StreamingWorkerStatusReporter reporter = - StreamingWorkerStatusReporter.forTesting( - true, - mockWorkUnitClient, - () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME, - () -> Collections.emptyList(), - mockFailureTracker, - StreamingCounters.create(), - mockMemoryMonitor, - mockExecutor, - (threadName) -> Executors.newSingleThreadScheduledExecutor(), - DEFAULT_HARNESS_REPORTING_PERIOD, - DEFAULT_PER_WORKER_METRICS_PERIOD); StreamingScalingReportResponse streamingScalingReportResponse = new StreamingScalingReportResponse().setMaximumThreadCount(10); WorkerMessageResponse workerMessageResponse = @@ -84,23 +73,25 @@ public void testOverrideMaximumThreadCount() throws Exception { @Test public void testHandleEmptyWorkerMessageResponse() throws Exception { - StreamingWorkerStatusReporter reporter = - StreamingWorkerStatusReporter.forTesting( - true, - mockWorkUnitClient, - () -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME, - () -> Collections.emptyList(), - mockFailureTracker, - StreamingCounters.create(), - mockMemoryMonitor, - mockExecutor, - (threadName) -> Executors.newSingleThreadScheduledExecutor(), - DEFAULT_HARNESS_REPORTING_PERIOD, - DEFAULT_PER_WORKER_METRICS_PERIOD); - WorkerMessageResponse workerMessageResponse = new WorkerMessageResponse(); when(mockWorkUnitClient.reportWorkerMessage(any())) - .thenReturn(Collections.singletonList(workerMessageResponse)); + .thenReturn(Collections.singletonList(new WorkerMessageResponse())); reporter.reportPeriodicWorkerMessage(); verify(mockExecutor, Mockito.times(0)).setMaximumPoolSize(anyInt(), anyInt()); } + + private StreamingWorkerStatusReporter buildWorkerStatusReporterForTest() { + return StreamingWorkerStatusReporter.builder() + .setPublishCounters(true) + .setDataflowServiceClient(mockWorkUnitClient) + .setWindmillQuotaThrottleTime(() -> DEFAULT_WINDMILL_QUOTA_THROTTLE_TIME) + .setAllStageInfo(Collections::emptyList) + .setFailureTracker(mockFailureTracker) + .setStreamingCounters(StreamingCounters.create()) + .setMemoryMonitor(mockMemoryMonitor) + .setWorkExecutor(mockExecutor) + .setExecutorFactory((threadName) -> Executors.newSingleThreadScheduledExecutor()) + .setWindmillHarnessUpdateReportingPeriodMillis(DEFAULT_HARNESS_REPORTING_PERIOD) + .setPerWorkerMetricsUpdateReportingPeriodMillis(DEFAULT_PER_WORKER_METRICS_PERIOD) + .build(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCacheTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCacheTest.java index 9f8a901cb629..1781261e3400 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCacheTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/ChannelCacheTest.java @@ -105,19 +105,10 @@ public ManagedChannel apply(WindmillServiceAddress windmillServiceAddress) { @Test public void testRemoveAndClose() throws InterruptedException { String channelName = "existingChannel"; - CountDownLatch verifyRemovalListenerAsync = new CountDownLatch(1); CountDownLatch notifyWhenChannelClosed = new CountDownLatch(1); cache = ChannelCache.forTesting( - ignored -> newChannel(channelName), - () -> { - try { - verifyRemovalListenerAsync.await(); - notifyWhenChannelClosed.countDown(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - }); + ignored -> newChannel(channelName), notifyWhenChannelClosed::countDown); WindmillServiceAddress someAddress = mock(WindmillServiceAddress.class); ManagedChannel cachedChannel = cache.get(someAddress); @@ -125,7 +116,6 @@ public void testRemoveAndClose() throws InterruptedException { // Assert that the removal happened before we check to see if the shutdowns happen to confirm // that removals are async. assertTrue(cache.isEmpty()); - verifyRemovalListenerAsync.countDown(); // Assert that the channel gets shutdown. notifyWhenChannelClosed.await(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java index af3a3e8295bb..19e05efb50c6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/testing/FakeWindmillStubFactory.java @@ -31,7 +31,7 @@ public final class FakeWindmillStubFactory implements ChannelCachingStubFactory private final ChannelCache channelCache; public FakeWindmillStubFactory(Supplier channelFactory) { - this.channelCache = ChannelCache.create(ignored -> channelFactory.get()); + this.channelCache = ChannelCache.forTesting(ignored -> channelFactory.get(), () -> {}); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSenderTest.java index ed915088d0a6..acbb3aebbcf5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/StreamPoolHeartbeatSenderTest.java @@ -39,7 +39,7 @@ public class StreamPoolHeartbeatSenderTest { public void sendsHeartbeatsOnStream() { FakeWindmillServer server = new FakeWindmillServer(new ErrorCollector(), c -> Optional.empty()); StreamPoolHeartbeatSender heartbeatSender = - StreamPoolHeartbeatSender.Create( + StreamPoolHeartbeatSender.create( WindmillStreamPool.create(1, Duration.standardSeconds(10), server::getDataStream)); Heartbeats.Builder heartbeatsBuilder = Heartbeats.builder(); heartbeatsBuilder @@ -59,7 +59,7 @@ public void sendsHeartbeatsOnDedicatedStream() { FakeGlobalConfigHandle configHandle = new FakeGlobalConfigHandle(getGlobalConfig(/*useSeparateHeartbeatStreams=*/ true)); StreamPoolHeartbeatSender heartbeatSender = - StreamPoolHeartbeatSender.Create( + StreamPoolHeartbeatSender.create( WindmillStreamPool.create( 1, Duration.standardSeconds(10), dedicatedServer::getDataStream), WindmillStreamPool.create( @@ -104,7 +104,7 @@ public void sendsHeartbeatsOnGetDataStream() { FakeGlobalConfigHandle configHandle = new FakeGlobalConfigHandle(getGlobalConfig(/*useSeparateHeartbeatStreams=*/ false)); StreamPoolHeartbeatSender heartbeatSender = - StreamPoolHeartbeatSender.Create( + StreamPoolHeartbeatSender.create( WindmillStreamPool.create( 1, Duration.standardSeconds(10), dedicatedServer::getDataStream), WindmillStreamPool.create( diff --git a/sdks/go.mod b/sdks/go.mod index 42f099d747bb..ef966f0dcef4 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -28,14 +28,14 @@ require ( cloud.google.com/go/datastore v1.20.0 cloud.google.com/go/profiler v0.4.1 cloud.google.com/go/pubsub v1.45.1 - cloud.google.com/go/spanner v1.70.0 + cloud.google.com/go/spanner v1.73.0 cloud.google.com/go/storage v1.45.0 - github.com/aws/aws-sdk-go-v2 v1.32.4 + github.com/aws/aws-sdk-go-v2 v1.32.6 github.com/aws/aws-sdk-go-v2/config v1.28.4 github.com/aws/aws-sdk-go-v2/credentials v1.17.45 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.38 github.com/aws/aws-sdk-go-v2/service/s3 v1.67.0 - github.com/aws/smithy-go v1.22.0 + github.com/aws/smithy-go v1.22.1 github.com/docker/go-connections v0.5.0 github.com/dustin/go-humanize v1.0.1 github.com/go-sql-driver/mysql v1.8.1 @@ -53,11 +53,11 @@ require ( github.com/xitongsys/parquet-go v1.6.2 github.com/xitongsys/parquet-go-source v0.0.0-20220315005136-aec0fe3e777c go.mongodb.org/mongo-driver v1.17.1 - golang.org/x/net v0.30.0 - golang.org/x/oauth2 v0.23.0 - golang.org/x/sync v0.8.0 + golang.org/x/net v0.31.0 + golang.org/x/oauth2 v0.24.0 + golang.org/x/sync v0.9.0 golang.org/x/sys v0.27.0 - golang.org/x/text v0.19.0 + golang.org/x/text v0.20.0 google.golang.org/api v0.203.0 google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 google.golang.org/grpc v1.67.1 @@ -190,7 +190,7 @@ require ( github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect go.opencensus.io v0.24.0 // indirect - golang.org/x/crypto v0.28.0 // indirect + golang.org/x/crypto v0.29.0 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/tools v0.24.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect diff --git a/sdks/go.sum b/sdks/go.sum index 5e42448c7d50..5527a59f4c52 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -542,8 +542,8 @@ cloud.google.com/go/shell v1.6.0/go.mod h1:oHO8QACS90luWgxP3N9iZVuEiSF84zNyLytb+ cloud.google.com/go/spanner v1.41.0/go.mod h1:MLYDBJR/dY4Wt7ZaMIQ7rXOTLjYrmxLE/5ve9vFfWos= cloud.google.com/go/spanner v1.44.0/go.mod h1:G8XIgYdOK+Fbcpbs7p2fiprDw4CaZX63whnSMLVBxjk= cloud.google.com/go/spanner v1.45.0/go.mod h1:FIws5LowYz8YAE1J8fOS7DJup8ff7xJeetWEo5REA2M= -cloud.google.com/go/spanner v1.70.0 h1:nj6p/GJTgMDiSQ1gQ034ItsKuJgHiMOjtOlONOg8PSo= -cloud.google.com/go/spanner v1.70.0/go.mod h1:X5T0XftydYp0K1adeJQDJtdWpbrOeJ7wHecM4tK6FiE= +cloud.google.com/go/spanner v1.73.0 h1:0bab8QDn6MNj9lNK6XyGAVFhMlhMU2waePPa6GZNoi8= +cloud.google.com/go/spanner v1.73.0/go.mod h1:mw98ua5ggQXVWwp83yjwggqEmW9t8rjs9Po1ohcUGW4= cloud.google.com/go/speech v1.6.0/go.mod h1:79tcr4FHCimOp56lwC01xnt/WPJZc4v3gzyT7FoBkCM= cloud.google.com/go/speech v1.7.0/go.mod h1:KptqL+BAQIhMsj1kOP2la5DSEEerPDuOP/2mmkhHhZQ= cloud.google.com/go/speech v1.8.0/go.mod h1:9bYIl1/tjsAnMgKGHKmBZzXKEkGgtU+MpdDPTE9f7y0= @@ -689,8 +689,8 @@ github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo= github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250= -github.com/aws/aws-sdk-go-v2 v1.32.4 h1:S13INUiTxgrPueTmrm5DZ+MiAo99zYzHEFh1UNkOxNE= -github.com/aws/aws-sdk-go-v2 v1.32.4/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= +github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= +github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 h1:pT3hpW0cOHRJx8Y0DfJUEQuqPild8jRGmSFmBgvydr0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= @@ -737,8 +737,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BV github.com/aws/aws-sdk-go-v2/service/sts v1.33.0 h1:s7LRgBqhwLaxcocnAniBJp7gaAB+4I4vHzqUqjH18yc= github.com/aws/aws-sdk-go-v2/service/sts v1.33.0/go.mod h1:9XEUty5v5UAsMiFOBJrNibZgwCeOma73jgGwwhgffa8= github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= -github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= -github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= @@ -1264,8 +1264,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1386,8 +1386,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1417,8 +1417,8 @@ golang.org/x/oauth2 v0.4.0/go.mod h1:RznEsdpjGAINPTOF0UH/t+xJ75L18YO3Ho6Pyn+uRec golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I= golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= -golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= -golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= +golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1435,8 +1435,8 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1534,8 +1534,8 @@ golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= -golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/term v0.26.0 h1:WEQa6V3Gja/BhNxg540hBip/kkaYtRg3cxg4oXSw4AU= +golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1552,8 +1552,8 @@ golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/sdks/java/container/Dockerfile-distroless b/sdks/java/container/Dockerfile-distroless new file mode 100644 index 000000000000..328c4dc6a7b3 --- /dev/null +++ b/sdks/java/container/Dockerfile-distroless @@ -0,0 +1,42 @@ +############################################################################### +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + +# ARG BEAM_BASE is the Beam SDK container image built using sdks/python/container/Dockerfile. +ARG BEAM_BASE + +# ARG DISTROLESS_BASE is the distroless container image URL. For available distroless Java images, +# see https://github.com/GoogleContainerTools/distroless/tree/main?tab=readme-ov-file#what-images-are-available. +# Only Java versions 17 and 21 are supported. +ARG DISTROLESS_BASE +FROM ${BEAM_BASE} AS base +ARG TARGETARCH +ENV LANG C.UTF-8 + +LABEL Author="Apache Beam " + +RUN if [ -z "${TARGETARCH}" ]; then echo "fatal: TARGETARCH not set; run as docker buildx build or use --build-arg=TARGETARCH=amd64|arm64" >&2; exit 1; fi + +FROM ${DISTROLESS_BASE}:latest-${TARGETARCH} AS distroless + +COPY --from=base /opt /opt + +# Along with the LANG environment variable above, prevents internally discovered failing bugs related to Dataflow Flex +# template character encodings. +COPY --from=base /usr/lib/locale /usr/lib/locale + +ENTRYPOINT ["/opt/apache/beam/boot"] diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java index 6e5843f533db..78ea34503e54 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java @@ -20,6 +20,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import com.fasterxml.jackson.annotation.JsonCreator; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -386,4 +387,21 @@ static List getConfiguredLoggerFromOptions(SdkHarnessOptions loggingOpti } return configuredLoggers; } + + @Hidden + @Description( + "Timeout used for cache of bundle processors. Defaults to a minute for batch and an hour for streaming.") + @Default.InstanceFactory(BundleProcessorCacheTimeoutFactory.class) + Duration getBundleProcessorCacheTimeout(); + + void setBundleProcessorCacheTimeout(Duration duration); + + class BundleProcessorCacheTimeoutFactory implements DefaultValueFactory { + @Override + public Duration create(PipelineOptions options) { + return options.as(StreamingOptions.class).isStreaming() + ? Duration.ofHours(1) + : Duration.ofMinutes(1); + } + } } diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcsOptions.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcsOptions.java index 18d637254115..1285b88663e7 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcsOptions.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/options/GcsOptions.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.extensions.gcp.options; import com.fasterxml.jackson.annotation.JsonIgnore; +import com.google.cloud.hadoop.gcsio.GoogleCloudStorageReadOptions; import com.google.cloud.hadoop.util.AsyncWriteChannelOptions; import java.util.concurrent.ExecutorService; import org.apache.beam.sdk.extensions.gcp.storage.GcsPathValidator; @@ -44,6 +45,15 @@ public interface GcsOptions extends ApplicationNameOptions, GcpOptions, Pipeline void setGcsUtil(GcsUtil value); + @JsonIgnore + @Description( + "The GoogleCloudStorageReadOptions instance that should be used to read from Google Cloud Storage.") + @Default.InstanceFactory(GcsUtil.GcsReadOptionsFactory.class) + @Hidden + GoogleCloudStorageReadOptions getGoogleCloudStorageReadOptions(); + + void setGoogleCloudStorageReadOptions(GoogleCloudStorageReadOptions value); + /** * The ExecutorService instance to use to create threads, can be overridden to specify an * ExecutorService that is compatible with the user's environment. If unset, the default is to use diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtil.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtil.java index 8d3596f17b3b..d58154132a72 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtil.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtil.java @@ -123,6 +123,14 @@ public static GcsCountersOptions create( } } + public static class GcsReadOptionsFactory + implements DefaultValueFactory { + @Override + public GoogleCloudStorageReadOptions create(PipelineOptions options) { + return GoogleCloudStorageReadOptions.DEFAULT; + } + } + /** * This is a {@link DefaultValueFactory} able to create a {@link GcsUtil} using any transport * flags specified on the {@link PipelineOptions}. @@ -153,7 +161,8 @@ public GcsUtil create(PipelineOptions options) { : null, gcsOptions.getEnableBucketWriteMetricCounter() ? gcsOptions.getGcsWriteCounterPrefix() - : null)); + : null), + gcsOptions.getGoogleCloudStorageReadOptions()); } /** Returns an instance of {@link GcsUtil} based on the given parameters. */ @@ -164,7 +173,8 @@ public static GcsUtil create( ExecutorService executorService, Credentials credentials, @Nullable Integer uploadBufferSizeBytes, - GcsCountersOptions gcsCountersOptions) { + GcsCountersOptions gcsCountersOptions, + GoogleCloudStorageReadOptions gcsReadOptions) { return new GcsUtil( storageClient, httpRequestInitializer, @@ -173,7 +183,8 @@ public static GcsUtil create( credentials, uploadBufferSizeBytes, null, - gcsCountersOptions); + gcsCountersOptions, + gcsReadOptions); } } @@ -249,7 +260,8 @@ public static boolean isWildcard(GcsPath spec) { Credentials credentials, @Nullable Integer uploadBufferSizeBytes, @Nullable Integer rewriteDataOpBatchLimit, - GcsCountersOptions gcsCountersOptions) { + GcsCountersOptions gcsCountersOptions, + GoogleCloudStorageReadOptions gcsReadOptions) { this.storageClient = storageClient; this.httpRequestInitializer = httpRequestInitializer; this.uploadBufferSizeBytes = uploadBufferSizeBytes; @@ -260,6 +272,7 @@ public static boolean isWildcard(GcsPath spec) { googleCloudStorageOptions = GoogleCloudStorageOptions.builder() .setAppName("Beam") + .setReadChannelOptions(gcsReadOptions) .setGrpcEnabled(shouldUseGrpc) .build(); googleCloudStorage = @@ -565,7 +578,9 @@ private SeekableByteChannel wrapInCounting( public SeekableByteChannel open(GcsPath path) throws IOException { String bucket = path.getBucket(); SeekableByteChannel channel = - googleCloudStorage.open(new StorageResourceId(path.getBucket(), path.getObject())); + googleCloudStorage.open( + new StorageResourceId(path.getBucket(), path.getObject()), + this.googleCloudStorageOptions.getReadChannelOptions()); return wrapInCounting(channel, bucket); } diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java index f5075a3f2c55..26d98125a3af 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/GcpCoreApiSurfaceTest.java @@ -55,6 +55,8 @@ public void testGcpCoreApiSurface() throws Exception { classesInPackage("com.google.api.services.storage"), classesInPackage("com.google.auth"), classesInPackage("com.fasterxml.jackson.annotation"), + classesInPackage("com.google.cloud.hadoop.gcsio"), + classesInPackage("com.google.common.collect"), // Via gcs-connector ReadOptions builder classesInPackage("java"), classesInPackage("javax"), classesInPackage("org.apache.beam.sdk"), diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilTest.java index bd7f46ec8951..97082572ce41 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilTest.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilTest.java @@ -177,6 +177,32 @@ public void testCreationWithGcsUtilProvided() { assertSame(gcsUtil, pipelineOptions.getGcsUtil()); } + @Test + public void testCreationWithExplicitGoogleCloudStorageReadOptions() throws Exception { + GoogleCloudStorageReadOptions readOptions = + GoogleCloudStorageReadOptions.builder() + .setFadvise(GoogleCloudStorageReadOptions.Fadvise.AUTO) + .setSupportGzipEncoding(true) + .setFastFailOnNotFound(false) + .build(); + + GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); + pipelineOptions.setGoogleCloudStorageReadOptions(readOptions); + + GcsUtil gcsUtil = pipelineOptions.getGcsUtil(); + GoogleCloudStorage googleCloudStorageMock = Mockito.spy(GoogleCloudStorage.class); + Mockito.when(googleCloudStorageMock.open(Mockito.any(), Mockito.any())) + .thenReturn(Mockito.mock(SeekableByteChannel.class)); + gcsUtil.setCloudStorageImpl(googleCloudStorageMock); + + assertEquals(readOptions, pipelineOptions.getGoogleCloudStorageReadOptions()); + + // Assert read options are passed to GCS calls + pipelineOptions.getGcsUtil().open(GcsPath.fromUri("gs://bucket/path")); + Mockito.verify(googleCloudStorageMock, Mockito.times(1)) + .open(StorageResourceId.fromStringPath("gs://bucket/path"), readOptions); + } + @Test public void testMultipleThreadsCanCompleteOutOfOrderWithDefaultThreadPool() throws Exception { GcsOptions pipelineOptions = PipelineOptionsFactory.as(GcsOptions.class); @@ -1630,7 +1656,8 @@ public static GcsUtilMock createMock(PipelineOptions options) { : null, gcsOptions.getEnableBucketWriteMetricCounter() ? gcsOptions.getGcsWriteCounterPrefix() - : null)); + : null), + gcsOptions.getGoogleCloudStorageReadOptions()); } private GcsUtilMock( @@ -1641,7 +1668,8 @@ private GcsUtilMock( Credentials credentials, @Nullable Integer uploadBufferSizeBytes, @Nullable Integer rewriteDataOpBatchLimit, - GcsCountersOptions gcsCountersOptions) { + GcsCountersOptions gcsCountersOptions, + GoogleCloudStorageReadOptions gcsReadOptions) { super( storageClient, httpRequestInitializer, @@ -1650,7 +1678,8 @@ private GcsUtilMock( credentials, uploadBufferSizeBytes, rewriteDataOpBatchLimit, - gcsCountersOptions); + gcsCountersOptions, + gcsReadOptions); } @Override diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 0d517503b12d..300796ac6f12 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -83,6 +83,7 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.SdkHarnessOptions; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.common.ReflectHelpers; @@ -188,7 +189,8 @@ public ProcessBundleHandler( executionStateSampler, REGISTERED_RUNNER_FACTORIES, processWideCache, - new BundleProcessorCache(), + new BundleProcessorCache( + options.as(SdkHarnessOptions.class).getBundleProcessorCacheTimeout()), dataSampler); } @@ -927,25 +929,25 @@ public int hashCode() { return super.hashCode(); } - BundleProcessorCache() { - this.cachedBundleProcessors = + BundleProcessorCache(Duration timeout) { + CacheBuilder> builder = CacheBuilder.newBuilder() - .expireAfterAccess(Duration.ofMinutes(1L)) .removalListener( - removalNotification -> { - ((ConcurrentLinkedQueue) removalNotification.getValue()) - .forEach( - bundleProcessor -> { - bundleProcessor.shutdown(); - }); - }) - .build( - new CacheLoader>() { - @Override - public ConcurrentLinkedQueue load(String s) throws Exception { - return new ConcurrentLinkedQueue<>(); - } - }); + removalNotification -> + removalNotification + .getValue() + .forEach(bundleProcessor -> bundleProcessor.shutdown())); + if (timeout.compareTo(Duration.ZERO) > 0) { + builder = builder.expireAfterAccess(timeout); + } + this.cachedBundleProcessors = + builder.build( + new CacheLoader>() { + @Override + public ConcurrentLinkedQueue load(String s) throws Exception { + return new ConcurrentLinkedQueue<>(); + } + }); // We specifically use a weak hash map so that references will automatically go out of scope // and not need to be freed explicitly from the cache. this.activeBundleProcessors = Collections.synchronizedMap(new WeakHashMap<>()); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 95b404aa6203..a69ea5338dc3 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -48,6 +48,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -354,6 +355,10 @@ void reset() throws Exception { private static class TestBundleProcessorCache extends BundleProcessorCache { + TestBundleProcessorCache() { + super(Duration.ZERO); + } + @Override BundleProcessor get( InstructionRequest processBundleRequest, @@ -376,7 +381,7 @@ public void testTrySplitBeforeBundleDoesNotFail() { executionStateSampler, ImmutableMap.of(), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); BeamFnApi.InstructionResponse response = @@ -407,7 +412,7 @@ public void testProgressBeforeBundleDoesNotFail() throws Exception { executionStateSampler, ImmutableMap.of(), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.progress( @@ -487,7 +492,7 @@ public void testOrderOfStartAndFinishCalls() throws Exception { DATA_INPUT_URN, startFinishRecorder, DATA_OUTPUT_URN, startFinishRecorder), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.processBundle( @@ -592,7 +597,7 @@ public void testOrderOfSetupTeardownCalls() throws Exception { executionStateSampler, urnToPTransformRunnerFactoryMap, Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.processBundle( @@ -699,7 +704,7 @@ private static InstructionRequest processBundleRequestFor( public void testBundleProcessorIsFoundWhenActive() { BundleProcessor bundleProcessor = mock(BundleProcessor.class); when(bundleProcessor.getInstructionId()).thenReturn("known"); - BundleProcessorCache cache = new BundleProcessorCache(); + BundleProcessorCache cache = new BundleProcessorCache(Duration.ZERO); // Check that an unknown bundle processor is not found assertNull(cache.find("unknown")); @@ -811,7 +816,7 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { throw new IllegalStateException("TestException"); }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "TestException", @@ -862,7 +867,7 @@ public void testBundleFinalizationIsPropagated() throws Exception { return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); BeamFnApi.InstructionResponse.Builder response = handler.processBundle( @@ -916,7 +921,7 @@ public void testPTransformStartExceptionsArePropagated() { return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "TestException", @@ -1094,7 +1099,7 @@ public void onCompleted() {} executionStateSampler, urnToPTransformRunnerFactoryMap, Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); } @@ -1427,7 +1432,7 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() @@ -1500,7 +1505,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "TestException", @@ -1551,7 +1556,7 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { return null; }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "TestException", @@ -1647,7 +1652,7 @@ private void doStateCalls(BeamFnStateClient beamFnStateClient) { } }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() @@ -1698,7 +1703,7 @@ private void doStateCalls(BeamFnStateClient beamFnStateClient) { } }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "State API calls are unsupported", @@ -1787,7 +1792,7 @@ public void reset() { return null; }; - BundleProcessorCache bundleProcessorCache = new BundleProcessorCache(); + BundleProcessorCache bundleProcessorCache = new BundleProcessorCache(Duration.ZERO); ProcessBundleHandler handler = new ProcessBundleHandler( PipelineOptionsFactory.create(), @@ -1930,7 +1935,7 @@ public Object createRunnerForPTransform(Context context) throws IOException { } }), Caches.noop(), - new BundleProcessorCache(), + new BundleProcessorCache(Duration.ZERO), null /* dataSampler */); assertThrows( "Timers are unsupported", diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index 84bf90bd4121..9a7f3a05556c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -3747,7 +3747,7 @@ private WriteResult continueExpandTyped( if (rowWriterFactory.getOutputType() == OutputType.JsonTableRow) { LOG.warn( "Found JSON type in TableSchema for 'FILE_LOADS' write method. \n" - + "Make sure the TableRow value is a parsed JSON to ensure the read as a " + + "Make sure the TableRow value is a Jackson JsonNode to ensure the read as a " + "JSON type. Otherwise it will read as a raw (escaped) string.\n" + "See https://cloud.google.com/bigquery/docs/loading-data-cloud-storage-json#limitations " + "for limitations."); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java index 2b9f24f09541..933394982e30 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java @@ -135,7 +135,6 @@ String tryGetTableName() { return getTable(); } else if (getQuery() != null) { String query = getQuery().getSql(); - System.err.println(query); Matcher matcher = queryPattern.matcher(query); if (matcher.find()) { return matcher.group("table"); diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java index 5df8604699a3..c79b0a550051 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertTrue; +import com.google.api.services.storage.model.Objects; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; @@ -36,6 +37,9 @@ import java.util.stream.LongStream; import java.util.stream.Stream; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.extensions.gcp.options.GcsOptions; +import org.apache.beam.sdk.extensions.gcp.util.GcsUtil; +import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.testing.PAssert; @@ -80,6 +84,7 @@ import org.joda.time.DateTimeZone; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; @@ -87,10 +92,14 @@ import org.junit.rules.TestName; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** Integration tests for {@link IcebergIO} source and sink. */ @RunWith(JUnit4.class) public class IcebergIOIT implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(IcebergIOIT.class); + private static final org.apache.beam.sdk.schemas.Schema DOUBLY_NESTED_ROW_SCHEMA = org.apache.beam.sdk.schemas.Schema.builder() .addStringField("doubly_nested_str") @@ -176,27 +185,45 @@ public Record apply(Row input) { @Rule public TestName testName = new TestName(); - private String warehouseLocation; + private static String warehouseLocation; private String tableId; - private Catalog catalog; + private static Catalog catalog; @BeforeClass public static void beforeClass() { options = TestPipeline.testingPipelineOptions().as(GcpOptions.class); - + warehouseLocation = + String.format("%s/IcebergIOIT/%s", options.getTempLocation(), UUID.randomUUID()); catalogHadoopConf = new Configuration(); catalogHadoopConf.set("fs.gs.project.id", options.getProject()); catalogHadoopConf.set("fs.gs.auth.type", "APPLICATION_DEFAULT"); + catalog = new HadoopCatalog(catalogHadoopConf, warehouseLocation); } @Before public void setUp() { - warehouseLocation = - String.format("%s/IcebergIOIT/%s", options.getTempLocation(), UUID.randomUUID()); - tableId = testName.getMethodName() + ".test_table"; - catalog = new HadoopCatalog(catalogHadoopConf, warehouseLocation); + } + + @AfterClass + public static void afterClass() { + try { + GcsUtil gcsUtil = options.as(GcsOptions.class).getGcsUtil(); + GcsPath path = GcsPath.fromUri(warehouseLocation); + + Objects objects = + gcsUtil.listObjects( + path.getBucket(), "IcebergIOIT/" + path.getFileName().toString(), null); + List filesToDelete = + objects.getItems().stream() + .map(obj -> "gs://" + path.getBucket() + "/" + obj.getName()) + .collect(Collectors.toList()); + + gcsUtil.remove(filesToDelete); + } catch (Exception e) { + LOG.warn("Failed to clean up files.", e); + } } /** Populates the Iceberg table and Returns a {@link List} of expected elements. */ diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java index 4884bb61e628..0a9ee4618b1e 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthSempClient.java @@ -17,14 +17,10 @@ */ package org.apache.beam.sdk.io.solace.broker; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.api.client.http.HttpRequestFactory; import com.solacesystems.jcsmp.JCSMPFactory; import java.io.IOException; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.io.solace.data.Semp.Queue; import org.apache.beam.sdk.util.SerializableSupplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,8 +36,6 @@ @Internal public class BasicAuthSempClient implements SempClient { private static final Logger LOG = LoggerFactory.getLogger(BasicAuthSempClient.class); - private final ObjectMapper objectMapper = - new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); private final SempBasicAuthClientExecutor sempBasicAuthClientExecutor; @@ -58,13 +52,12 @@ public BasicAuthSempClient( @Override public boolean isQueueNonExclusive(String queueName) throws IOException { - LOG.info("SolaceIO.Read: SempOperations: query SEMP if queue {} is nonExclusive", queueName); - BrokerResponse response = sempBasicAuthClientExecutor.getQueueResponse(queueName); - if (response.content == null) { - throw new IOException("SolaceIO: response from SEMP is empty!"); - } - Queue q = mapJsonToClass(response.content, Queue.class); - return q.data().accessType().equals("non-exclusive"); + boolean queueNonExclusive = sempBasicAuthClientExecutor.isQueueNonExclusive(queueName); + LOG.info( + "SolaceIO.Read: SempOperations: queried SEMP if queue {} is non-exclusive: {}", + queueName, + queueNonExclusive); + return queueNonExclusive; } @Override @@ -77,12 +70,7 @@ public com.solacesystems.jcsmp.Queue createQueueForTopic(String queueName, Strin @Override public long getBacklogBytes(String queueName) throws IOException { - BrokerResponse response = sempBasicAuthClientExecutor.getQueueResponse(queueName); - if (response.content == null) { - throw new IOException("SolaceIO: response from SEMP is empty!"); - } - Queue q = mapJsonToClass(response.content, Queue.class); - return q.data().msgSpoolUsage(); + return sempBasicAuthClientExecutor.getBacklogBytes(queueName); } private void createQueue(String queueName) throws IOException { @@ -94,9 +82,4 @@ private void createSubscription(String queueName, String topicName) throws IOExc LOG.info("SolaceIO.Read: Creating new subscription {} for topic {}.", queueName, topicName); sempBasicAuthClientExecutor.createSubscriptionResponse(queueName, topicName); } - - private T mapJsonToClass(String content, Class mapSuccessToClass) - throws JsonProcessingException { - return objectMapper.readValue(content, mapSuccessToClass); - } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java index 99a81f716435..965fc8741374 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SempBasicAuthClientExecutor.java @@ -19,6 +19,9 @@ import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.api.client.http.GenericUrl; import com.google.api.client.http.HttpContent; import com.google.api.client.http.HttpHeaders; @@ -40,6 +43,7 @@ import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import org.apache.beam.sdk.io.solace.data.Semp.Queue; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.checkerframework.checker.nullness.qual.Nullable; @@ -52,7 +56,7 @@ * response is 401 Unauthorized, the client will execute an additional request with Basic Auth * header to refresh the token. */ -class SempBasicAuthClientExecutor implements Serializable { +public class SempBasicAuthClientExecutor implements Serializable { // Every request will be repeated 2 times in case of abnormal connection failures. private static final int REQUEST_NUM_RETRIES = 2; private static final Map COOKIE_MANAGER_MAP = @@ -65,8 +69,10 @@ class SempBasicAuthClientExecutor implements Serializable { private final String password; private final CookieManagerKey cookieManagerKey; private final transient HttpRequestFactory requestFactory; + private final ObjectMapper objectMapper = + new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - SempBasicAuthClientExecutor( + public SempBasicAuthClientExecutor( String host, String username, String password, @@ -78,7 +84,16 @@ class SempBasicAuthClientExecutor implements Serializable { this.password = password; this.requestFactory = httpRequestFactory; this.cookieManagerKey = new CookieManagerKey(this.baseUrl, this.username); - COOKIE_MANAGER_MAP.putIfAbsent(this.cookieManagerKey, new CookieManager()); + COOKIE_MANAGER_MAP.computeIfAbsent(this.cookieManagerKey, key -> new CookieManager()); + } + + public boolean isQueueNonExclusive(String queueName) throws IOException { + BrokerResponse response = getQueueResponse(queueName); + if (response.content == null) { + throw new IOException("SolaceIO: response from SEMP is empty!"); + } + Queue q = mapJsonToClass(response.content, Queue.class); + return q.data().accessType().equals("non-exclusive"); } private static String getQueueEndpoint(String messageVpn, String queueName) @@ -199,6 +214,20 @@ private static String urlEncode(String queueName) throws UnsupportedEncodingExce return URLEncoder.encode(queueName, StandardCharsets.UTF_8.name()); } + private T mapJsonToClass(String content, Class mapSuccessToClass) + throws JsonProcessingException { + return objectMapper.readValue(content, mapSuccessToClass); + } + + public long getBacklogBytes(String queueName) throws IOException { + BrokerResponse response = getQueueResponse(queueName); + if (response.content == null) { + throw new IOException("SolaceIO: response from SEMP is empty!"); + } + Queue q = mapJsonToClass(response.content, Queue.class); + return q.data().msgSpoolUsage(); + } + private static class CookieManagerKey implements Serializable { private final String baseUrl; private final String username; diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java index 12d8a8507d8a..c55d37942c72 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java @@ -23,8 +23,9 @@ import org.apache.beam.sdk.values.KV; /** - * This class a pseudo-key with a given cardinality. The downstream steps will use state {@literal - * &} timers to distribute the data and control for the number of parallel workers used for writing. + * This class adds pseudo-key with a given cardinality. The downstream steps will use state + * {@literal &} timers to distribute the data and control for the number of parallel workers used + * for writing. */ @Internal public class AddShardKeyDoFn extends DoFn> { diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClient.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClient.java new file mode 100644 index 000000000000..637cecdcfd15 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClient.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.it; + +import com.google.api.client.http.HttpRequestFactory; +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.sdk.io.solace.broker.BasicAuthSempClient; +import org.apache.beam.sdk.io.solace.broker.SempBasicAuthClientExecutor; +import org.apache.beam.sdk.util.SerializableSupplier; + +/** + * Example class showing how the {@link BasicAuthSempClient} can be extended or have functionalities + * overridden. In this case, the modified method is {@link + * BasicAuthSempClient#getBacklogBytes(String)}, which queries multiple SEMP endpoints to collect + * accurate backlog metrics. For usage, see {@link SolaceIOMultipleSempIT}. + */ +public class BasicAuthMultipleSempClient extends BasicAuthSempClient { + private final List sempBacklogBasicAuthClientExecutors; + + public BasicAuthMultipleSempClient( + String mainHost, + List backlogHosts, + String username, + String password, + String vpnName, + SerializableSupplier httpRequestFactorySupplier) { + super(mainHost, username, password, vpnName, httpRequestFactorySupplier); + sempBacklogBasicAuthClientExecutors = + backlogHosts.stream() + .map( + host -> + new SempBasicAuthClientExecutor( + host, username, password, vpnName, httpRequestFactorySupplier.get())) + .collect(Collectors.toList()); + } + + @Override + public long getBacklogBytes(String queueName) throws IOException { + long backlog = 0; + for (SempBasicAuthClientExecutor client : sempBacklogBasicAuthClientExecutors) { + backlog += client.getBacklogBytes(queueName); + } + return backlog; + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClientFactory.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClientFactory.java new file mode 100644 index 000000000000..0a548c10555c --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/BasicAuthMultipleSempClientFactory.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.it; + +import com.google.api.client.http.HttpRequestFactory; +import com.google.api.client.http.javanet.NetHttpTransport; +import com.google.auto.value.AutoValue; +import java.util.List; +import org.apache.beam.sdk.io.solace.broker.SempClient; +import org.apache.beam.sdk.io.solace.broker.SempClientFactory; +import org.apache.beam.sdk.util.SerializableSupplier; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Example class showing how to implement a custom {@link SempClientFactory} with custom client. For + * usage, see {@link SolaceIOMultipleSempIT}. + */ +@AutoValue +public abstract class BasicAuthMultipleSempClientFactory implements SempClientFactory { + + public abstract String mainHost(); + + public abstract List backlogHosts(); + + public abstract String username(); + + public abstract String password(); + + public abstract String vpnName(); + + public abstract @Nullable SerializableSupplier httpRequestFactorySupplier(); + + public static Builder builder() { + return new AutoValue_BasicAuthMultipleSempClientFactory.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + /** Set Solace host, format: [Protocol://]Host[:Port]. */ + public abstract Builder mainHost(String host); + + public abstract Builder backlogHosts(List hosts); + + /** Set Solace username. */ + public abstract Builder username(String username); + /** Set Solace password. */ + public abstract Builder password(String password); + + /** Set Solace vpn name. */ + public abstract Builder vpnName(String vpnName); + + abstract Builder httpRequestFactorySupplier( + SerializableSupplier httpRequestFactorySupplier); + + public abstract BasicAuthMultipleSempClientFactory build(); + } + + @Override + public SempClient create() { + return new BasicAuthMultipleSempClient( + mainHost(), + backlogHosts(), + username(), + password(), + vpnName(), + getHttpRequestFactorySupplier()); + } + + @SuppressWarnings("return") + private @NonNull SerializableSupplier getHttpRequestFactorySupplier() { + SerializableSupplier httpRequestSupplier = httpRequestFactorySupplier(); + return httpRequestSupplier != null + ? httpRequestSupplier + : () -> new NetHttpTransport().createRequestFactory(); + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOMultipleSempIT.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOMultipleSempIT.java new file mode 100644 index 000000000000..77d00b4e41ec --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOMultipleSempIT.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.it; + +import static org.apache.beam.sdk.io.solace.it.SolaceContainerManager.TOPIC_NAME; +import static org.apache.beam.sdk.values.TypeDescriptors.strings; +import static org.junit.Assert.assertEquals; + +import com.solacesystems.jcsmp.DeliveryMode; +import java.io.IOException; +import java.util.Arrays; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.SolaceIO.WriterType; +import org.apache.beam.sdk.io.solace.broker.BasicAuthJcsmpSessionServiceFactory; +import org.apache.beam.sdk.io.solace.broker.SempClientFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Queue; +import org.apache.beam.sdk.io.solace.data.SolaceDataUtils; +import org.apache.beam.sdk.io.solace.write.SolaceOutput; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.StreamingOptions; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.testutils.metrics.MetricsReader; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; + +public class SolaceIOMultipleSempIT { + private static final String NAMESPACE = SolaceIOMultipleSempIT.class.getName(); + private static final String READ_COUNT = "read_count"; + private static final String QUEUE_NAME = "test_queue"; + private static final long PUBLISH_MESSAGE_COUNT = 20; + private static final TestPipelineOptions pipelineOptions; + private static SolaceContainerManager solaceContainerManager; + + static { + pipelineOptions = PipelineOptionsFactory.create().as(TestPipelineOptions.class); + pipelineOptions.as(StreamingOptions.class).setStreaming(true); + // For the read connector tests, we need to make sure that p.run() does not block + pipelineOptions.setBlockOnRun(false); + pipelineOptions.as(TestPipelineOptions.class).setBlockOnRun(false); + } + + @Rule public final TestPipeline pipeline = TestPipeline.fromOptions(pipelineOptions); + + @BeforeClass + public static void setup() throws IOException { + solaceContainerManager = new SolaceContainerManager(); + solaceContainerManager.start(); + solaceContainerManager.createQueueWithSubscriptionTopic(QUEUE_NAME); + } + + @AfterClass + public static void afterClass() { + if (solaceContainerManager != null) { + solaceContainerManager.stop(); + } + } + + /** + * This test verifies the functionality of reading data from a Solace queue using the + * SolaceIO.read() transform. This test does not actually test functionalities of {@link + * BasicAuthMultipleSempClientFactory}, but it demonstrates how to integrate a custom + * implementation of {@link SempClientFactory}, in this case, {@link + * BasicAuthMultipleSempClientFactory}, to handle authentication and configuration interactions + * with the Solace message broker. + */ + @Test + public void test01writeAndReadWithMultipleSempClientFactory() { + Pipeline writerPipeline = + createWriterPipeline(WriterType.BATCHED, solaceContainerManager.jcsmpPortMapped); + writerPipeline + .apply( + "Read from Solace", + SolaceIO.read() + .from(Queue.fromName(QUEUE_NAME)) + .withMaxNumConnections(1) + .withDeduplicateRecords(true) + .withSempClientFactory( + BasicAuthMultipleSempClientFactory.builder() + .backlogHosts( + Arrays.asList( + "http://localhost:" + solaceContainerManager.sempPortMapped, + "http://localhost:" + solaceContainerManager.sempPortMapped)) + .mainHost("http://localhost:" + solaceContainerManager.sempPortMapped) + .username("admin") + .password("admin") + .vpnName(SolaceContainerManager.VPN_NAME) + .build()) + .withSessionServiceFactory( + BasicAuthJcsmpSessionServiceFactory.builder() + .host("localhost:" + solaceContainerManager.jcsmpPortMapped) + .username(SolaceContainerManager.USERNAME) + .password(SolaceContainerManager.PASSWORD) + .vpnName(SolaceContainerManager.VPN_NAME) + .build())) + .apply("Count", ParDo.of(new CountingFn<>(NAMESPACE, READ_COUNT))); + + PipelineResult pipelineResult = writerPipeline.run(); + // We need enough time for Beam to pull all messages from the queue, but we need a timeout too, + // as the Read connector will keep attempting to read forever. + pipelineResult.waitUntilFinish(Duration.standardSeconds(15)); + + MetricsReader metricsReader = new MetricsReader(pipelineResult, NAMESPACE); + long actualRecordsCount = metricsReader.getCounterMetric(READ_COUNT); + assertEquals(PUBLISH_MESSAGE_COUNT, actualRecordsCount); + } + + private Pipeline createWriterPipeline( + SolaceIO.WriterType writerType, int solaceContainerJcsmpPort) { + TestStream.Builder> kvBuilder = + TestStream.create(KvCoder.of(AvroCoder.of(String.class), AvroCoder.of(String.class))) + .advanceWatermarkTo(Instant.EPOCH); + + for (int i = 0; i < PUBLISH_MESSAGE_COUNT; i++) { + String key = "Solace-Message-ID:m" + solaceContainerJcsmpPort + i; + String payload = String.format("{\"field_str\":\"value\",\"field_int\":123%d}", i); + kvBuilder = + kvBuilder + .addElements(KV.of(key, payload)) + .advanceProcessingTime(Duration.standardSeconds(60)); + } + + TestStream> testStream = kvBuilder.advanceWatermarkToInfinity(); + + PCollection> kvs = + pipeline.apply(String.format("Test stream %s", writerType), testStream); + + PCollection records = + kvs.apply( + String.format("To Record %s", writerType), + MapElements.into(TypeDescriptor.of(Solace.Record.class)) + .via(kv -> SolaceDataUtils.getSolaceRecord(kv.getValue(), kv.getKey()))); + + SolaceOutput result = + records.apply( + String.format("Write to Solace %s", writerType), + SolaceIO.write() + .to(Solace.Topic.fromName(TOPIC_NAME)) + .withSubmissionMode(SolaceIO.SubmissionMode.TESTING) + .withWriterType(writerType) + .withDeliveryMode(DeliveryMode.PERSISTENT) + .withNumberOfClientsPerWorker(1) + .withNumShards(1) + .withSessionServiceFactory( + BasicAuthJcsmpSessionServiceFactory.builder() + .host("localhost:" + solaceContainerJcsmpPort) + .username(SolaceContainerManager.USERNAME) + .password(SolaceContainerManager.PASSWORD) + .vpnName(SolaceContainerManager.VPN_NAME) + .build())); + result + .getSuccessfulPublish() + .apply( + String.format("Get ids %s", writerType), + MapElements.into(strings()).via(Solace.PublishResult::getMessageId)); + + return pipeline; + } + + private static class CountingFn extends DoFn { + + private final Counter elementCounter; + + CountingFn(String namespace, String name) { + elementCounter = Metrics.counter(namespace, name); + } + + @ProcessElement + public void processElement(@Element T record, OutputReceiver c) { + elementCounter.inc(1L); + c.output(record); + } + } +} diff --git a/sdks/java/testing/nexmark/build.gradle b/sdks/java/testing/nexmark/build.gradle index a09ed9238991..0a09c357ed57 100644 --- a/sdks/java/testing/nexmark/build.gradle +++ b/sdks/java/testing/nexmark/build.gradle @@ -119,11 +119,7 @@ def getNexmarkArgs = { } } else { def dataflowWorkerJar = project.findProperty('dataflowWorkerJar') ?: project(":runners:google-cloud-dataflow-java:worker").shadowJar.archivePath - // Provide job with a customizable worker jar. - // With legacy worker jar, containerImage is set to empty (i.e. to use the internal build). - // More context and discussions can be found in PR#6694. nexmarkArgsList.add("--dataflowWorkerJar=${dataflowWorkerJar}".toString()) - nexmarkArgsList.add('--workerHarnessContainerImage=') def nexmarkProfile = project.findProperty(nexmarkProfilingProperty) ?: "" if (nexmarkProfile.equals("true")) { diff --git a/sdks/python/apache_beam/internal/metrics/cells.py b/sdks/python/apache_beam/internal/metrics/cells.py index c7b546258a70..989dc7183045 100644 --- a/sdks/python/apache_beam/internal/metrics/cells.py +++ b/sdks/python/apache_beam/internal/metrics/cells.py @@ -28,7 +28,6 @@ from typing import TYPE_CHECKING from typing import Optional -from apache_beam.metrics.cells import MetricAggregator from apache_beam.metrics.cells import MetricCell from apache_beam.metrics.cells import MetricCellFactory from apache_beam.utils.histogram import Histogram @@ -50,10 +49,10 @@ class HistogramCell(MetricCell): """ def __init__(self, bucket_type): self._bucket_type = bucket_type - self.data = HistogramAggregator(bucket_type).identity_element() + self.data = HistogramData.identity_element(bucket_type) def reset(self): - self.data = HistogramAggregator(self._bucket_type).identity_element() + self.data = HistogramData.identity_element(self._bucket_type) def combine(self, other: 'HistogramCell') -> 'HistogramCell': result = HistogramCell(self._bucket_type) @@ -148,22 +147,6 @@ def combine(self, other: Optional['HistogramData']) -> 'HistogramData': return HistogramData(self.histogram.combine(other.histogram)) - -class HistogramAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. - - Aggregator for Histogram metric data during pipeline execution. - - Values aggregated should be ``HistogramData`` objects. - """ - def __init__(self, bucket_type: 'BucketType') -> None: - self._bucket_type = bucket_type - - def identity_element(self) -> HistogramData: - return HistogramData(Histogram(self._bucket_type)) - - def combine(self, x: HistogramData, y: HistogramData) -> HistogramData: - return x.combine(y) - - def result(self, x: HistogramData) -> HistogramResult: - return HistogramResult(x.get_cumulative()) + @staticmethod + def identity_element(bucket_type) -> 'HistogramData': + return HistogramData(Histogram(bucket_type)) diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd index 98bb5eff0977..c583dabeb0c0 100644 --- a/sdks/python/apache_beam/metrics/cells.pxd +++ b/sdks/python/apache_beam/metrics/cells.pxd @@ -33,6 +33,7 @@ cdef class CounterCell(MetricCell): cpdef bint update(self, value) except -1 +# Not using AbstractMetricCell so that data can be typed. cdef class DistributionCell(MetricCell): cdef readonly DistributionData data @@ -40,14 +41,18 @@ cdef class DistributionCell(MetricCell): cdef inline bint _update(self, value) except -1 -cdef class GaugeCell(MetricCell): - cdef readonly object data +cdef class AbstractMetricCell(MetricCell): + cdef readonly object data_class + cdef public object data + cdef bint _update_locked(self, value) except -1 -cdef class StringSetCell(MetricCell): - cdef readonly object data +cdef class GaugeCell(AbstractMetricCell): + pass - cdef inline bint _update(self, value) except -1 + +cdef class StringSetCell(AbstractMetricCell): + pass cdef class DistributionData(object): diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index 63fc9f3f7cc9..10ac7b3a1e69 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -27,11 +27,9 @@ import threading import time from datetime import datetime -from typing import Any from typing import Iterable from typing import Optional from typing import Set -from typing import SupportsInt try: import cython @@ -43,11 +41,7 @@ class fake_cython: globals()['cython'] = fake_cython __all__ = [ - 'MetricAggregator', - 'MetricCell', - 'MetricCellFactory', - 'DistributionResult', - 'GaugeResult' + 'MetricCell', 'MetricCellFactory', 'DistributionResult', 'GaugeResult' ] _LOGGER = logging.getLogger(__name__) @@ -110,11 +104,11 @@ class CounterCell(MetricCell): """ def __init__(self, *args): super().__init__(*args) - self.value = CounterAggregator.identity_element() + self.value = 0 def reset(self): # type: () -> None - self.value = CounterAggregator.identity_element() + self.value = 0 def combine(self, other): # type: (CounterCell) -> CounterCell @@ -175,11 +169,11 @@ class DistributionCell(MetricCell): """ def __init__(self, *args): super().__init__(*args) - self.data = DistributionAggregator.identity_element() + self.data = DistributionData.identity_element() def reset(self): # type: () -> None - self.data = DistributionAggregator.identity_element() + self.data = DistributionData.identity_element() def combine(self, other): # type: (DistributionCell) -> DistributionCell @@ -221,47 +215,65 @@ def to_runner_api_monitoring_info_impl(self, name, transform_id): ptransform=transform_id) -class GaugeCell(MetricCell): +class AbstractMetricCell(MetricCell): """For internal use only; no backwards-compatibility guarantees. - Tracks the current value and delta for a gauge metric. - - Each cell tracks the state of a metric independently per context per bundle. - Therefore, each metric has a different cell in each bundle, that is later - aggregated. + Tracks the current value and delta for a metric with a data class. This class is thread safe. """ - def __init__(self, *args): - super().__init__(*args) - self.data = GaugeAggregator.identity_element() + def __init__(self, data_class): + super().__init__() + self.data_class = data_class + self.data = self.data_class.identity_element() def reset(self): - self.data = GaugeAggregator.identity_element() + self.data = self.data_class.identity_element() - def combine(self, other): - # type: (GaugeCell) -> GaugeCell - result = GaugeCell() + def combine(self, other: 'AbstractMetricCell') -> 'AbstractMetricCell': + result = type(self)() # type: ignore[call-arg] result.data = self.data.combine(other.data) return result def set(self, value): - self.update(value) + with self._lock: + self._update_locked(value) def update(self, value): - # type: (SupportsInt) -> None - value = int(value) with self._lock: - # Set the value directly without checking timestamp, because - # this value is naturally the latest value. - self.data.value = value - self.data.timestamp = time.time() + self._update_locked(value) + + def _update_locked(self, value): + raise NotImplementedError(type(self)) def get_cumulative(self): - # type: () -> GaugeData with self._lock: return self.data.get_cumulative() + def to_runner_api_monitoring_info_impl(self, name, transform_id): + raise NotImplementedError(type(self)) + + +class GaugeCell(AbstractMetricCell): + """For internal use only; no backwards-compatibility guarantees. + + Tracks the current value and delta for a gauge metric. + + Each cell tracks the state of a metric independently per context per bundle. + Therefore, each metric has a different cell in each bundle, that is later + aggregated. + + This class is thread safe. + """ + def __init__(self): + super().__init__(GaugeData) + + def _update_locked(self, value): + # Set the value directly without checking timestamp, because + # this value is naturally the latest value. + self.data.value = int(value) + self.data.timestamp = time.time() + def to_runner_api_monitoring_info_impl(self, name, transform_id): from apache_beam.metrics import monitoring_infos return monitoring_infos.int64_user_gauge( @@ -271,7 +283,7 @@ def to_runner_api_monitoring_info_impl(self, name, transform_id): ptransform=transform_id) -class StringSetCell(MetricCell): +class StringSetCell(AbstractMetricCell): """For internal use only; no backwards-compatibility guarantees. Tracks the current value for a StringSet metric. @@ -282,50 +294,23 @@ class StringSetCell(MetricCell): This class is thread safe. """ - def __init__(self, *args): - super().__init__(*args) - self.data = StringSetAggregator.identity_element() + def __init__(self): + super().__init__(StringSetData) def add(self, value): self.update(value) - def update(self, value): - # type: (str) -> None - if cython.compiled: - # We will hold the GIL throughout the entire _update. - self._update(value) - else: - with self._lock: - self._update(value) - - def _update(self, value): + def _update_locked(self, value): self.data.add(value) - def get_cumulative(self): - # type: () -> StringSetData - with self._lock: - return self.data.get_cumulative() - - def combine(self, other): - # type: (StringSetCell) -> StringSetCell - combined = StringSetAggregator().combine(self.data, other.data) - result = StringSetCell() - result.data = combined - return result - def to_runner_api_monitoring_info_impl(self, name, transform_id): from apache_beam.metrics import monitoring_infos - return monitoring_infos.user_set_string( name.namespace, name.name, self.get_cumulative(), ptransform=transform_id) - def reset(self): - # type: () -> None - self.data = StringSetAggregator.identity_element() - class DistributionResult(object): """The result of a Distribution metric.""" @@ -449,6 +434,10 @@ def get_cumulative(self): # type: () -> GaugeData return GaugeData(self.value, timestamp=self.timestamp) + def get_result(self): + # type: () -> GaugeResult + return GaugeResult(self.get_cumulative()) + def combine(self, other): # type: (Optional[GaugeData]) -> GaugeData if other is None: @@ -464,6 +453,11 @@ def singleton(value, timestamp=None): # type: (Optional[int], Optional[int]) -> GaugeData return GaugeData(value, timestamp=timestamp) + @staticmethod + def identity_element(): + # type: () -> GaugeData + return GaugeData(0, timestamp=0) + class DistributionData(object): """For internal use only; no backwards-compatibility guarantees. @@ -510,6 +504,9 @@ def get_cumulative(self): # type: () -> DistributionData return DistributionData(self.sum, self.count, self.min, self.max) + def get_result(self) -> DistributionResult: + return DistributionResult(self.get_cumulative()) + def combine(self, other): # type: (Optional[DistributionData]) -> DistributionData if other is None: @@ -526,6 +523,11 @@ def singleton(value): # type: (int) -> DistributionData return DistributionData(value, 1, value, value) + @staticmethod + def identity_element(): + # type: () -> DistributionData + return DistributionData(0, 0, 2**63 - 1, -2**63) + class StringSetData(object): """For internal use only; no backwards-compatibility guarantees. @@ -568,6 +570,9 @@ def __repr__(self) -> str: def get_cumulative(self) -> "StringSetData": return StringSetData(set(self.string_set), self.string_size) + def get_result(self) -> set[str]: + return set(self.string_set) + def add(self, *strings): """ Add strings into this StringSetData and return the result StringSetData. @@ -585,6 +590,11 @@ def combine(self, other: "StringSetData") -> "StringSetData": if other is None: return self + if not other.string_set: + return self + elif not self.string_set: + return other + combined = set(self.string_set) string_size = self.add_until_capacity( combined, self.string_size, other.string_set) @@ -614,113 +624,9 @@ def add_until_capacity( return current_size @staticmethod - def singleton(value): - # type: (int) -> DistributionData - return DistributionData(value, 1, value, value) - - -class MetricAggregator(object): - """For internal use only; no backwards-compatibility guarantees. - - Base interface for aggregating metric data during pipeline execution.""" - def identity_element(self): - # type: () -> Any - - """Returns the identical element of an Aggregation. - - For the identity element, it must hold that - Aggregator.combine(any_element, identity_element) == any_element. - """ - raise NotImplementedError + def singleton(value: str) -> "StringSetData": + return StringSetData({value}) - def combine(self, x, y): - # type: (Any, Any) -> Any - raise NotImplementedError - - def result(self, x): - # type: (Any) -> Any - raise NotImplementedError - - -class CounterAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. - - Aggregator for Counter metric data during pipeline execution. - - Values aggregated should be ``int`` objects. - """ @staticmethod - def identity_element(): - # type: () -> int - return 0 - - def combine(self, x, y): - # type: (SupportsInt, SupportsInt) -> int - return int(x) + int(y) - - def result(self, x): - # type: (SupportsInt) -> int - return int(x) - - -class DistributionAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. - - Aggregator for Distribution metric data during pipeline execution. - - Values aggregated should be ``DistributionData`` objects. - """ - @staticmethod - def identity_element(): - # type: () -> DistributionData - return DistributionData(0, 0, 2**63 - 1, -2**63) - - def combine(self, x, y): - # type: (DistributionData, DistributionData) -> DistributionData - return x.combine(y) - - def result(self, x): - # type: (DistributionData) -> DistributionResult - return DistributionResult(x.get_cumulative()) - - -class GaugeAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. - - Aggregator for Gauge metric data during pipeline execution. - - Values aggregated should be ``GaugeData`` objects. - """ - @staticmethod - def identity_element(): - # type: () -> GaugeData - return GaugeData(0, timestamp=0) - - def combine(self, x, y): - # type: (GaugeData, GaugeData) -> GaugeData - result = x.combine(y) - return result - - def result(self, x): - # type: (GaugeData) -> GaugeResult - return GaugeResult(x.get_cumulative()) - - -class StringSetAggregator(MetricAggregator): - @staticmethod - def identity_element(): - # type: () -> StringSetData + def identity_element() -> "StringSetData": return StringSetData() - - def combine(self, x, y): - # type: (StringSetData, StringSetData) -> StringSetData - if len(x.string_set) == 0: - return y - elif len(y.string_set) == 0: - return x - else: - return x.combine(y) - - def result(self, x): - # type: (StringSetData) -> set - return set(x.string_set) diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py index 52123516de1a..dc35aa016013 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py @@ -21,6 +21,7 @@ import shutil import tempfile import unittest +import uuid from typing import Any from typing import Dict from typing import Iterable @@ -127,7 +128,7 @@ def test_predict_tensor(self): def test_predict_tensor_with_batch_size(self): model = _create_mult2_model() - model_path = os.path.join(self.tmpdir, 'mult2.keras') + model_path = os.path.join(self.tmpdir, f'mult2_{uuid.uuid4()}.keras') tf.keras.models.save_model(model, model_path) with TestPipeline() as pipeline: @@ -173,7 +174,7 @@ def fake_batching_inference_fn( def test_predict_tensor_with_large_model(self): model = _create_mult2_model() - model_path = os.path.join(self.tmpdir, 'mult2.keras') + model_path = os.path.join(self.tmpdir, f'mult2_{uuid.uuid4()}.keras') tf.keras.models.save_model(model, model_path) with TestPipeline() as pipeline: @@ -220,7 +221,7 @@ def fake_batching_inference_fn( def test_predict_numpy_with_batch_size(self): model = _create_mult2_model() - model_path = os.path.join(self.tmpdir, 'mult2_numpy.keras') + model_path = os.path.join(self.tmpdir, f'mult2_{uuid.uuid4()}.keras') tf.keras.models.save_model(model, model_path) with TestPipeline() as pipeline: @@ -263,7 +264,7 @@ def fake_batching_inference_fn( def test_predict_numpy_with_large_model(self): model = _create_mult2_model() - model_path = os.path.join(self.tmpdir, 'mult2_numpy.keras') + model_path = os.path.join(self.tmpdir, f'mult2_{uuid.uuid4()}.keras') tf.keras.models.save_model(model, model_path) with TestPipeline() as pipeline: diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 1a21f6caf7e1..3db5a63b9542 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -21,7 +21,6 @@ import shutil import tempfile import time -import typing import unittest from collections.abc import Sequence from typing import Any @@ -140,8 +139,8 @@ def test_ml_transform_on_list_dict(self): 'x': int, 'y': float }, expected_dtype={ - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], }, ), param( @@ -153,8 +152,8 @@ def test_ml_transform_on_list_dict(self): 'x': np.int32, 'y': np.float32 }, expected_dtype={ - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], }, ), param( @@ -165,8 +164,8 @@ def test_ml_transform_on_list_dict(self): 'x': list[int], 'y': list[float] }, expected_dtype={ - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], }, ), param( @@ -174,12 +173,12 @@ def test_ml_transform_on_list_dict(self): 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0] }], input_types={ - 'x': typing.Sequence[int], - 'y': typing.Sequence[float], + 'x': Sequence[int], + 'y': Sequence[float], }, expected_dtype={ - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], }, ), ]) diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py index 4b53026c36a4..bb5f9b5f0f70 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers_test.py +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -20,9 +20,9 @@ import shutil import sys import tempfile -import typing import unittest import uuid +from collections.abc import Sequence from typing import NamedTuple from typing import Union @@ -276,9 +276,9 @@ def test_tft_process_handler_transformed_data_schema(self): schema_utils.schema_from_feature_spec(raw_data_feature_spec)) expected_transformed_data_schema = { - 'x': typing.Sequence[np.float32], - 'y': typing.Sequence[np.float32], - 'z': typing.Sequence[bytes] + 'x': Sequence[np.float32], + 'y': Sequence[np.float32], + 'z': Sequence[bytes] } actual_transformed_data_schema = ( diff --git a/sdks/python/apache_beam/runners/direct/direct_metrics.py b/sdks/python/apache_beam/runners/direct/direct_metrics.py index f715ce3bf521..d20849d769af 100644 --- a/sdks/python/apache_beam/runners/direct/direct_metrics.py +++ b/sdks/python/apache_beam/runners/direct/direct_metrics.py @@ -24,23 +24,84 @@ import threading from collections import defaultdict +from typing import Any +from typing import SupportsInt -from apache_beam.metrics.cells import CounterAggregator -from apache_beam.metrics.cells import DistributionAggregator -from apache_beam.metrics.cells import GaugeAggregator -from apache_beam.metrics.cells import StringSetAggregator +from apache_beam.metrics.cells import DistributionData +from apache_beam.metrics.cells import GaugeData +from apache_beam.metrics.cells import StringSetData from apache_beam.metrics.execution import MetricKey from apache_beam.metrics.execution import MetricResult from apache_beam.metrics.metric import MetricResults +class MetricAggregator(object): + """For internal use only; no backwards-compatibility guarantees. + + Base interface for aggregating metric data during pipeline execution.""" + def identity_element(self): + # type: () -> Any + + """Returns the identical element of an Aggregation. + + For the identity element, it must hold that + Aggregator.combine(any_element, identity_element) == any_element. + """ + raise NotImplementedError + + def combine(self, x, y): + # type: (Any, Any) -> Any + raise NotImplementedError + + def result(self, x): + # type: (Any) -> Any + raise NotImplementedError + + +class CounterAggregator(MetricAggregator): + """For internal use only; no backwards-compatibility guarantees. + + Aggregator for Counter metric data during pipeline execution. + + Values aggregated should be ``int`` objects. + """ + @staticmethod + def identity_element(): + # type: () -> int + return 0 + + def combine(self, x, y): + # type: (SupportsInt, SupportsInt) -> int + return int(x) + int(y) + + def result(self, x): + # type: (SupportsInt) -> int + return int(x) + + +class GenericAggregator(MetricAggregator): + def __init__(self, data_class): + self._data_class = data_class + + def identity_element(self): + return self._data_class.identity_element() + + def combine(self, x, y): + return x.combine(y) + + def result(self, x): + return x.get_result() + + class DirectMetrics(MetricResults): def __init__(self): self._counters = defaultdict(lambda: DirectMetric(CounterAggregator())) self._distributions = defaultdict( - lambda: DirectMetric(DistributionAggregator())) - self._gauges = defaultdict(lambda: DirectMetric(GaugeAggregator())) - self._string_sets = defaultdict(lambda: DirectMetric(StringSetAggregator())) + lambda: DirectMetric(GenericAggregator(DistributionData))) + self._gauges = defaultdict( + lambda: DirectMetric(GenericAggregator(GaugeData))) + self._string_sets = defaultdict( + lambda: DirectMetric(GenericAggregator(StringSetData))) def _apply_operation(self, bundle, updates, op): for k, v in updates.counters.items(): diff --git a/sdks/python/apache_beam/transforms/sql_test.py b/sdks/python/apache_beam/transforms/sql_test.py index 854aec078ce5..a7da253c4617 100644 --- a/sdks/python/apache_beam/transforms/sql_test.py +++ b/sdks/python/apache_beam/transforms/sql_test.py @@ -20,6 +20,7 @@ # pytype: skip-file import logging +import subprocess import typing import unittest @@ -69,6 +70,22 @@ class SqlTransformTest(unittest.TestCase): """ _multiprocess_can_split_ = True + @staticmethod + def _disable_zetasql_test(): + # disable if run on Java8 which is no longer supported by ZetaSQL + try: + result = subprocess.run(['java', '-version'], + check=True, + capture_output=True, + text=True) + version_line = result.stderr.splitlines()[0] + version = version_line.split()[2].strip('\"') + if version.startswith("1."): + return True + return False + except: # pylint: disable=bare-except + return False + def test_generate_data(self): with TestPipeline() as p: out = p | SqlTransform( @@ -150,6 +167,9 @@ def test_row(self): assert_that(out, equal_to([(1, 1), (4, 1), (100, 2)])) def test_zetasql_generate_data(self): + if self._disable_zetasql_test(): + raise unittest.SkipTest("ZetaSQL tests need Java11+") + with TestPipeline() as p: out = p | SqlTransform( """SELECT diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 9c0cc2b8af4e..7050df7016e5 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -82,7 +82,6 @@ def foo((a, b)): import inspect import itertools import logging -import sys import traceback import types from typing import Any @@ -686,9 +685,6 @@ def get_type_hints(fn: Any) -> IOTypeHints: # Can't add arbitrary attributes to this object, # but might have some restrictions anyways... hints = IOTypeHints.empty() - # Python 3.7 introduces annotations for _MethodDescriptorTypes. - if isinstance(fn, _MethodDescriptorType) and sys.version_info < (3, 7): - hints = hints.with_input_types(fn.__objclass__) # type: ignore return hints return fn._type_hints # pylint: enable=protected-access diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index 71edc75f31a6..dd110ced5bb8 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -20,7 +20,6 @@ # pytype: skip-file import functools -import sys import typing import unittest @@ -70,14 +69,7 @@ def test_from_callable_builtin(self): def test_from_callable_method_descriptor(self): # from_callable() injects an annotation in this special type of builtin. th = decorators.IOTypeHints.from_callable(str.strip) - if sys.version_info >= (3, 7): - self.assertEqual(th.input_types, ((str, Any), {})) - else: - self.assertEqual( - th.input_types, - ((str, decorators._ANY_VAR_POSITIONAL), { - '__unknown__keywords': decorators._ANY_VAR_KEYWORD - })) + self.assertEqual(th.input_types, ((str, Any), {})) self.assertEqual(th.output_types, ((Any, ), {})) def test_strip_iterable_not_simple_output_noop(self): diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 621adc44507e..6f704b37a969 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -101,10 +101,7 @@ def _match_issubclass(match_against): def _match_is_exactly_mapping(user_type): # Avoid unintentionally catching all subtypes (e.g. strings and mappings). - if sys.version_info < (3, 7): - expected_origin = typing.Mapping - else: - expected_origin = collections.abc.Mapping + expected_origin = collections.abc.Mapping return getattr(user_type, '__origin__', None) is expected_origin @@ -112,10 +109,7 @@ def _match_is_exactly_iterable(user_type): if user_type is typing.Iterable: return True # Avoid unintentionally catching all subtypes (e.g. strings and mappings). - if sys.version_info < (3, 7): - expected_origin = typing.Iterable - else: - expected_origin = collections.abc.Iterable + expected_origin = collections.abc.Iterable return getattr(user_type, '__origin__', None) is expected_origin @@ -244,11 +238,10 @@ def convert_to_beam_type(typ): sys.version_info.minor >= 10) and (isinstance(typ, types.UnionType)): typ = typing.Union[typ] - if sys.version_info >= (3, 9) and isinstance(typ, types.GenericAlias): + if isinstance(typ, types.GenericAlias): typ = convert_builtin_to_typing(typ) - if sys.version_info >= (3, 9) and getattr(typ, '__module__', - None) == 'collections.abc': + if getattr(typ, '__module__', None) == 'collections.abc': typ = convert_collections_to_typing(typ) typ_module = getattr(typ, '__module__', None) diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index 2e6db6a7733c..ae8e1a0b2906 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -21,7 +21,6 @@ import collections.abc import enum -import sys import typing import unittest @@ -128,105 +127,98 @@ def test_convert_to_beam_type(self): self.assertEqual(converted_typing_type, typing_type, description) def test_convert_to_beam_type_with_builtin_types(self): - if sys.version_info >= (3, 9): - test_cases = [ - ('builtin dict', dict[str, int], typehints.Dict[str, int]), - ('builtin list', list[str], typehints.List[str]), - ('builtin tuple', tuple[str], - typehints.Tuple[str]), ('builtin set', set[str], typehints.Set[str]), - ('builtin frozenset', frozenset[int], typehints.FrozenSet[int]), - ( - 'nested builtin', - dict[str, list[tuple[float]]], - typehints.Dict[str, typehints.List[typehints.Tuple[float]]]), - ( - 'builtin nested tuple', - tuple[str, list], - typehints.Tuple[str, typehints.List[typehints.Any]], - ) - ] - - for test_case in test_cases: - description = test_case[0] - builtins_type = test_case[1] - expected_beam_type = test_case[2] - converted_beam_type = convert_to_beam_type(builtins_type) - self.assertEqual(converted_beam_type, expected_beam_type, description) + test_cases = [ + ('builtin dict', dict[str, int], typehints.Dict[str, int]), + ('builtin list', list[str], typehints.List[str]), + ('builtin tuple', tuple[str], + typehints.Tuple[str]), ('builtin set', set[str], typehints.Set[str]), + ('builtin frozenset', frozenset[int], typehints.FrozenSet[int]), + ( + 'nested builtin', + dict[str, list[tuple[float]]], + typehints.Dict[str, typehints.List[typehints.Tuple[float]]]), + ( + 'builtin nested tuple', + tuple[str, list], + typehints.Tuple[str, typehints.List[typehints.Any]], + ) + ] + + for test_case in test_cases: + description = test_case[0] + builtins_type = test_case[1] + expected_beam_type = test_case[2] + converted_beam_type = convert_to_beam_type(builtins_type) + self.assertEqual(converted_beam_type, expected_beam_type, description) def test_convert_to_beam_type_with_collections_types(self): - if sys.version_info >= (3, 9): - test_cases = [ - ( - 'collection iterable', - collections.abc.Iterable[int], - typehints.Iterable[int]), - ( - 'collection generator', - collections.abc.Generator[int], - typehints.Generator[int]), - ( - 'collection iterator', - collections.abc.Iterator[int], - typehints.Iterator[int]), - ( - 'nested iterable', - tuple[bytes, collections.abc.Iterable[int]], - typehints.Tuple[bytes, typehints.Iterable[int]]), - ( - 'iterable over tuple', - collections.abc.Iterable[tuple[str, int]], - typehints.Iterable[typehints.Tuple[str, int]]), - ( - 'mapping not caught', - collections.abc.Mapping[str, int], - collections.abc.Mapping[str, int]), - ('set', collections.abc.Set[str], typehints.Set[str]), - ('mutable set', collections.abc.MutableSet[int], typehints.Set[int]), - ( - 'enum set', - collections.abc.Set[_TestEnum], - typehints.Set[_TestEnum]), - ( - 'enum mutable set', - collections.abc.MutableSet[_TestEnum], - typehints.Set[_TestEnum]), - ( - 'collection enum', - collections.abc.Collection[_TestEnum], - typehints.Collection[_TestEnum]), - ( - 'collection of tuples', - collections.abc.Collection[tuple[str, int]], - typehints.Collection[typehints.Tuple[str, int]]), - ] - - for test_case in test_cases: - description = test_case[0] - builtins_type = test_case[1] - expected_beam_type = test_case[2] - converted_beam_type = convert_to_beam_type(builtins_type) - self.assertEqual(converted_beam_type, expected_beam_type, description) + test_cases = [ + ( + 'collection iterable', + collections.abc.Iterable[int], + typehints.Iterable[int]), + ( + 'collection generator', + collections.abc.Generator[int], + typehints.Generator[int]), + ( + 'collection iterator', + collections.abc.Iterator[int], + typehints.Iterator[int]), + ( + 'nested iterable', + tuple[bytes, collections.abc.Iterable[int]], + typehints.Tuple[bytes, typehints.Iterable[int]]), + ( + 'iterable over tuple', + collections.abc.Iterable[tuple[str, int]], + typehints.Iterable[typehints.Tuple[str, int]]), + ( + 'mapping not caught', + collections.abc.Mapping[str, int], + collections.abc.Mapping[str, int]), + ('set', collections.abc.Set[str], typehints.Set[str]), + ('mutable set', collections.abc.MutableSet[int], typehints.Set[int]), + ('enum set', collections.abc.Set[_TestEnum], typehints.Set[_TestEnum]), + ( + 'enum mutable set', + collections.abc.MutableSet[_TestEnum], + typehints.Set[_TestEnum]), + ( + 'collection enum', + collections.abc.Collection[_TestEnum], + typehints.Collection[_TestEnum]), + ( + 'collection of tuples', + collections.abc.Collection[tuple[str, int]], + typehints.Collection[typehints.Tuple[str, int]]), + ] + + for test_case in test_cases: + description = test_case[0] + builtins_type = test_case[1] + expected_beam_type = test_case[2] + converted_beam_type = convert_to_beam_type(builtins_type) + self.assertEqual(converted_beam_type, expected_beam_type, description) def test_convert_builtin_to_typing(self): - if sys.version_info >= (3, 9): - test_cases = [ - ('dict', dict[str, int], typing.Dict[str, int]), - ('list', list[str], typing.List[str]), - ('tuple', tuple[str], typing.Tuple[str]), - ('set', set[str], typing.Set[str]), - ( - 'nested', - dict[str, list[tuple[float]]], - typing.Dict[str, typing.List[typing.Tuple[float]]]), - ] - - for test_case in test_cases: - description = test_case[0] - builtin_type = test_case[1] - expected_typing_type = test_case[2] - converted_typing_type = convert_builtin_to_typing(builtin_type) - self.assertEqual( - converted_typing_type, expected_typing_type, description) + test_cases = [ + ('dict', dict[str, int], typing.Dict[str, int]), + ('list', list[str], typing.List[str]), + ('tuple', tuple[str], typing.Tuple[str]), + ('set', set[str], typing.Set[str]), + ( + 'nested', + dict[str, list[tuple[float]]], + typing.Dict[str, typing.List[typing.Tuple[float]]]), + ] + + for test_case in test_cases: + description = test_case[0] + builtin_type = test_case[1] + expected_typing_type = test_case[2] + converted_typing_type = convert_builtin_to_typing(builtin_type) + self.assertEqual(converted_typing_type, expected_typing_type, description) def test_generator_converted_to_iterator(self): self.assertEqual( @@ -293,14 +285,11 @@ def test_convert_bare_types(self): typing.Tuple[typing.Iterator], typehints.Tuple[typehints.Iterator[typehints.TypeVariable('T_co')]] ), + ( + 'bare generator', + typing.Generator, + typehints.Generator[typehints.TypeVariable('T_co')]), ] - if sys.version_info >= (3, 7): - test_cases += [ - ( - 'bare generator', - typing.Generator, - typehints.Generator[typehints.TypeVariable('T_co')]), - ] for test_case in test_cases: description = test_case[0] typing_type = test_case[1] diff --git a/sdks/python/apache_beam/typehints/opcodes.py b/sdks/python/apache_beam/typehints/opcodes.py index 62c7a8fadc35..7bea621841f6 100644 --- a/sdks/python/apache_beam/typehints/opcodes.py +++ b/sdks/python/apache_beam/typehints/opcodes.py @@ -246,14 +246,10 @@ def set_add(state, arg): def map_add(state, arg): - if sys.version_info >= (3, 8): - # PEP 572 The MAP_ADD expects the value as the first element in the stack - # and the key as the second element. - new_value_type = Const.unwrap(state.stack.pop()) - new_key_type = Const.unwrap(state.stack.pop()) - else: - new_key_type = Const.unwrap(state.stack.pop()) - new_value_type = Const.unwrap(state.stack.pop()) + # PEP 572 The MAP_ADD expects the value as the first element in the stack + # and the key as the second element. + new_value_type = Const.unwrap(state.stack.pop()) + new_key_type = Const.unwrap(state.stack.pop()) state.stack[-arg] = Dict[Union[state.stack[-arg].key_type, new_key_type], Union[state.stack[-arg].value_type, new_value_type]] diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test.py b/sdks/python/apache_beam/typehints/typed_pipeline_test.py index 57e7f44f6922..44318fa44a8c 100644 --- a/sdks/python/apache_beam/typehints/typed_pipeline_test.py +++ b/sdks/python/apache_beam/typehints/typed_pipeline_test.py @@ -19,7 +19,6 @@ # pytype: skip-file -import sys import typing import unittest @@ -874,12 +873,7 @@ def test_flat_type_hint(self): class AnnotationsTest(unittest.TestCase): def test_pardo_wrapper_builtin_method(self): th = beam.ParDo(str.strip).get_type_hints() - if sys.version_info < (3, 7): - self.assertEqual(th.input_types, ((str, ), {})) - else: - # Python 3.7+ has annotations for CPython builtins - # (_MethodDescriptorType). - self.assertEqual(th.input_types, ((str, typehints.Any), {})) + self.assertEqual(th.input_types, ((str, typehints.Any), {})) self.assertEqual(th.output_types, ((typehints.Any, ), {})) def test_pardo_wrapper_builtin_type(self): diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 912cb78dc095..0e18e887c2a0 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -391,12 +391,6 @@ def validate_composite_type_param(type_param, error_msg_prefix): if sys.version_info.major == 3 and sys.version_info.minor >= 10: if isinstance(type_param, types.UnionType): is_not_type_constraint = False - # Pre-Python 3.9 compositve type-hinting with built-in types was not - # supported, the typing module equivalents should be used instead. - if sys.version_info.major == 3 and sys.version_info.minor < 9: - is_not_type_constraint = is_not_type_constraint or ( - isinstance(type_param, type) and - type_param in DISALLOWED_PRIMITIVE_TYPES) if is_not_type_constraint: raise TypeError( @@ -1266,7 +1260,7 @@ def normalize(x, none_as_type=False): # Avoid circular imports from apache_beam.typehints import native_type_compatibility - if sys.version_info >= (3, 9) and isinstance(x, types.GenericAlias): + if isinstance(x, types.GenericAlias): x = native_type_compatibility.convert_builtin_to_typing(x) if none_as_type and x is None: diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 843c1498cac5..6611dcecab01 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -388,15 +388,10 @@ def test_getitem_params_must_be_type_or_constraint(self): typehints.Tuple[5, [1, 3]] self.assertTrue(e.exception.args[0].startswith(expected_error_prefix)) - if sys.version_info < (3, 9): - with self.assertRaises(TypeError) as e: - typehints.Tuple[list, dict] - self.assertTrue(e.exception.args[0].startswith(expected_error_prefix)) - else: - try: - typehints.Tuple[list, dict] - except TypeError: - self.fail("built-in composite raised TypeError unexpectedly") + try: + typehints.Tuple[list, dict] + except TypeError: + self.fail("built-in composite raised TypeError unexpectedly") def test_compatibility_arbitrary_length(self): self.assertNotCompatible( @@ -548,15 +543,13 @@ def test_type_check_invalid_composite_type_arbitrary_length(self): e.exception.args[0]) def test_normalize_with_builtin_tuple(self): - if sys.version_info >= (3, 9): - expected_beam_type = typehints.Tuple[int, int] - converted_beam_type = typehints.normalize(tuple[int, int], False) - self.assertEqual(converted_beam_type, expected_beam_type) + expected_beam_type = typehints.Tuple[int, int] + converted_beam_type = typehints.normalize(tuple[int, int], False) + self.assertEqual(converted_beam_type, expected_beam_type) def test_builtin_and_type_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible(tuple, typing.Tuple) - self.assertCompatible(tuple[int, int], typing.Tuple[int, int]) + self.assertCompatible(tuple, typing.Tuple) + self.assertCompatible(tuple[int, int], typing.Tuple[int, int]) class ListHintTestCase(TypeHintTestCase): @@ -618,22 +611,19 @@ def test_enforce_list_type_constraint_invalid_composite_type(self): e.exception.args[0]) def test_normalize_with_builtin_list(self): - if sys.version_info >= (3, 9): - expected_beam_type = typehints.List[int] - converted_beam_type = typehints.normalize(list[int], False) - self.assertEqual(converted_beam_type, expected_beam_type) + expected_beam_type = typehints.List[int] + converted_beam_type = typehints.normalize(list[int], False) + self.assertEqual(converted_beam_type, expected_beam_type) def test_builtin_and_type_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible(list, typing.List) - self.assertCompatible(list[int], typing.List[int]) + self.assertCompatible(list, typing.List) + self.assertCompatible(list[int], typing.List[int]) def test_is_typing_generic(self): self.assertTrue(typehints.is_typing_generic(typing.List[str])) def test_builtin_is_typing_generic(self): - if sys.version_info >= (3, 9): - self.assertTrue(typehints.is_typing_generic(list[str])) + self.assertTrue(typehints.is_typing_generic(list[str])) class KVHintTestCase(TypeHintTestCase): @@ -687,14 +677,10 @@ def test_getitem_param_must_have_length_2(self): e.exception.args[0]) def test_key_type_must_be_valid_composite_param(self): - if sys.version_info < (3, 9): - with self.assertRaises(TypeError): - typehints.Dict[list, int] - else: - try: - typehints.Tuple[list, int] - except TypeError: - self.fail("built-in composite raised TypeError unexpectedly") + try: + typehints.Tuple[list, int] + except TypeError: + self.fail("built-in composite raised TypeError unexpectedly") def test_value_type_must_be_valid_composite_param(self): with self.assertRaises(TypeError): @@ -777,35 +763,24 @@ def test_match_type_variables(self): hint.match_type_variables(typehints.Dict[int, str])) def test_normalize_with_builtin_dict(self): - if sys.version_info >= (3, 9): - expected_beam_type = typehints.Dict[str, int] - converted_beam_type = typehints.normalize(dict[str, int], False) - self.assertEqual(converted_beam_type, expected_beam_type) + expected_beam_type = typehints.Dict[str, int] + converted_beam_type = typehints.normalize(dict[str, int], False) + self.assertEqual(converted_beam_type, expected_beam_type) def test_builtin_and_type_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible(dict, typing.Dict) - self.assertCompatible(dict[str, int], typing.Dict[str, int]) - self.assertCompatible( - dict[str, list[int]], typing.Dict[str, typing.List[int]]) + self.assertCompatible(dict, typing.Dict) + self.assertCompatible(dict[str, int], typing.Dict[str, int]) + self.assertCompatible( + dict[str, list[int]], typing.Dict[str, typing.List[int]]) class BaseSetHintTest: class CommonTests(TypeHintTestCase): def test_getitem_invalid_composite_type_param(self): - if sys.version_info < (3, 9): - with self.assertRaises(TypeError) as e: - self.beam_type[list] - self.assertEqual( - "Parameter to a {} hint must be a non-sequence, a " - "type, or a TypeConstraint. {} is an instance of " - "type.".format(self.string_type, list), - e.exception.args[0]) - else: - try: - self.beam_type[list] - except TypeError: - self.fail("built-in composite raised TypeError unexpectedly") + try: + self.beam_type[list] + except TypeError: + self.fail("built-in composite raised TypeError unexpectedly") def test_non_typing_generic(self): testCase = DummyTestClass1() @@ -855,16 +830,14 @@ class SetHintTestCase(BaseSetHintTest.CommonTests): string_type = 'Set' def test_builtin_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible(set[int], collections.abc.Set[int]) - self.assertCompatible(set[int], collections.abc.MutableSet[int]) + self.assertCompatible(set[int], collections.abc.Set[int]) + self.assertCompatible(set[int], collections.abc.MutableSet[int]) def test_collections_compatibility(self): - if sys.version_info >= (3, 9): - self.assertCompatible( - collections.abc.Set[int], collections.abc.MutableSet[int]) - self.assertCompatible( - collections.abc.MutableSet[int], collections.abc.Set[int]) + self.assertCompatible( + collections.abc.Set[int], collections.abc.MutableSet[int]) + self.assertCompatible( + collections.abc.MutableSet[int], collections.abc.Set[int]) class FrozenSetHintTestCase(BaseSetHintTest.CommonTests): @@ -1416,37 +1389,16 @@ def func(a, b_c, *d): func, *[Any, Any, Tuple[str, ...], int])) def test_getcallargs_forhints_builtins(self): - if sys.version_info < (3, 7): - # Signatures for builtins are not supported in 3.5 and 3.6. - self.assertEqual({ - '_': str, - '__unknown__varargs': Tuple[Any, ...], - '__unknown__keywords': typehints.Dict[Any, Any] - }, - getcallargs_forhints(str.upper, str)) - self.assertEqual({ - '_': str, - '__unknown__varargs': Tuple[str, ...], - '__unknown__keywords': typehints.Dict[Any, Any] - }, - getcallargs_forhints(str.strip, str, str)) - self.assertEqual({ - '_': str, - '__unknown__varargs': Tuple[typehints.List[int], ...], - '__unknown__keywords': typehints.Dict[Any, Any] - }, - getcallargs_forhints(str.join, str, typehints.List[int])) - else: - self.assertEqual({'self': str}, getcallargs_forhints(str.upper, str)) - # str.strip has an optional second argument. - self.assertEqual({ - 'self': str, 'chars': Any - }, - getcallargs_forhints(str.strip, str)) - self.assertEqual({ - 'self': str, 'iterable': typehints.List[int] - }, - getcallargs_forhints(str.join, str, typehints.List[int])) + self.assertEqual({'self': str}, getcallargs_forhints(str.upper, str)) + # str.strip has an optional second argument. + self.assertEqual({ + 'self': str, 'chars': Any + }, + getcallargs_forhints(str.strip, str)) + self.assertEqual({ + 'self': str, 'iterable': typehints.List[int] + }, + getcallargs_forhints(str.join, str, typehints.List[int])) class TestGetYieldedType(unittest.TestCase): diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 3bef1a0a1101..9f92f59f42b6 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -467,7 +467,9 @@ def expand(self, pcoll): break else: raise ValueError( - f"No field name matches one of {self._ERROR_FIELD_NAMES}") + 'The input to this transform does not appear to be an error ' + + "output. Expected a schema'd input with a field named " + + ' or '.join(repr(fld) for fld in self._ERROR_FIELD_NAMES)) if fld is None: # This handles with_exception_handling() that returns bare tuples. diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index ab86a2aaff56..0190fe20413f 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -576,8 +576,8 @@ def is_not_output_of_last_transform(new_transforms, value): pass else: raise ValueError( - f'Transform {identify_object(transform)} is part of a chain, ' - 'must have implicit inputs and outputs.') + f'Transform {identify_object(transform)} is part of a chain. ' + 'Cannot define explicit inputs on chain pipeline') if ix == 0: if is_explicitly_empty(transform.get('input', None)): pass diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index bc0493509d5a..084e03cdb197 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -325,6 +325,23 @@ def test_chain_as_composite_with_input(self): self.assertEqual( chain_as_composite(spec)['transforms'][0]['input'], {"input": "input"}) + def test_chain_as_composite_with_transform_input(self): + spec = ''' + type: chain + transforms: + - type: Create + config: + elements: [0,1,2] + - type: LogForTesting + input: Create + ''' + spec = yaml.load(spec, Loader=SafeLineLoader) + with self.assertRaisesRegex( + ValueError, + r"Transform .* is part of a chain. " + r"Cannot define explicit inputs on chain pipeline"): + chain_as_composite(spec) + def test_normalize_source_sink(self): spec = ''' source: diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 3b45cbf82fc1..53c7a532e706 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -490,12 +490,9 @@ def get_portability_package_data(): 'sentence-transformers', 'skl2onnx', 'pillow', - # Support TF 2.16.0: https://github.com/apache/beam/issues/31294 - # Once TF version is unpinned, also don't restrict Python version. - 'tensorflow<2.16.0;python_version<"3.12"', + 'tensorflow', 'tensorflow-hub', - # https://github.com/tensorflow/transform/issues/313 - 'tensorflow-transform;python_version<"3.11"', + 'tensorflow-transform', 'tf2onnx', 'torch', 'transformers', @@ -504,6 +501,19 @@ def get_portability_package_data(): # https://github.com/apache/beam/issues/31285 # 'xgboost<2.0', # https://github.com/apache/beam/issues/31252 ], + 'p312_ml_test': [ + 'datatable', + 'embeddings', + 'onnxruntime', + 'sentence-transformers', + 'skl2onnx', + 'pillow', + 'tensorflow', + 'tensorflow-hub', + 'tf2onnx', + 'torch', + 'transformers', + ], 'aws': ['boto3>=1.9,<2'], 'azure': [ 'azure-storage-blob>=12.3.2,<13', diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index 85f71e559c15..68ac15ced70d 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -101,11 +101,23 @@ commands = python apache_beam/examples/complete/autocomplete_test.py bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" -[testenv:py{39,310,311,312}-ml] +[testenv:py{39,310,311}-ml] # Don't set TMPDIR to avoid "AF_UNIX path too long" errors in certain tests. setenv = extras = test,gcp,dataframe,ml_test commands = + # Log tensorflow version for debugging + /bin/sh -c "pip freeze | grep -E tensorflow" + bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" + +[testenv:py312-ml] +# many packages do not support py3.12 +# Don't set TMPDIR to avoid "AF_UNIX path too long" errors in certain tests. +setenv = +extras = test,gcp,dataframe,p312_ml_test +commands = + # Log tensorflow version for debugging + /bin/sh -c "pip freeze | grep -E tensorflow" bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" [testenv:py{39,310,311,312}-dask] @@ -410,8 +422,9 @@ deps = tensorflow>=2.12rc1,<2.13 # Help pip resolve conflict with typing-extensions for old version of TF https://github.com/apache/beam/issues/30852 pydantic<2.7 - protobuf==4.25.5 -extras = test,gcp,ml_test +extras = test,gcp +commands_pre = + pip install -U 'protobuf==4.25.5' commands = # Log tensorflow version for debugging /bin/sh -c "pip freeze | grep -E tensorflow" diff --git a/settings.gradle.kts b/settings.gradle.kts index ca30a5ea750a..d90bb3fb5b82 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -19,7 +19,7 @@ import com.gradle.enterprise.gradleplugin.internal.extension.BuildScanExtensionW pluginManagement { plugins { - id("org.javacc.javacc") version "3.0.2" // enable the JavaCC parser generator + id("org.javacc.javacc") version "3.0.3" // enable the JavaCC parser generator } } diff --git a/website/www/site/config.toml b/website/www/site/config.toml index d769f8434a7f..0ed8aef2a906 100644 --- a/website/www/site/config.toml +++ b/website/www/site/config.toml @@ -104,7 +104,7 @@ github_project_repo = "https://github.com/apache/beam" [params] description = "Apache Beam is an open source, unified model and set of language-specific SDKs for defining and executing data processing workflows, and also data ingestion and integration flows, supporting Enterprise Integration Patterns (EIPs) and Domain Specific Languages (DSLs). Dataflow pipelines simplify the mechanics of large-scale batch and streaming data processing and can run on a number of runtimes like Apache Flink, Apache Spark, and Google Cloud Dataflow (a cloud service). Beam also brings DSL in different languages, allowing users to easily implement their data integration processes." -release_latest = "2.60.0" +release_latest = "2.61.0" # The repository and branch where the files live in Github or Colab. This is used # to serve and stage from your local branch, but publish to the master branch. # e.g. https://github.com/{{< param branch_repo >}}/path/to/notebook.ipynb diff --git a/website/www/site/content/en/blog/beam-2.60.0.md b/website/www/site/content/en/blog/beam-2.60.0.md index ae5e0284ccdd..e5767cff5114 100644 --- a/website/www/site/content/en/blog/beam-2.60.0.md +++ b/website/www/site/content/en/blog/beam-2.60.0.md @@ -67,6 +67,7 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) * (Java) Fixed custom delimiter issues in TextIO ([#32249](https://github.com/apache/beam/issues/32249), [#32251](https://github.com/apache/beam/issues/32251)). * (Java, Python, Go) Fixed PeriodicSequence backlog bytes reporting, which was preventing Dataflow Runner autoscaling from functioning properly ([#32506](https://github.com/apache/beam/issues/32506)). * (Java) Fix improper decoding of rows with schemas containing nullable fields when encoded with a schema with equal encoding positions but modified field order. ([#32388](https://github.com/apache/beam/issues/32388)). +* (Java) Skip close on bundles in BigtableIO.Read ([#32661](https://github.com/apache/beam/pull/32661), [#32759](https://github.com/apache/beam/pull/32759)). ## Known Issues diff --git a/website/www/site/content/en/blog/beam-2.61.0.md b/website/www/site/content/en/blog/beam-2.61.0.md new file mode 100644 index 000000000000..a4c7ac0cefbd --- /dev/null +++ b/website/www/site/content/en/blog/beam-2.61.0.md @@ -0,0 +1,74 @@ +--- +title: "Apache Beam 2.61.0" +date: 2024-11-25 15:00:00 -0500 +categories: + - blog + - release +authors: + - damccorm +--- + + +We are happy to present the new 2.61.0 release of Beam. +This release includes both improvements and new functionality. +See the [download page](/get-started/downloads/#2610-2024-11-25) for this release. + + + +For more information on changes in 2.61.0, check out the [detailed release notes](https://github.com/apache/beam/milestone/25). + +## Highlights + +* [Python] Introduce Managed Transforms API ([#31495](https://github.com/apache/beam/pull/31495)) +* Flink 1.19 support added ([#32648](https://github.com/apache/beam/pull/32648)) + +## I/Os + +* [Managed Iceberg] Support creating tables if needed ([#32686](https://github.com/apache/beam/pull/32686)) +* [Managed Iceberg] Now available in Python SDK ([#31495](https://github.com/apache/beam/pull/31495)) +* [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) +* BigQuery CDC writes are now available in Python SDK, only supported when using StorageWrite API at least once mode ([#32527](https://github.com/apache/beam/issues/32527)) +* [Managed Iceberg] Allow updating table partition specs during pipeline runtime ([#32879](https://github.com/apache/beam/pull/32879)) +* Added BigQueryIO as a Managed IO ([#31486](https://github.com/apache/beam/pull/31486)) +* Support for writing to [Solace messages queues](https://solace.com/) (`SolaceIO.Write`) added (Java) ([#31905](https://github.com/apache/beam/issues/31905)). + +## New Features / Improvements + +* Added support for read with metadata in MqttIO (Java) ([#32195](https://github.com/apache/beam/issues/32195)) +* Added support for processing events which use a global sequence to "ordered" extension (Java) ([#32540](https://github.com/apache/beam/pull/32540)) +* Add new meta-transform FlattenWith and Tee that allow one to introduce branching + without breaking the linear/chaining style of pipeline construction. +* Use Prism as a fallback to the Python Portable runner when running a pipeline with the Python Direct runner ([#32876](https://github.com/apache/beam/pull/32876)) + +## Deprecations + +* Removed support for Flink 1.15 and 1.16 +* Removed support for Python 3.8 + +## Bugfixes + +* (Java) Fixed tearDown not invoked when DoFn throws on Portable Runners ([#18592](https://github.com/apache/beam/issues/18592), [#31381](https://github.com/apache/beam/issues/31381)). +* (Java) Fixed protobuf error with MapState.remove() in Dataflow Streaming Java Legacy Runner without Streaming Engine ([#32892](https://github.com/apache/beam/issues/32892)). +* Adding flag to support conditionally disabling auto-commit in JdbcIO ReadFn ([#31111](https://github.com/apache/beam/issues/31111)) + +## Known Issues + +N/A + +For the most up to date list of known issues, see https://github.com/apache/beam/blob/master/CHANGES.md + +## List of Contributors + +According to git shortlog, the following people contributed to the 2.60.0 release. Thank you to all contributors! + +Ahmed Abualsaud, Ahmet Altay, Arun Pandian, Ayush Pandey, Chamikara Jayalath, Chris Ashcraft, Christoph Grotz, DKPHUONG, Damon, Danny Mccormick, Dmitry Ulyumdzhiev, Ferran Fernández Garrido, Hai Joey Tran, Hyeonho Kim, Idan Attias, Israel Herraiz, Jack McCluskey, Jan Lukavský, Jeff Kinard, Jeremy Edwards, Joey Tran, Kenneth Knowles, Maciej Szwaja, Manit Gupta, Mattie Fu, Michel Davit, Minbo Bae, Mohamed Awnallah, Naireen Hussain, Rebecca Szper, Reeba Qureshi, Reuven Lax, Robert Bradshaw, Robert Burke, S. Veyrié, Sam Whittle, Sergei Lilichenko, Shunping Huang, Steven van Rossum, Tan Le, Thiago Nunes, Vitaly Terentyev, Vlado Djerek, Yi Hu, claudevdm, fozzie15, johnjcasey, kushmiD, liferoad, martin trieu, pablo rodriguez defino, razvanculea, s21lee, tvalentyn, twosom diff --git a/website/www/site/content/en/case-studies/accenture_baltics.md b/website/www/site/content/en/case-studies/accenture_baltics.md new file mode 100644 index 000000000000..98f9d9a8a687 --- /dev/null +++ b/website/www/site/content/en/case-studies/accenture_baltics.md @@ -0,0 +1,105 @@ +--- +title: "Accenture Baltics' Journey with Apache Beam to Streamlined Data Workflows for a Sustainable Energy Leader" +name: "Accenture Baltics" +icon: /images/logos/powered-by/accenture.png +hasNav: true +category: study +cardTitle: "Accenture Baltics' Journey with Apache Beam" +cardDescription: "Accenture Baltics uses Apache Beam on Google Cloud to build a robust data processing infrastructure for a sustainable energy leader.They use Beam to democratize data access, process data in real-time, and handle complex ETL tasks." +authorName: "Jana Polianskaja" +authorPosition: "Data Engineer @ Accenture Baltics" +authorImg: /images/case-study/accenture/Jana_Polianskaja_sm.jpg +publishDate: 2024-11-25T00:12:00+00:00 +--- + + +

+
+ +
+
+

+ “Apache Beam empowers team members who don’t have data engineering backgrounds to directly access and analyze BigQuery data by using SQL. The data scientists, the finance department, and production optimization teams all benefit from improved data accessibility, which gives them immediate access to critical information for faster analysis and decision-making.” +

+
+
+ +
+
+
+ Jana Polianskaja +
+
+ Data Engineer @ Accenture Baltics +
+
+
+
+
+ + +
+ +# Accenture Baltics' Journey with Apache Beam to Streamlined Data Workflows for a Sustainable Energy Leader + +## Background + +Accenture Baltics, a branch of the global professional services company Accenture, leverages its expertise across various industries to provide consulting, technology, and outsourcing solutions to clients worldwide. A specific project at Accenture Baltics highlights the effective implementation of Apache Beam to support a client who is a global leader in sustainable energy and uses Google Cloud. + +## Journey to Apache Beam + +The team responsible for transforming, curating, and preparing data, including transactional, analytics, and sensor data, for data scientists and other teams has been using Dataflow with Apache Beam for about five years. Dataflow with Beam is a natural choice for both streaming and batch data processing. For our workloads, we typically use the following configurations: worker machine types are `n1-standard-2` or `n1-standard-4`, and the maximum number of workers varies up to five, using the Dataflow runner. + +As an example, a streaming pipeline ingests transaction data from Pub/Sub, performs basic ETL and data cleaning, and outputs the results to BigQuery. A separate batch Dataflow pipeline evaluates a binary classification model, reading input and writing results to Google Cloud Storage. The following diagram shows a workflow that uses Pub/Sub to feed Dataflow pipelines across three Google Cloud projects. It also shows how Dataflow, Composer, Cloud Storage, BigQuery, and Grafana integrate into the architecture. + +
+ + Diagram of Accenture Baltics' Dataflow pipeline architecture + +
+ +## Use Cases + +Apache Beam is an invaluable tool for our use cases, particularly in the following areas: + +* **Democratizing data access:** Beam empowers team members without data engineering backgrounds to directly access and analyze BigQuery data using their SQL skills. The data scientists, the finance department, and production optimization teams all benefit from improved data accessibility, gaining immediate access to critical information for faster analysis and decision-making. +* **Real-time data processing:** Beam excels at ingesting and processing data in real time from sources like Pub/Sub. +* **ETL (extract, transform, load):** Beam effectively manages the full spectrum of data transformation and cleaning tasks, even when dealing with complex data structures. +* **Data routing and partitioning:** Beam enables sophisticated data routing and partitioning strategies. For example, it can automatically route failed transactions to a separate BigQuery table for further analysis. +* **Data deduplication and error handling:** Beam has been instrumental in tackling challenging tasks like deduplicating Pub/Sub messages and implementing robust error handling, such as for JSON parsing, that are crucial for maintaining data integrity and pipeline reliability. + +We also utilize Grafana (shown in below) with custom notification emails and tickets for comprehensive monitoring of our Beam pipelines. Notifications are generated from Google’s Cloud Logging and Cloud Monitoring services to ensure we stay informed about the performance and health of our pipelines. The seamless integration of Airflow with Dataflow and Beam further enhances our workflow, allowing us to effortlessly use operators such as `DataflowCreatePythonJobOperator` and `BeamRunPythonPipelineOperator` in [Airflow 2](https://airflow.apache.org/docs/apache-airflow-providers-google/stable/_api/airflow/providers/google/cloud/operators/dataflow/index.html). + +
+ + scheme + +
+ +## Results + +Our data processing infrastructure uses 12 distinct pipelines to manage and transform data across various projects within the organization. These pipelines are divided into two primary categories: + +* **Streaming pipelines:** These pipelines are designed to handle real-time or near real-time data streams. In our current setup, these pipelines process an average of 10,000 messages per second from Pub/Sub and about 200,000 rows per hour to BigQuery, ensuring that time-sensitive data is ingested and processed with minimal latency. +* **Batch pipelines:** These pipelines are optimized for processing large volumes of data in scheduled batches. Our current batch pipelines handle approximately two gigabytes of data per month, transforming and loading this data into our data warehouse for further analysis and reporting. + +Apache Beam has proven to be a highly effective solution for orchestrating and managing the complex data pipelines required by the client. By leveraging the capabilities of Dataflow, a fully managed service for executing Beam pipelines, we have successfully addressed and fulfilled the client's specific data processing needs. This powerful combination has enabled us to achieve scalability, reliability, and efficiency in handling large volumes of data, ensuring timely and accurate delivery of insights to the client. + +*Check out [my Medium blog](https://medium.com/@jana_om)\! I usually post about using Beam/Dataflow as an ETL tool with Python and how it works with other data engineering tools. My focus is on building projects that are easy to understand and learn from, especially if you want to get some hands-on experience with Beam.* + + +{{< case_study_feedback "AccentureBalticsStreaming" >}} +
+
diff --git a/website/www/site/content/en/documentation/sdks/python-streaming.md b/website/www/site/content/en/documentation/sdks/python-streaming.md index 2d0bdfa9500b..7122cc8b9ae5 100644 --- a/website/www/site/content/en/documentation/sdks/python-streaming.md +++ b/website/www/site/content/en/documentation/sdks/python-streaming.md @@ -155,11 +155,3 @@ about executing streaming pipelines: - [DirectRunner streaming execution](/documentation/runners/direct/#streaming-execution) - [DataflowRunner streaming execution](/documentation/runners/dataflow/#streaming-execution) - [Portable Flink runner](/documentation/runners/flink/) - -## Unsupported features - -Python streaming execution does not currently support the following features: - -- Custom source API -- User-defined custom merging `WindowFn` (with fnapi) -- For portable runners, see [portability support table](https://s.apache.org/apache-beam-portability-support-table). diff --git a/website/www/site/content/en/get-started/downloads.md b/website/www/site/content/en/get-started/downloads.md index ff432996578d..dea5dc314b17 100644 --- a/website/www/site/content/en/get-started/downloads.md +++ b/website/www/site/content/en/get-started/downloads.md @@ -96,11 +96,19 @@ versions denoted `0.x.y`. ## Releases +### 2.61.0 (2024-11-25) + +Official [source code download](https://downloads.apache.org/beam/2.61.0/apache-beam-2.61.0-source-release.zip). +[SHA-512](https://downloads.apache.org/beam/2.61.0/apache-beam-2.61.0-source-release.zip.sha512). +[signature](https://downloads.apache.org/beam/2.61.0/apache-beam-2.61.0-source-release.zip.asc). + +[Release notes](https://github.com/apache/beam/releases/tag/v2.61.0) + ### 2.60.0 (2024-10-17) -Official [source code download](https://downloads.apache.org/beam/2.60.0/apache-beam-2.60.0-source-release.zip). -[SHA-512](https://downloads.apache.org/beam/2.60.0/apache-beam-2.60.0-source-release.zip.sha512). -[signature](https://downloads.apache.org/beam/2.60.0/apache-beam-2.60.0-source-release.zip.asc). +Official [source code download](https://archive.apache.org/dist/beam/2.60.0/apache-beam-2.60.0-source-release.zip). +[SHA-512](https://archive.apache.org/dist/beam/2.60.0/apache-beam-2.60.0-source-release.zip.sha512). +[signature](https://archive.apache.org/dist/beam/2.60.0/apache-beam-2.60.0-source-release.zip.asc). [Release notes](https://github.com/apache/beam/releases/tag/v2.60.0) diff --git a/website/www/site/data/en/quotes.yaml b/website/www/site/data/en/quotes.yaml index 4139d855fff5..db45c4344346 100644 --- a/website/www/site/data/en/quotes.yaml +++ b/website/www/site/data/en/quotes.yaml @@ -76,6 +76,11 @@ logoUrl: /images/logos/powered-by/yelp.png linkUrl: case-studies/yelp_streaming/index.html linkText: Learn more +- text: Accenture Baltics uses Apache Beam on Google Cloud to build a robust data processing infrastructure for a sustainable energy leader.They use Beam to democratize data access, process data in real-time, and handle complex ETL tasks. + icon: icons/quote-icon.svg + logoUrl: /images/logos/powered-by/accenture.png + linkUrl: case-studies/accenture_baltics/index.html + linkText: Learn more - text: Have a story to share? Your logo could be here. icon: icons/quote-icon.svg logoUrl: images/logos/powered-by/blank.jpg diff --git a/website/www/site/static/images/case-study/accenture/Jana_Polianskaja_sm.jpg b/website/www/site/static/images/case-study/accenture/Jana_Polianskaja_sm.jpg new file mode 100644 index 000000000000..49ce23e5d7cd Binary files /dev/null and b/website/www/site/static/images/case-study/accenture/Jana_Polianskaja_sm.jpg differ diff --git a/website/www/site/static/images/case-study/accenture/dataflow_grafana.jpg b/website/www/site/static/images/case-study/accenture/dataflow_grafana.jpg new file mode 100644 index 000000000000..8a7471c1798f Binary files /dev/null and b/website/www/site/static/images/case-study/accenture/dataflow_grafana.jpg differ diff --git a/website/www/site/static/images/case-study/accenture/dataflow_pipelines.png b/website/www/site/static/images/case-study/accenture/dataflow_pipelines.png new file mode 100644 index 000000000000..423044910e36 Binary files /dev/null and b/website/www/site/static/images/case-study/accenture/dataflow_pipelines.png differ