diff --git a/.github/autolabeler.yml b/.github/autolabeler.yml index 5a8a22044da4..d1cc8296d303 100644 --- a/.github/autolabeler.yml +++ b/.github/autolabeler.yml @@ -31,6 +31,7 @@ python: ["sdks/python/**/*", "learning/katas/python/**/*"] typescript: ["sdks/typescript/**/*"] vendor: ["vendor/**/*"] website: ["website/**/*"] +yaml: ["sdks/python/apache_beam/yaml/**"] # Extensions extensions: ["sdks/java/extensions/**/*", "runners/extensions-java/**/*"] diff --git a/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml b/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml new file mode 100644 index 000000000000..952273e810d2 --- /dev/null +++ b/.github/workflows/beam_PostCommit_Java_BigQueryEarlyRollout.yml @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: PostCommit Java BigQueryEarlyRollout + +on: + issue_comment: + types: [created] + schedule: + - cron: '0 */6 * * *' + workflow_dispatch: + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.body || github.event.sender.login}}' + cancel-in-progress: true + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: write + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +env: + GRADLE_ENTERPRISE_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PostCommit_Java_BigQueryEarlyRollout: + name: ${{matrix.job_name}} (${{matrix.job_phrase}}) + runs-on: [self-hosted, ubuntu-20.04, main] + timeout-minutes: 100 + strategy: + matrix: + job_name: [beam_PostCommit_Java_BigQueryEarlyRollout] + job_phrase: [Run Java BigQueryEarlyRollout PostCommit] + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'schedule' || + github.event.comment.body == 'Run Java BigQueryEarlyRollout PostCommit' + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Authenticate on GCP + uses: google-github-actions/setup-gcloud@v0 + with: + service_account_email: ${{ secrets.GCP_SA_EMAIL }} + service_account_key: ${{ secrets.GCP_SA_KEY }} + project_id: ${{ secrets.GCP_PROJECT_ID }} + export_default_credentials: true + - name: run PostCommit Java BigQueryEarlyRollout script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:java:io:google-cloud-platform:bigQueryEarlyRolloutIntegrationTest + - name: Archive JUnit Test Results + uses: actions/upload-artifact@v3 + if: failure() + with: + name: JUnit Test Results + path: "**/build/reports/tests/" + - name: Publish JUnit Test Results + uses: EnricoMi/publish-unit-test-result-action@v2 + if: always() + with: + commit: '${{ env.prsha || env.GITHUB_SHA }}' + comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} + files: '**/build/test-results/**/*.xml' \ No newline at end of file diff --git a/.github/workflows/beam_PostCommit_Python.yml b/.github/workflows/beam_PostCommit_Python.yml index a7a214c7c5a9..6f4bc5e2ef0b 100644 --- a/.github/workflows/beam_PostCommit_Python.yml +++ b/.github/workflows/beam_PostCommit_Python.yml @@ -53,7 +53,7 @@ env: jobs: beam_PostCommit_Python: name: ${{matrix.job_name}} (${{matrix.job_phrase}} ${{matrix.python_version}}) - runs-on: [self-hosted, ubuntu-20.04, highmem] + runs-on: [self-hosted, ubuntu-20.04, main] timeout-minutes: 240 strategy: fail-fast: false diff --git a/.github/workflows/beam_PostCommit_Python_Arm.yml b/.github/workflows/beam_PostCommit_Python_Arm.yml index a77c4e96dc51..8be303a82d1d 100644 --- a/.github/workflows/beam_PostCommit_Python_Arm.yml +++ b/.github/workflows/beam_PostCommit_Python_Arm.yml @@ -18,10 +18,10 @@ name: PostCommit Python Arm on: - # issue_comment: - # types: [created] - # schedule: - # - cron: '0 */6 * * *' + issue_comment: + types: [created] + schedule: + - cron: '0 */6 * * *' workflow_dispatch: # This allows a subsequently queued workflow run to interrupt previous runs @@ -81,12 +81,20 @@ jobs: run: | sudo curl -L https://github.com/docker/compose/releases/download/1.22.0/docker-compose-$(uname -s)-$(uname -m) -o /usr/local/bin/docker-compose sudo chmod +x /usr/local/bin/docker-compose + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: GCloud Docker credential helper + run: | + gcloud auth configure-docker us.gcr.io - name: Set PY_VER_CLEAN id: set_py_ver_clean run: | PY_VER=${{ matrix.python_version }} PY_VER_CLEAN=${PY_VER//.} echo "py_ver_clean=$PY_VER_CLEAN" >> $GITHUB_OUTPUT + - name: Generate TAG unique variable based on timestamp + id: set_tag + run: echo "TAG=$(date +'%Y%m%d-%H%M%S%N')" >> $GITHUB_OUTPUT - name: run PostCommit Python ${{ matrix.python_version }} script uses: ./.github/actions/gradle-command-self-hosted-action with: @@ -94,8 +102,14 @@ jobs: arguments: | -PuseWheelDistribution \ -PpythonVersion=${{ matrix.python_version }} \ + -Pcontainer-architecture-list=arm64,amd64 \ + -Pdocker-repository-root=us.gcr.io/apache-beam-testing/github-actions \ + -Pdocker-tag=${{ steps.set_tag.outputs.TAG }} \ + -Ppush-containers \ env: CLOUDSDK_CONFIG: ${{ env.KUBELET_GCLOUD_CONFIG_PATH}} + MULTIARCH_TAG: ${{ steps.set_tag.outputs.TAG }} + USER: github-actions - name: Archive code coverage results uses: actions/upload-artifact@v3 with: diff --git a/.github/workflows/beam_PreCommit_CommunityMetrics.yml b/.github/workflows/beam_PreCommit_CommunityMetrics.yml index f044b154c0ab..bb44ca0b5464 100644 --- a/.github/workflows/beam_PreCommit_CommunityMetrics.yml +++ b/.github/workflows/beam_PreCommit_CommunityMetrics.yml @@ -19,10 +19,10 @@ on: push: tags: ['v*'] branches: ['master', 'release-*'] - paths: ['.test-infra/metrics/**', '.github/workflows/beam_PreCommit_CommunityMetrics.yml'] + paths: ['.test-infra/metrics/**', 'buildSrc/build.gradle.kts', '.github/workflows/beam_PreCommit_CommunityMetrics.yml'] pull_request_target: branches: ['master', 'release-*'] - paths: ['.test-infra/metrics/**'] + paths: ['.test-infra/metrics/**', 'buildSrc/build.gradle.kts'] issue_comment: types: [created] schedule: diff --git a/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml b/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml index 7c821a024742..30e8d6d6c33c 100644 --- a/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml +++ b/.github/workflows/beam_PreCommit_Java_GCP_IO_Direct.yml @@ -90,7 +90,7 @@ jobs: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || github.event.comment.body == 'Run Java_GCP_IO_Direct PreCommit' - runs-on: [self-hosted, ubuntu-20.04, highmem] + runs-on: [self-hosted, ubuntu-20.04, main] steps: - uses: actions/checkout@v4 - name: Setup repository diff --git a/.github/workflows/run_perf_alert_tool.yml b/.github/workflows/run_perf_alert_tool.yml index 6946011f0617..1bd8d525c2fb 100644 --- a/.github/workflows/run_perf_alert_tool.yml +++ b/.github/workflows/run_perf_alert_tool.yml @@ -30,7 +30,7 @@ on: jobs: python_run_change_point_analysis: name: Run Change Point Analysis. - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-20.04, main] permissions: issues: write steps: diff --git a/.test-infra/metrics/build.gradle b/.test-infra/metrics/build.gradle index febe2849ef56..f1ecba05f84d 100644 --- a/.test-infra/metrics/build.gradle +++ b/.test-infra/metrics/build.gradle @@ -106,7 +106,7 @@ task deploy { standardOutput = stdout } - // All images have the same tag, it doesn't matter which we choose. + // All images have the same tag, it doesn't matter which we choose. String image = (stdout.toString().split(' ') as List)[0] String currentImageTag = (image.split(':') as List)[1] println "Current image tag: ${currentImageTag}" diff --git a/buildSrc/build.gradle.kts b/buildSrc/build.gradle.kts index edd10ee108f6..968829caeb8b 100644 --- a/buildSrc/build.gradle.kts +++ b/buildSrc/build.gradle.kts @@ -44,20 +44,19 @@ dependencies { implementation("gradle.plugin.com.github.johnrengelman:shadow:7.1.1") implementation("com.github.spotbugs.snom:spotbugs-gradle-plugin:5.0.14") - runtimeOnly("com.google.protobuf:protobuf-gradle-plugin:0.8.13") // Enable proto code generation - runtimeOnly("com.github.davidmc24.gradle-avro-plugin:gradle-avro-plugin:0.16.0") // Enable Avro code generation - runtimeOnly("com.diffplug.spotless:spotless-plugin-gradle:5.6.1") // Enable a code formatting plugin - runtimeOnly("com.palantir.gradle.docker:gradle-docker:0.34.0") // Enable building Docker containers - runtimeOnly("gradle.plugin.com.dorongold.plugins:task-tree:1.5") // Adds a 'taskTree' task to print task dependency tree - runtimeOnly("gradle.plugin.com.github.johnrengelman:shadow:7.1.1") // Enable shading Java dependencies + runtimeOnly("com.google.protobuf:protobuf-gradle-plugin:0.8.13") // Enable proto code generation + runtimeOnly("com.github.davidmc24.gradle-avro-plugin:gradle-avro-plugin:0.16.0") // Enable Avro code generation + runtimeOnly("com.diffplug.spotless:spotless-plugin-gradle:5.6.1") // Enable a code formatting plugin + runtimeOnly("gradle.plugin.com.dorongold.plugins:task-tree:1.5") // Adds a 'taskTree' task to print task dependency tree + runtimeOnly("gradle.plugin.com.github.johnrengelman:shadow:7.1.1") // Enable shading Java dependencies runtimeOnly("net.linguica.gradle:maven-settings-plugin:0.5") runtimeOnly("gradle.plugin.io.pry.gradle.offline_dependencies:gradle-offline-dependencies-plugin:0.5.0") // Enable creating an offline repository - runtimeOnly("net.ltgt.gradle:gradle-errorprone-plugin:1.2.1") // Enable errorprone Java static analysis + runtimeOnly("net.ltgt.gradle:gradle-errorprone-plugin:3.1.0") // Enable errorprone Java static analysis runtimeOnly("org.ajoberstar.grgit:grgit-gradle:4.1.1") // Enable website git publish to asf-site branch - runtimeOnly("com.avast.gradle:gradle-docker-compose-plugin:0.17.5") // Enable docker compose tasks + runtimeOnly("com.avast.gradle:gradle-docker-compose-plugin:0.16.12") // Enable docker compose tasks runtimeOnly("ca.cutterslade.gradle:gradle-dependency-analyze:1.8.3") // Enable dep analysis runtimeOnly("gradle.plugin.net.ossindex:ossindex-gradle-plugin:0.4.11") // Enable dep vulnerability analysis - runtimeOnly("org.checkerframework:checkerframework-gradle-plugin:0.6.33") // Enable enhanced static checking plugin + runtimeOnly("org.checkerframework:checkerframework-gradle-plugin:0.6.33") // Enable enhanced static checking plugin } // Because buildSrc is built and tested automatically _before_ gradle diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy new file mode 100644 index 000000000000..442b35439cae --- /dev/null +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.gradle + +import java.util.regex.Pattern +import org.gradle.api.GradleException +import org.gradle.api.Plugin +import org.gradle.api.Project +import org.gradle.api.Task +import org.gradle.api.file.CopySpec +import org.gradle.api.logging.LogLevel +import org.gradle.api.logging.Logger +import org.gradle.api.logging.Logging +import org.gradle.api.tasks.Copy +import org.gradle.api.tasks.Delete +import org.gradle.api.tasks.Exec + +/** + * A gradle plug-in interacting with docker. Originally replicated from + * com.palantir.docker plugin. + */ +class BeamDockerPlugin implements Plugin { + private static final Logger logger = Logging.getLogger(BeamDockerPlugin.class) + private static final Pattern LABEL_KEY_PATTERN = Pattern.compile('^[a-z0-9.-]*$') + + static class DockerExtension { + Project project + + private static final String DEFAULT_DOCKERFILE_PATH = 'Dockerfile' + String name = null + File dockerfile = null + String dockerComposeTemplate = 'docker-compose.yml.template' + String dockerComposeFile = 'docker-compose.yml' + Set dependencies = [] as Set + Set tags = [] as Set + Map namedTags = [:] + Map labels = [:] + Map buildArgs = [:] + boolean pull = false + boolean noCache = false + String network = null + boolean buildx = false + Set platform = [] as Set + boolean load = false + boolean push = false + String builder = null + + File resolvedDockerfile = null + File resolvedDockerComposeTemplate = null + File resolvedDockerComposeFile = null + + // The CopySpec defining the Docker Build Context files + final CopySpec copySpec + + DockerExtension(Project project) { + this.project = project + this.copySpec = project.copySpec() + } + + void resolvePathsAndValidate() { + if (dockerfile != null) { + resolvedDockerfile = dockerfile + } else { + resolvedDockerfile = project.file(DEFAULT_DOCKERFILE_PATH) + } + resolvedDockerComposeFile = project.file(dockerComposeFile) + resolvedDockerComposeTemplate = project.file(dockerComposeTemplate) + } + + void dependsOn(Task... args) { + this.dependencies = args as Set + } + + Set getDependencies() { + return dependencies + } + + void files(Object... files) { + copySpec.from(files) + } + + void tags(String... args) { + this.tags = args as Set + } + + Set getTags() { + return this.tags + project.getVersion().toString() + } + + Set getPlatform() { + return platform + } + + void platform(String... args) { + this.platform = args as Set + } + } + + @Override + void apply(Project project) { + DockerExtension ext = project.extensions.create('docker', DockerExtension, project) + + Delete clean = project.tasks.create('dockerClean', Delete, { + group = 'Docker' + description = 'Cleans Docker build directory.' + }) + + Copy prepare = project.tasks.create('dockerPrepare', Copy, { + group = 'Docker' + description = 'Prepares Docker build directory.' + dependsOn clean + }) + + Exec exec = project.tasks.create('docker', Exec, { + group = 'Docker' + description = 'Builds Docker image.' + dependsOn prepare + }) + + Task tag = project.tasks.create('dockerTag', { + group = 'Docker' + description = 'Applies all tags to the Docker image.' + dependsOn exec + }) + + Task pushAllTags = project.tasks.create('dockerTagsPush', { + group = 'Docker' + description = 'Pushes all tagged Docker images to configured Docker Hub.' + }) + + project.tasks.create('dockerPush', { + group = 'Docker' + description = 'Pushes named Docker image to configured Docker Hub.' + dependsOn pushAllTags + }) + + project.afterEvaluate { + ext.resolvePathsAndValidate() + String dockerDir = "${project.buildDir}/docker" + clean.delete dockerDir + + prepare.with { + with ext.copySpec + from(ext.resolvedDockerfile) { + rename { fileName -> + fileName.replace(ext.resolvedDockerfile.getName(), 'Dockerfile') + } + } + into dockerDir + } + + exec.with { + workingDir dockerDir + commandLine buildCommandLine(ext) + dependsOn ext.getDependencies() + logging.captureStandardOutput LogLevel.INFO + logging.captureStandardError LogLevel.ERROR + } + + Map tags = ext.namedTags.collectEntries { taskName, tagName -> + [ + generateTagTaskName(taskName), + [ + tagName: tagName, + tagTask: { + -> tagName } + ] + ] + } + + if (!ext.tags.isEmpty()) { + ext.tags.each { unresolvedTagName -> + String taskName = generateTagTaskName(unresolvedTagName) + + if (tags.containsKey(taskName)) { + throw new IllegalArgumentException("Task name '${taskName}' is existed.") + } + + tags[taskName] = [ + tagName: unresolvedTagName, + tagTask: { + -> computeName(ext.name, unresolvedTagName) } + ] + } + } + + tags.each { taskName, tagConfig -> + Exec tagSubTask = project.tasks.create('dockerTag' + taskName, Exec, { + group = 'Docker' + description = "Tags Docker image with tag '${tagConfig.tagName}'" + workingDir dockerDir + commandLine 'docker', 'tag', "${-> ext.name}", "${-> tagConfig.tagTask()}" + dependsOn exec + }) + tag.dependsOn tagSubTask + + Exec pushSubTask = project.tasks.create('dockerPush' + taskName, Exec, { + group = 'Docker' + description = "Pushes the Docker image with tag '${tagConfig.tagName}' to configured Docker Hub" + workingDir dockerDir + commandLine 'docker', 'push', "${-> tagConfig.tagTask()}" + dependsOn tagSubTask + }) + pushAllTags.dependsOn pushSubTask + } + } + } + + private List buildCommandLine(DockerExtension ext) { + List buildCommandLine = ['docker'] + if (ext.buildx) { + buildCommandLine.addAll(['buildx', 'build']) + if (!ext.platform.isEmpty()) { + buildCommandLine.addAll('--platform', String.join(',', ext.platform)) + } + if (ext.load) { + buildCommandLine.add '--load' + } + if (ext.push) { + buildCommandLine.add '--push' + if (ext.load) { + throw new Exception("cannot combine 'push' and 'load' options") + } + } + if (ext.builder != null) { + buildCommandLine.addAll('--builder', ext.builder) + } + } else { + buildCommandLine.add 'build' + } + if (ext.noCache) { + buildCommandLine.add '--no-cache' + } + if (ext.getNetwork() != null) { + buildCommandLine.addAll('--network', ext.network) + } + if (!ext.buildArgs.isEmpty()) { + for (Map.Entry buildArg : ext.buildArgs.entrySet()) { + buildCommandLine.addAll('--build-arg', "${buildArg.getKey()}=${buildArg.getValue()}" as String) + } + } + if (!ext.labels.isEmpty()) { + for (Map.Entry label : ext.labels.entrySet()) { + if (!label.getKey().matches(LABEL_KEY_PATTERN)) { + throw new GradleException(String.format("Docker label '%s' contains illegal characters. " + + "Label keys must only contain lowercase alphanumberic, `.`, or `-` characters (must match %s).", + label.getKey(), LABEL_KEY_PATTERN.pattern())) + } + buildCommandLine.addAll('--label', "${label.getKey()}=${label.getValue()}" as String) + } + } + if (ext.pull) { + buildCommandLine.add '--pull' + } + buildCommandLine.addAll(['-t', "${-> ext.name}", '.']) + logger.debug("${buildCommandLine}" as String) + return buildCommandLine + } + + private static String computeName(String name, String tag) { + int firstAt = tag.indexOf("@") + + String tagValue + if (firstAt > 0) { + tagValue = tag.substring(firstAt + 1, tag.length()) + } else { + tagValue = tag + } + + if (tagValue.contains(':') || tagValue.contains('/')) { + // tag with ':' or '/' -> force use the tag value + return tagValue + } else { + // tag without ':' and '/' -> replace the tag part of original name + int lastColon = name.lastIndexOf(':') + int lastSlash = name.lastIndexOf('/') + + int endIndex; + + // image_name -> this should remain + // host:port/image_name -> this should remain. + // host:port/image_name:v1 -> v1 should be replaced + if (lastColon > lastSlash) endIndex = lastColon + else endIndex = name.length() + + return name.substring(0, endIndex) + ":" + tagValue + } + } + + private static String generateTagTaskName(String name) { + String tagTaskName = name + int firstAt = name.indexOf("@") + + if (firstAt > 0) { + // Get substring of task name + tagTaskName = name.substring(0, firstAt) + } else if (firstAt == 0) { + // Task name must not be empty + throw new GradleException("Task name of docker tag '${name}' must not be empty.") + } else if (name.contains(':') || name.contains('/')) { + // Tags which with repo or name must have a task name + throw new GradleException("Docker tag '${name}' must have a task name.") + } + + StringBuffer sb = new StringBuffer(tagTaskName) + // Uppercase the first letter of task name + sb.replace(0, 1, tagTaskName.substring(0, 1).toUpperCase()); + return sb.toString() + } +} diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerRunPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerRunPlugin.groovy new file mode 100644 index 000000000000..5297c7018139 --- /dev/null +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerRunPlugin.groovy @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.gradle + +import org.gradle.api.Plugin +import org.gradle.api.Project +import org.gradle.api.tasks.Exec + +/** + * A gradle plug-in handling 'docker run' command. Originally replicated from + * com.palantir.docker-run plugin. + */ +class BeamDockerRunPlugin implements Plugin { + + /** A class defining the configurations of dockerRun task. */ + static class DockerRunExtension { + String name + String image + Set ports = [] as Set + Map env = [:] + List arguments = [] + Map volumes = [:] + boolean daemonize = true + boolean clean = false + + public String getName() { + return name + } + + public void setName(String name) { + this.name = name + } + } + + @Override + void apply(Project project) { + DockerRunExtension ext = project.extensions.create('dockerRun', DockerRunExtension) + + Exec dockerRunStatus = project.tasks.create('dockerRunStatus', Exec, { + group = 'Docker Run' + description = 'Checks the run status of the container' + }) + + Exec dockerRun = project.tasks.create('dockerRun', Exec, { + group = 'Docker Run' + description = 'Runs the specified container with port mappings' + }) + + Exec dockerStop = project.tasks.create('dockerStop', Exec, { + group = 'Docker Run' + description = 'Stops the named container if it is running' + ignoreExitValue = true + }) + + Exec dockerRemoveContainer = project.tasks.create('dockerRemoveContainer', Exec, { + group = 'Docker Run' + description = 'Removes the persistent container associated with the Docker Run tasks' + ignoreExitValue = true + }) + + project.afterEvaluate { + /** Inspect status of docker. */ + dockerRunStatus.with { + standardOutput = new ByteArrayOutputStream() + commandLine 'docker', 'inspect', '--format={{.State.Running}}', ext.name + doLast { + if (standardOutput.toString().trim() != 'true') { + println "Docker container '${ext.name}' is STOPPED." + return 1 + } else { + println "Docker container '${ext.name}' is RUNNING." + } + } + } + + /** + * Run a docker container. See {@link DockerRunExtension} for supported + * arguments. + * + * Replication of dockerRun task of com.palantir.docker-run plugin. + */ + dockerRun.with { + List args = new ArrayList() + args.addAll(['docker', 'run']) + + if (ext.daemonize) { + args.add('-d') + } + if (ext.clean) { + args.add('--rm') + } else { + finalizedBy dockerRunStatus + } + for (String port : ext.ports) { + args.add('-p') + args.add(port) + } + for (Map.Entry volume : ext.volumes.entrySet()) { + File localFile = project.file(volume.key) + + if (!localFile.exists()) { + logger.error("ERROR: Local folder ${localFile} doesn't exist. Mounted volume will not be visible to container") + throw new IllegalStateException("Local folder ${localFile} doesn't exist.") + } + args.add('-v') + args.add("${localFile.absolutePath}:${volume.value}") + } + args.addAll(ext.env.collect{ k, v -> ['-e', "${k}=${v}"]}.flatten()) + args.add('--name') + args.add(ext.name) + if (!ext.arguments.isEmpty()) { + args.addAll(ext.arguments) + } + args.add(ext.image) + + commandLine args + } + + dockerStop.with { + commandLine 'docker', 'stop', ext.name + } + + dockerRemoveContainer.with { + commandLine 'docker', 'rm', ext.name + } + } + } +} diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index c31482d577e0..c7a62237086e 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -23,7 +23,6 @@ import static java.util.UUID.randomUUID import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar import groovy.json.JsonOutput import groovy.json.JsonSlurper -import java.net.ServerSocket import java.util.logging.Logger import org.gradle.api.attributes.Category import org.gradle.api.GradleException @@ -1252,7 +1251,7 @@ class BeamModulePlugin implements Plugin { if (configuration.shadowClosure) { // Ensure that tests are packaged and part of the artifact set. project.task('packageTests', type: Jar) { - classifier = 'tests-unshaded' + archiveClassifier = 'tests-unshaded' from project.sourceSets.test.output } project.artifacts.archives project.packageTests @@ -1560,13 +1559,13 @@ class BeamModulePlugin implements Plugin { } } - // Always configure the shadowJar classifier and merge service files. + // Always configure the shadowJar archiveClassifier and merge service files. if (configuration.shadowClosure) { // Only set the classifer on the unshaded classes if we are shading. - project.jar { classifier = "unshaded" } + project.jar { archiveClassifier = "unshaded" } project.shadowJar({ - classifier = null + archiveClassifier = null mergeServiceFiles() zip64 true into("META-INF/") { @@ -1575,11 +1574,11 @@ class BeamModulePlugin implements Plugin { } } << configuration.shadowClosure) - // Always configure the shadowTestJar classifier and merge service files. + // Always configure the shadowTestJar archiveClassifier and merge service files. project.task('shadowTestJar', type: ShadowJar, { group = "Shadow" description = "Create a combined JAR of project and test dependencies" - classifier = "tests" + archiveClassifier = "tests" from project.sourceSets.test.output configurations = [ project.configurations.testRuntimeMigration @@ -1639,7 +1638,7 @@ class BeamModulePlugin implements Plugin { project.tasks.register("testJar", Jar) { group = "Jar" description = "Create a JAR of test classes" - classifier = "tests" + archiveClassifier = "tests" from project.sourceSets.test.output zip64 true exclude "META-INF/INDEX.LIST" @@ -1794,18 +1793,18 @@ class BeamModulePlugin implements Plugin { project.task('sourcesJar', type: Jar) { from project.sourceSets.main.allSource - classifier = 'sources' + archiveClassifier = 'sources' } project.artifacts.archives project.sourcesJar project.task('testSourcesJar', type: Jar) { from project.sourceSets.test.allSource - classifier = 'test-sources' + archiveClassifier = 'test-sources' } project.artifacts.archives project.testSourcesJar project.task('javadocJar', type: Jar, dependsOn: project.javadoc) { - classifier = 'javadoc' + archiveClassifier = 'javadoc' from project.javadoc.destinationDir } project.artifacts.archives project.javadocJar @@ -1915,8 +1914,8 @@ class BeamModulePlugin implements Plugin { def dependencyNode = dependenciesNode.appendNode('dependency') def appendClassifier = { dep -> dep.artifacts.each { art -> - if (art.hasProperty('classifier')) { - dependencyNode.appendNode('classifier', art.classifier) + if (art.hasProperty('archiveClassifier')) { + dependencyNode.appendNode('archiveClassifier', art.archiveClassifier) } } } @@ -2210,7 +2209,7 @@ class BeamModulePlugin implements Plugin { /** ***********************************************************************************************/ project.ext.applyDockerNature = { - project.apply plugin: "com.palantir.docker" + project.apply plugin: BeamDockerPlugin project.docker { noCache true } project.tasks.create(name: "copyLicenses", type: Copy) { from "${project.rootProject.projectDir}/LICENSE" @@ -2222,7 +2221,7 @@ class BeamModulePlugin implements Plugin { } project.ext.applyDockerRunNature = { - project.apply plugin: "com.palantir.docker-run" + project.apply plugin: BeamDockerRunPlugin } /** ***********************************************************************************************/ diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/VendorJavaPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/VendorJavaPlugin.groovy index 061ccf27cce2..97d96e6cf1eb 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/VendorJavaPlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/VendorJavaPlugin.groovy @@ -126,7 +126,7 @@ artifactId=${project.name} } config.exclusions.each { exclude it } - classifier = null + archiveClassifier = null mergeServiceFiles() zip64 true exclude "META-INF/INDEX.LIST" diff --git a/examples/notebooks/beam-ml/automatic_model_refresh.ipynb b/examples/notebooks/beam-ml/automatic_model_refresh.ipynb index 67fe51af1253..9cbab0a14178 100644 --- a/examples/notebooks/beam-ml/automatic_model_refresh.ipynb +++ b/examples/notebooks/beam-ml/automatic_model_refresh.ipynb @@ -1,605 +1,530 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, - "cells": [{ - "cell_type": "code", - "source": [ - "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", - "\n", - "# Licensed to the Apache Software Foundation (ASF) under one\n", - "# or more contributor license agreements. See the NOTICE file\n", - "# distributed with this work for additional information\n", - "# regarding copyright ownership. The ASF licenses this file\n", - "# to you under the Apache License, Version 2.0 (the\n", - "# \"License\"); you may not use this file except in compliance\n", - "# with the License. You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing,\n", - "# software distributed under the License is distributed on an\n", - "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", - "# KIND, either express or implied. See the License for the\n", - "# specific language governing permissions and limitations\n", - "# under the License" - ], - "metadata": { - "cellView": "form", - "id": "OsFaZscKSPvo" - }, - "execution_count": null, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "# Update ML models in running pipelines\n", - "\n", - "\n", - " \n", - " \n", - "
\n", - " Run in Google Colab\n", - " \n", - " View source on GitHub\n", - "
\n" - ], - "metadata": { - "id": "ZUSiAR62SgO8" - } - }, - { - "cell_type": "markdown", - "source": [ - "This notebook demonstrates how to perform automatic model updates without stopping your Apache Beam pipeline.\n", - "You can use side inputs to update your model in real time, even while the Apache Beam pipeline is running. The side input is passed in a `ModelHandler` configuration object. You can update the model either by leveraging one of Apache Beam's provided patterns, such as the `WatchFilePattern`, or by configuring a custom side input `PCollection` that defines the logic for the model update.\n", - "\n", - "The pipeline in this notebook uses a RunInference `PTransform` with TensorFlow machine learning (ML) models to run inference on images. To update the model, it uses a side input `PCollection` that emits `ModelMetadata`.\n", - "For more information about side inputs, see the [Side inputs](https://beam.apache.org/documentation/programming-guide/#side-inputs) section in the Apache Beam Programming Guide.\n", - "\n", - "This example uses `WatchFilePattern` as a side input. `WatchFilePattern` is used to watch for file updates that match the `file_pattern` based on timestamps. It emits the latest `ModelMetadata`, which is used in the RunInference `PTransform` to automatically update the ML model without stopping the Apache Beam pipeline.\n" - ], - "metadata": { - "id": "tBtqF5UpKJNZ" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Before you begin\n", - "Install the dependencies required to run this notebook.\n", - "\n", - "To use RunInference with side inputs for automatic model updates, use Apache Beam version 2.46.0 or later." - ], - "metadata": { - "id": "SPuXFowiTpWx" - } - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "1RyTYsFEIOlA", - "outputId": "0e6b88a7-82d8-4d94-951c-046a9b8b7abb", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }], - "source": [ - "!pip install apache_beam[gcp]>=2.46.0 --quiet\n", - "!pip install tensorflow\n", - "!pip install tensorflow_hub" - ] - }, - { - "cell_type": "code", - "source": [ - "# Imports required for the notebook.\n", - "import logging\n", - "import time\n", - "from typing import Iterable\n", - "from typing import Tuple\n", - "\n", - "import apache_beam as beam\n", - "from apache_beam.examples.inference.tensorflow_imagenet_segmentation import PostProcessor\n", - "from apache_beam.examples.inference.tensorflow_imagenet_segmentation import read_image\n", - "from apache_beam.ml.inference.base import PredictionResult\n", - "from apache_beam.ml.inference.base import RunInference\n", - "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor\n", - "from apache_beam.ml.inference.utils import WatchFilePattern\n", - "from apache_beam.options.pipeline_options import GoogleCloudOptions\n", - "from apache_beam.options.pipeline_options import PipelineOptions\n", - "from apache_beam.options.pipeline_options import SetupOptions\n", - "from apache_beam.options.pipeline_options import StandardOptions\n", - "from apache_beam.transforms.periodicsequence import PeriodicImpulse\n", - "import numpy\n", - "from PIL import Image\n", - "import tensorflow as tf" - ], - "metadata": { - "id": "Rs4cwwNrIV9H" - }, - "execution_count": 2, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "code", - "source": [ - "# Authenticate to your Google Cloud account.\n", - "from google.colab import auth\n", - "auth.authenticate_user()" - ], - "metadata": { - "id": "jAKpPcmmGm03" - }, - "execution_count": 3, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "## Configure the runner\n", - "\n", - "This pipeline uses the Dataflow Runner. To run the pipeline, you need to complete the following tasks:\n", - "\n", - "* Ensure that you have all the required permissions to run the pipeline on Dataflow.\n", - "* Configure the pipeline options for the pipeline to run on Dataflow. Make sure the pipeline is using streaming mode.\n", - "\n", - "In the following code, replace `BUCKET_NAME` with the the name of your Cloud Storage bucket." - ], - "metadata": { - "id": "ORYNKhH3WQyP" - } - }, - { - "cell_type": "code", - "source": [ - "options = PipelineOptions()\n", - "options.view_as(StandardOptions).streaming = True\n", - "\n", - "# Provide required pipeline options for the Dataflow Runner.\n", - "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n", - "\n", - "# Set the project to the default project in your current Google Cloud environment.\n", - "options.view_as(GoogleCloudOptions).project = 'your-project'\n", - "\n", - "# Set the Google Cloud region that you want to run Dataflow in.\n", - "options.view_as(GoogleCloudOptions).region = 'us-central1'\n", - "\n", - "# IMPORTANT: Replace BUCKET_NAME with the the name of your Cloud Storage bucket.\n", - "dataflow_gcs_location = \"gs://BUCKET_NAME/tmp/\"\n", - "\n", - "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", - "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n", - "\n", - "# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.\n", - "options.view_as(GoogleCloudOptions).temp_location = '%s/temp' % dataflow_gcs_location\n", - "\n" - ], - "metadata": { - "id": "wWjbnq6X-4uE" - }, - "execution_count": 4, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "Install the `tensorflow` and `tensorflow_hub` dependencies on Dataflow. Use the `requirements_file` pipeline option to pass these dependencies." - ], - "metadata": { - "id": "HTJV8pO2Wcw4" - } - }, - { - "cell_type": "code", - "source": [ - "# In a requirements file, define the dependencies required for the pipeline.\n", - "deps_required_for_pipeline = ['tensorflow>=2.12.0', 'tensorflow-hub>=0.10.0', 'Pillow>=9.0.0']\n", - "requirements_file_path = './requirements.txt'\n", - "# Write the dependencies to the requirements file.\n", - "with open(requirements_file_path, 'w') as f:\n", - " for dep in deps_required_for_pipeline:\n", - " f.write(dep + '\\n')\n", - "\n", - "# Install the pipeline dependencies on Dataflow.\n", - "options.view_as(SetupOptions).requirements_file = requirements_file_path" - ], - "metadata": { - "id": "lEy4PkluWbdm" - }, - "execution_count": 5, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "## Use the TensorFlow model handler\n", - " This example uses `TFModelHandlerTensor` as the model handler and the `resnet_101` model trained on [ImageNet](https://www.image-net.org/).\n", - "\n", - " Download the model from [Google Cloud Storage](https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels.h5) (link downloads the model), and place it in the directory that you want to use to update your model.\n", - "\n", - "In the following code, replace `BUCKET_NAME` with the the name of your Cloud Storage bucket." - ], - "metadata": { - "id": "_AUNH_GJk_NE" - } - }, - { - "cell_type": "code", - "source": [ - "model_handler = TFModelHandlerTensor(\n", - " model_uri=\"gs://BUCKET_NAME/resnet101_weights_tf_dim_ordering_tf_kernels.h5\")" - ], - "metadata": { - "id": "kkSnsxwUk-Sp" - }, - "execution_count": 6, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "## Preprocess images\n", - "\n", - "Use `preprocess_image` to run the inference, read the image, and convert the image to a TensorFlow tensor." - ], - "metadata": { - "id": "tZH0r0sL-if5" - } - }, - { - "cell_type": "code", - "source": [ - "def preprocess_image(image_name, image_dir):\n", - " img = tf.keras.utils.get_file(image_name, image_dir + image_name)\n", - " img = Image.open(img).resize((224, 224))\n", - " img = numpy.array(img) / 255.0\n", - " img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)\n", - " return img_tensor" - ], - "metadata": { - "id": "dU5imgTt-8Ne" - }, - "execution_count": 7, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "code", - "source": [ - "class PostProcessor(beam.DoFn):\n", - " \"\"\"Process the PredictionResult to get the predicted label.\n", - " Returns predicted label.\n", - " \"\"\"\n", - " def process(self, element: PredictionResult) -> Iterable[Tuple[str, str]]:\n", - " predicted_class = numpy.argmax(element.inference, axis=-1)\n", - " labels_path = tf.keras.utils.get_file(\n", - " 'ImageNetLabels.txt',\n", - " 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt' # pylint: disable=line-too-long\n", - " )\n", - " imagenet_labels = numpy.array(open(labels_path).read().splitlines())\n", - " predicted_class_name = imagenet_labels[predicted_class]\n", - " yield predicted_class_name.title(), element.model_id" - ], - "metadata": { - "id": "6V5tJxO6-gyt" - }, - "execution_count": 8, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "code", - "source": [ - "# Define the pipeline object.\n", - "pipeline = beam.Pipeline(options=options)" - ], - "metadata": { - "id": "GpdKk72O_NXT", - "outputId": "bcbaa8a6-0408-427a-de9e-78a6a7eefd7b", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 400 - } - }, - "execution_count": 9, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "Next, review the pipeline steps and examine the code.\n", - "\n", - "### Pipeline steps\n" - ], - "metadata": { - "id": "elZ53uxc_9Hv" - } - }, - { - "cell_type": "markdown", - "source": [ - "1. Create a `PeriodicImpulse` transform, which emits output every `n` seconds. The `PeriodicImpulse` transform generates an infinite sequence of elements with a given runtime interval.\n", - "\n", - " In this example, `PeriodicImpulse` mimics the Pub/Sub source. Because the inputs in a streaming pipeline arrive in intervals, use `PeriodicImpulse` to output elements at `m` intervals.\n", - "To learn more about `PeriodicImpulse`, see the [`PeriodicImpulse` code](https://github.com/apache/beam/blob/9c52e0594d6f0e59cd17ee005acfb41da508e0d5/sdks/python/apache_beam/transforms/periodicsequence.py#L150)." - ], - "metadata": { - "id": "305tkV2sAD-S" - } - }, - { - "cell_type": "code", - "source": [ - "start_timestamp = time.time() # start timestamp of the periodic impulse\n", - "end_timestamp = start_timestamp + 60 * 20 # end timestamp of the periodic impulse (will run for 20 minutes).\n", - "main_input_fire_interval = 60 # interval in seconds at which the main input PCollection is emitted.\n", - "side_input_fire_interval = 60 # interval in seconds at which the side input PCollection is emitted.\n", - "\n", - "periodic_impulse = (\n", - " pipeline\n", - " | \"MainInputPcoll\" >> PeriodicImpulse(\n", - " start_timestamp=start_timestamp,\n", - " stop_timestamp=end_timestamp,\n", - " fire_interval=main_input_fire_interval))" - ], - "metadata": { - "id": "vUFStz66_Tbb", - "outputId": "39f2704b-021e-4d41-fce3-a2fac90a5bad", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 133 - } - }, - "execution_count": 10, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "2. To read and preprocess the images, use the `read_image` function. This example uses `Cat-with-beanie.jpg` for all inferences.\n", - "\n", - " **Note**: Image used for prediction is licensed in CC-BY. The creator is listed in the [LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt) file." - ], - "metadata": { - "id": "8-sal2rFAxP2" - } - }, - { - "cell_type": "markdown", - "source": [ - "![download.png]()" - ], - "metadata": { - "id": "gW4cE8bhXS-d" - } - }, - { - "cell_type": "code", - "source": [ - "image_data = (periodic_impulse | beam.Map(lambda x: \"Cat-with-beanie.jpg\")\n", - " | \"ReadImage\" >> beam.Map(lambda image_name: read_image(\n", - " image_name=image_name, image_dir='https://storage.googleapis.com/apache-beam-samples/image_captioning/')))" - ], - "metadata": { - "id": "dGg11TpV_aV6", - "outputId": "a57e8197-6756-4fd8-a664-f51ef2fea730", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 204 - } - }, - "execution_count": 11, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "3. Pass the images to the RunInference `PTransform`. RunInference takes `model_handler` and `model_metadata_pcoll` as input parameters.\n", - " * `model_metadata_pcoll` is a side input `PCollection` to the RunInference `PTransform`. This side input is used to update the `model_uri` in the `model_handler` without needing to stop the Apache Beam pipeline\n", - " * Use `WatchFilePattern` as side input to watch a `file_pattern` matching `.h5` files. In this case, the `file_pattern` is `'gs://BUCKET_NAME/*.h5'`.\n", - "\n" - ], - "metadata": { - "id": "eB0-ewd-BCKE" - } - }, - { - "cell_type": "code", - "source": [ - " # The side input used to watch for the .h5 file and update the model_uri of the TFModelHandlerTensor.\n", - "file_pattern = 'gs://BUCKET_NAME/*.h5'\n", - "side_input_pcoll = (\n", - " pipeline\n", - " | \"WatchFilePattern\" >> WatchFilePattern(file_pattern=file_pattern,\n", - " interval=side_input_fire_interval,\n", - " stop_timestamp=end_timestamp))\n", - "inferences = (\n", - " image_data\n", - " | \"ApplyWindowing\" >> beam.WindowInto(beam.window.FixedWindows(10))\n", - " | \"RunInference\" >> RunInference(model_handler=model_handler,\n", - " model_metadata_pcoll=side_input_pcoll))" - ], - "metadata": { - "id": "_AjvvexJ_hUq", - "outputId": "291fcc38-0abb-4b11-f840-4a850097a56f", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 133 - } - }, - "execution_count": 12, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "4. Post-process the `PredictionResult` object.\n", - "When the inference is complete, RunInference outputs a `PredictionResult` object that contains the fields `example`, `inference`, and `model_id`. The `model_id` field identifies the model used to run the inference. The `PostProcessor` returns the predicted label and the model ID used to run the inference on the predicted label." - ], - "metadata": { - "id": "lTA4wRWNDVis" - } - }, - { - "cell_type": "code", - "source": [ - "post_processor = (\n", - " inferences\n", - " | \"PostProcessResults\" >> beam.ParDo(PostProcessor())\n", - " | \"LogResults\" >> beam.Map(logging.info))" - ], - "metadata": { - "id": "9TB76fo-_vZJ", - "outputId": "3e12d482-1bdf-4136-fbf7-9d5bb4bb62c3", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 222 - } - }, - "execution_count": 13, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - }, - { - "cell_type": "markdown", - "source": [ - "### Watch for the model update\n", - "\n", - "After the pipeline starts processing data and when you see output emitted from the RunInference `PTransform`, upload a `resnet152` model saved in `.h5` format to a Google Cloud Storage bucket location that matches the `file_pattern` you defined earlier. You can [download a copy of the model](https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet152_weights_tf_dim_ordering_tf_kernels.h5) (link downloads the model). RunInference uses `WatchFilePattern` as a side input to update the `model_uri` of `TFModelHandlerTensor`." - ], - "metadata": { - "id": "wYp-mBHHjOjA" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Run the pipeline\n", - "\n", - "Use the following code to run the pipeline." - ], - "metadata": { - "id": "_ty03jDnKdKR" - } - }, - { - "cell_type": "code", - "source": [ - "# Run the pipeline.\n", - "result = pipeline.run().wait_until_finish()" - ], - "metadata": { - "id": "wd0VJLeLEWBU", - "outputId": "3489c891-05d2-4739-d693-1899cfe78859", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 186 - } - }, - "execution_count": 14, - "outputs": [{ - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - }] - } - ] -} + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "source": [ + "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", + "\n", + "# Licensed to the Apache Software Foundation (ASF) under one\n", + "# or more contributor license agreements. See the NOTICE file\n", + "# distributed with this work for additional information\n", + "# regarding copyright ownership. The ASF licenses this file\n", + "# to you under the Apache License, Version 2.0 (the\n", + "# \"License\"); you may not use this file except in compliance\n", + "# with the License. You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing,\n", + "# software distributed under the License is distributed on an\n", + "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + "# KIND, either express or implied. See the License for the\n", + "# specific language governing permissions and limitations\n", + "# under the License" + ], + "metadata": { + "cellView": "form", + "id": "OsFaZscKSPvo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Update ML models in running pipelines\n", + "\n", + "\n", + " \n", + " \n", + "
\n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
\n" + ], + "metadata": { + "id": "ZUSiAR62SgO8" + } + }, + { + "cell_type": "markdown", + "source": [ + "This notebook demonstrates how to perform automatic model updates without stopping your Apache Beam pipeline.\n", + "You can use side inputs to update your model in real time, even while the Apache Beam pipeline is running. The side input is passed in a `ModelHandler` configuration object. You can update the model either by leveraging one of Apache Beam's provided patterns, such as the `WatchFilePattern`, or by configuring a custom side input `PCollection` that defines the logic for the model update.\n", + "\n", + "The pipeline in this notebook uses a RunInference `PTransform` with TensorFlow machine learning (ML) models to run inference on images. To update the model, it uses a side input `PCollection` that emits `ModelMetadata`.\n", + "For more information about side inputs, see the [Side inputs](https://beam.apache.org/documentation/programming-guide/#side-inputs) section in the Apache Beam Programming Guide.\n", + "\n", + "This example uses `WatchFilePattern` as a side input. `WatchFilePattern` is used to watch for file updates that match the `file_pattern` based on timestamps. It emits the latest `ModelMetadata`, which is used in the RunInference `PTransform` to automatically update the ML model without stopping the Apache Beam pipeline.\n" + ], + "metadata": { + "id": "tBtqF5UpKJNZ" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Before you begin\n", + "Install the dependencies required to run this notebook.\n", + "\n", + "To use RunInference with side inputs for automatic model updates, use Apache Beam version 2.46.0 or later." + ], + "metadata": { + "id": "SPuXFowiTpWx" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1RyTYsFEIOlA" + }, + "outputs": [], + "source": [ + "!pip install apache_beam[gcp]>=2.46.0 --quiet\n", + "!pip install tensorflow --quiet\n", + "!pip install tensorflow_hub --quiet" + ] + }, + { + "cell_type": "code", + "source": [ + "# Imports required for the notebook.\n", + "import logging\n", + "import time\n", + "from typing import Iterable\n", + "from typing import Tuple\n", + "\n", + "import apache_beam as beam\n", + "from apache_beam.ml.inference.base import PredictionResult\n", + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor\n", + "from apache_beam.ml.inference.utils import WatchFilePattern\n", + "from apache_beam.options.pipeline_options import GoogleCloudOptions\n", + "from apache_beam.options.pipeline_options import PipelineOptions\n", + "from apache_beam.options.pipeline_options import SetupOptions\n", + "from apache_beam.options.pipeline_options import StandardOptions\n", + "from apache_beam.options.pipeline_options import WorkerOptions\n", + "from apache_beam.transforms.periodicsequence import PeriodicImpulse\n", + "import numpy\n", + "from PIL import Image\n", + "import tensorflow as tf" + ], + "metadata": { + "id": "Rs4cwwNrIV9H" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Authenticate to your Google Cloud account.\n", + "def auth_to_colab():\n", + " from google.colab import auth\n", + " auth.authenticate_user()\n", + "\n", + "auth_to_colab()" + ], + "metadata": { + "id": "jAKpPcmmGm03" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Configure the runner\n", + "\n", + "This pipeline uses the Dataflow Runner. To run the pipeline, you need to complete the following tasks:\n", + "\n", + "* Ensure that you have all the required permissions to run the pipeline on Dataflow.\n", + "* Configure the pipeline options for the pipeline to run on Dataflow. Make sure the pipeline is using streaming mode.\n", + "\n", + "In the following code, replace `BUCKET_NAME` with the the name of your Cloud Storage bucket." + ], + "metadata": { + "id": "ORYNKhH3WQyP" + } + }, + { + "cell_type": "code", + "source": [ + "options = PipelineOptions()\n", + "options.view_as(StandardOptions).streaming = True\n", + "\n", + "BUCKET_NAME = '' # Replace with your bucket name.\n", + "\n", + "# Provide required pipeline options for the Dataflow Runner.\n", + "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n", + "\n", + "# Set the project to the default project in your current Google Cloud environment.\n", + "options.view_as(GoogleCloudOptions).project = ''\n", + "\n", + "# Set the Google Cloud region that you want to run Dataflow in.\n", + "options.view_as(GoogleCloudOptions).region = 'us-central1'\n", + "\n", + "# IMPORTANT: Replace BUCKET_NAME with the the name of your Cloud Storage bucket.\n", + "dataflow_gcs_location = \"gs://%s/dataflow\" % BUCKET_NAME\n", + "\n", + "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", + "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n", + "\n", + "\n", + "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", + "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n", + "\n", + "# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.\n", + "options.view_as(GoogleCloudOptions).temp_location = '%s/temp' % dataflow_gcs_location\n", + "\n", + "options.view_as(SetupOptions).save_main_session = True\n", + "\n", + "# Launching Dataflow with only one worker might result in processing delays due to\n", + "# initial input processing. This could further postpone the side input model updates.\n", + "# To expedite the model update process, it's recommended to set num_workers>1.\n", + "# https://github.com/apache/beam/issues/28776\n", + "options.view_as(WorkerOptions).num_workers = 5" + ], + "metadata": { + "id": "wWjbnq6X-4uE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Install the `tensorflow` and `tensorflow_hub` dependencies on Dataflow. Use the `requirements_file` pipeline option to pass these dependencies." + ], + "metadata": { + "id": "HTJV8pO2Wcw4" + } + }, + { + "cell_type": "code", + "source": [ + "# In a requirements file, define the dependencies required for the pipeline.\n", + "!printf 'tensorflow>=2.12.0\\ntensorflow_hub>=0.10.0\\nPillow>=9.0.0' > ./requirements.txt\n", + "# Install the pipeline dependencies on Dataflow.\n", + "options.view_as(SetupOptions).requirements_file = './requirements.txt'" + ], + "metadata": { + "id": "lEy4PkluWbdm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Use the TensorFlow model handler\n", + " This example uses `TFModelHandlerTensor` as the model handler and the `resnet_101` model trained on [ImageNet](https://www.image-net.org/).\n", + "\n", + "\n", + "For DataflowRunner, the model needs to be stored remote location accessible by the Beam pipeline. So we will download `ResNet101` model and upload it to the GCS location.\n" + ], + "metadata": { + "id": "_AUNH_GJk_NE" + } + }, + { + "cell_type": "code", + "source": [ + "model = tf.keras.applications.resnet.ResNet101()\n", + "model.save('resnet101_weights_tf_dim_ordering_tf_kernels.keras')\n", + "# After saving the model locally, upload the model to GCS bucket and provide that gcs bucket `URI` as `model_uri` to the `TFModelHandler`\n", + "# Replace `BUCKET_NAME` value with actual bucket name.\n", + "!gsutil cp resnet101_weights_tf_dim_ordering_tf_kernels.keras gs:///dataflow/resnet101_weights_tf_dim_ordering_tf_kernels.keras" + ], + "metadata": { + "id": "ibkWiwVNvyrn" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model_handler = TFModelHandlerTensor(\n", + " model_uri=dataflow_gcs_location + \"/resnet101_weights_tf_dim_ordering_tf_kernels.keras\")" + ], + "metadata": { + "id": "kkSnsxwUk-Sp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Preprocess images\n", + "\n", + "Use `preprocess_image` to run the inference, read the image, and convert the image to a TensorFlow tensor." + ], + "metadata": { + "id": "tZH0r0sL-if5" + } + }, + { + "cell_type": "code", + "source": [ + "def preprocess_image(image_name, image_dir):\n", + " img = tf.keras.utils.get_file(image_name, image_dir + image_name)\n", + " img = Image.open(img).resize((224, 224))\n", + " img = numpy.array(img) / 255.0\n", + " img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)\n", + " return img_tensor" + ], + "metadata": { + "id": "dU5imgTt-8Ne" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class PostProcessor(beam.DoFn):\n", + " \"\"\"Process the PredictionResult to get the predicted label.\n", + " Returns predicted label.\n", + " \"\"\"\n", + " def process(self, element: PredictionResult) -> Iterable[Tuple[str, str]]:\n", + " predicted_class = numpy.argmax(element.inference, axis=-1)\n", + " labels_path = tf.keras.utils.get_file(\n", + " 'ImageNetLabels.txt',\n", + " 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt' # pylint: disable=line-too-long\n", + " )\n", + " imagenet_labels = numpy.array(open(labels_path).read().splitlines())\n", + " predicted_class_name = imagenet_labels[predicted_class]\n", + " yield predicted_class_name.title(), element.model_id" + ], + "metadata": { + "id": "6V5tJxO6-gyt" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Define the pipeline object.\n", + "pipeline = beam.Pipeline(options=options)" + ], + "metadata": { + "id": "GpdKk72O_NXT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, review the pipeline steps and examine the code.\n", + "\n", + "### Pipeline steps\n" + ], + "metadata": { + "id": "elZ53uxc_9Hv" + } + }, + { + "cell_type": "markdown", + "source": [ + "1. Create a `PeriodicImpulse` transform, which emits output every `n` seconds. The `PeriodicImpulse` transform generates an infinite sequence of elements with a given runtime interval.\n", + "\n", + " In this example, `PeriodicImpulse` mimics the Pub/Sub source. Because the inputs in a streaming pipeline arrive in intervals, use `PeriodicImpulse` to output elements at `m` intervals.\n", + "To learn more about `PeriodicImpulse`, see the [`PeriodicImpulse` code](https://github.com/apache/beam/blob/9c52e0594d6f0e59cd17ee005acfb41da508e0d5/sdks/python/apache_beam/transforms/periodicsequence.py#L150)." + ], + "metadata": { + "id": "305tkV2sAD-S" + } + }, + { + "cell_type": "code", + "source": [ + "start_timestamp = time.time() # start timestamp of the periodic impulse\n", + "end_timestamp = start_timestamp + 60 * 20 # end timestamp of the periodic impulse (will run for 20 minutes).\n", + "main_input_fire_interval = 60 # interval in seconds at which the main input PCollection is emitted.\n", + "side_input_fire_interval = 60 # interval in seconds at which the side input PCollection is emitted.\n", + "\n", + "periodic_impulse = (\n", + " pipeline\n", + " | \"MainInputPcoll\" >> PeriodicImpulse(\n", + " start_timestamp=start_timestamp,\n", + " stop_timestamp=end_timestamp,\n", + " fire_interval=main_input_fire_interval))" + ], + "metadata": { + "id": "vUFStz66_Tbb" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "2. To read and preprocess the images, use the `preprocess_image` function. This example uses `Cat-with-beanie.jpg` for all inferences.\n", + "\n", + " **Note**: Image used for prediction is licensed in CC-BY. The creator is listed in the [LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt) file." + ], + "metadata": { + "id": "8-sal2rFAxP2" + } + }, + { + "cell_type": "markdown", + "source": [ + "![download.png]()" + ], + "metadata": { + "id": "gW4cE8bhXS-d" + } + }, + { + "cell_type": "code", + "source": [ + "image_data = (periodic_impulse | beam.Map(lambda x: \"Cat-with-beanie.jpg\")\n", + " | \"ReadImage\" >> beam.Map(lambda image_name: preprocess_image(\n", + " image_name=image_name, image_dir='https://storage.googleapis.com/apache-beam-samples/image_captioning/')))" + ], + "metadata": { + "id": "dGg11TpV_aV6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "3. Pass the images to the RunInference `PTransform`. RunInference takes `model_handler` and `model_metadata_pcoll` as input parameters.\n", + " * `model_metadata_pcoll` is a side input `PCollection` to the RunInference `PTransform`. This side input is used to update the `model_uri` in the `model_handler` without needing to stop the Apache Beam pipeline\n", + " * Use `WatchFilePattern` as side input to watch a `file_pattern` matching `.keras` files. In this case, the `file_pattern` is `'gs://BUCKET_NAME/dataflow/*keras'`.\n", + "\n" + ], + "metadata": { + "id": "eB0-ewd-BCKE" + } + }, + { + "cell_type": "code", + "source": [ + " # The side input used to watch for the .keras file and update the model_uri of the TFModelHandlerTensor.\n", + "file_pattern = dataflow_gcs_location + '/*.keras'\n", + "side_input_pcoll = (\n", + " pipeline\n", + " | \"WatchFilePattern\" >> WatchFilePattern(file_pattern=file_pattern,\n", + " interval=side_input_fire_interval,\n", + " stop_timestamp=end_timestamp))\n", + "inferences = (\n", + " image_data\n", + " | \"ApplyWindowing\" >> beam.WindowInto(beam.window.FixedWindows(10))\n", + " | \"RunInference\" >> RunInference(model_handler=model_handler,\n", + " model_metadata_pcoll=side_input_pcoll))" + ], + "metadata": { + "id": "_AjvvexJ_hUq" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "4. Post-process the `PredictionResult` object.\n", + "When the inference is complete, RunInference outputs a `PredictionResult` object that contains the fields `example`, `inference`, and `model_id`. The `model_id` field identifies the model used to run the inference. The `PostProcessor` returns the predicted label and the model ID used to run the inference on the predicted label." + ], + "metadata": { + "id": "lTA4wRWNDVis" + } + }, + { + "cell_type": "code", + "source": [ + "post_processor = (\n", + " inferences\n", + " | \"PostProcessResults\" >> beam.ParDo(PostProcessor())\n", + " | \"LogResults\" >> beam.Map(logging.info))" + ], + "metadata": { + "id": "9TB76fo-_vZJ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Watch for the model update\n", + "\n", + "After the pipeline starts processing data and when you see output emitted from the RunInference `PTransform`, upload a `resnet152` model saved in `.keras` format to a Google Cloud Storage bucket location that matches the `file_pattern` you defined earlier.\n" + ], + "metadata": { + "id": "wYp-mBHHjOjA" + } + }, + { + "cell_type": "code", + "source": [ + "model = tf.keras.applications.resnet.ResNet152()\n", + "model.save('resnet152_weights_tf_dim_ordering_tf_kernels.keras')\n", + "# Replace the `BUCKET_NAME` with the actual bucket name.\n", + "!gsutil cp resnet152_weights_tf_dim_ordering_tf_kernels.keras gs:///resnet152_weights_tf_dim_ordering_tf_kernels.keras" + ], + "metadata": { + "id": "FpUfNBSWH9Xy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Run the pipeline\n", + "\n", + "Use the following code to run the pipeline." + ], + "metadata": { + "id": "_ty03jDnKdKR" + } + }, + { + "cell_type": "code", + "source": [ + "# Run the pipeline.\n", + "result = pipeline.run().wait_until_finish()" + ], + "metadata": { + "id": "wd0VJLeLEWBU" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/examples/notebooks/healthcare/beam_nlp.ipynb b/examples/notebooks/healthcare/beam_nlp.ipynb index 5106aaa607d9..4ba4a5e0a739 100644 --- a/examples/notebooks/healthcare/beam_nlp.ipynb +++ b/examples/notebooks/healthcare/beam_nlp.ipynb @@ -146,7 +146,7 @@ { "cell_type": "markdown", "source": [ - "Then, download [this raw CSV file](https://https://github.com/socd06/medical-nlp/blob/master/data/test.csv), and then upload it into Colab. You should be able to view this file (*test.csv*) in the \"Files\" tab in Colab after uploading." + "Then, download [this raw CSV file](https://github.com/socd06/medical-nlp/blob/master/data/test.csv), and then upload it into Colab. You should be able to view this file (*test.csv*) in the \"Files\" tab in Colab after uploading." ], "metadata": { "id": "1IArtEm8QuCR" diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index afba109285af..7f93135c49b7 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 4e86b9270786..ac72c34e8acc 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.3-bin.zip networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 65dcd68d65c8..0adc8e1a5321 100755 --- a/gradlew +++ b/gradlew @@ -83,10 +83,8 @@ done # This is normally unused # shellcheck disable=SC2034 APP_BASE_NAME=${0##*/} -APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum @@ -133,10 +131,13 @@ location of your Java installation." fi else JAVACMD=java - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." + fi fi # Increase the maximum file descriptors if we can. @@ -144,7 +145,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then case $MAX_FD in #( max*) # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. - # shellcheck disable=SC3045 + # shellcheck disable=SC3045 MAX_FD=$( ulimit -H -n ) || warn "Could not query maximum file descriptor limit" esac @@ -152,7 +153,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then '' | soft) :;; #( *) # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. - # shellcheck disable=SC3045 + # shellcheck disable=SC3045 ulimit -n "$MAX_FD" || warn "Could not set maximum file descriptor limit to $MAX_FD" esac @@ -197,6 +198,10 @@ if "$cygwin" || "$msys" ; then done fi + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + # Collect all arguments for the java command; # * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of # shell script including quotes and variable substitutions, so put them in diff --git a/playground/kafka-emulator/build.gradle b/playground/kafka-emulator/build.gradle index 486a232f9b99..2d3f70aa9883 100644 --- a/playground/kafka-emulator/build.gradle +++ b/playground/kafka-emulator/build.gradle @@ -24,11 +24,11 @@ plugins { applyJavaNature(exportJavadoc: false, publish: false) distZip { - archiveName "${baseName}.zip" + archiveFileName = "${archiveBaseName}.zip" } distTar { - archiveName "${baseName}.tar" + archiveFileName = "${archiveBaseName}.tar" } dependencies { diff --git a/sdks/go.mod b/sdks/go.mod index e17427227eba..d817ae549857 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -28,7 +28,7 @@ require ( cloud.google.com/go/datastore v1.14.0 cloud.google.com/go/profiler v0.3.1 cloud.google.com/go/pubsub v1.33.0 - cloud.google.com/go/spanner v1.49.0 + cloud.google.com/go/spanner v1.50.0 cloud.google.com/go/storage v1.33.0 github.com/aws/aws-sdk-go-v2 v1.21.0 github.com/aws/aws-sdk-go-v2/config v1.18.43 @@ -67,8 +67,8 @@ require ( ) require ( - github.com/fsouza/fake-gcs-server v1.47.4 - golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 + github.com/fsouza/fake-gcs-server v1.47.5 + golang.org/x/exp v0.0.0-20230807204917-050eac23e9de ) require ( @@ -88,7 +88,7 @@ require ( cloud.google.com/go v0.110.7 // indirect cloud.google.com/go/compute v1.23.0 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect - cloud.google.com/go/iam v1.1.1 // indirect + cloud.google.com/go/iam v1.1.2 // indirect cloud.google.com/go/longrunning v0.5.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index 71c1c4545c89..9f43e9a53abc 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -26,8 +26,8 @@ cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7 cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/datastore v1.14.0 h1:Mq0ApTRdLW3/dyiw+DkjTk0+iGIUvkbzaC8sfPwWTH4= cloud.google.com/go/datastore v1.14.0/go.mod h1:GAeStMBIt9bPS7jMJA85kgkpsMkvseWWXiaHya9Jes8= -cloud.google.com/go/iam v1.1.1 h1:lW7fzj15aVIXYHREOqjRBV9PsH0Z6u8Y46a1YGvQP4Y= -cloud.google.com/go/iam v1.1.1/go.mod h1:A5avdyVL2tCppe4unb0951eI9jreack+RJ0/d+KUZOU= +cloud.google.com/go/iam v1.1.2 h1:gacbrBdWcoVmGLozRuStX45YKvJtzIjJdAolzUs1sm4= +cloud.google.com/go/iam v1.1.2/go.mod h1:A5avdyVL2tCppe4unb0951eI9jreack+RJ0/d+KUZOU= cloud.google.com/go/kms v1.15.0 h1:xYl5WEaSekKYN5gGRyhjvZKM22GVBBCzegGNVPy+aIs= cloud.google.com/go/longrunning v0.5.1 h1:Fr7TXftcqTudoyRJa113hyaqlGdiBQkp0Gq7tErFDWI= cloud.google.com/go/longrunning v0.5.1/go.mod h1:spvimkwdz6SPWKEt/XBij79E9fiTkHSQl/fRUUQJYJc= @@ -38,8 +38,8 @@ cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+ cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= cloud.google.com/go/pubsub v1.33.0 h1:6SPCPvWav64tj0sVX/+npCBKhUi/UjJehy9op/V3p2g= cloud.google.com/go/pubsub v1.33.0/go.mod h1:f+w71I33OMyxf9VpMVcZbnG5KSUkCOUHYpFd5U1GdRc= -cloud.google.com/go/spanner v1.49.0 h1:+HY8C4uztU7XyLz3xMi/LCXdetLEOExhvRFJu2NiVXM= -cloud.google.com/go/spanner v1.49.0/go.mod h1:eGj9mQGK8+hkgSVbHNQ06pQ4oS+cyc4tXXd6Dif1KoM= +cloud.google.com/go/spanner v1.50.0 h1:QrJFOpaxCXdXF+GkiruLz642PHxkdj68PbbnLw3O2Zw= +cloud.google.com/go/spanner v1.50.0/go.mod h1:eGj9mQGK8+hkgSVbHNQ06pQ4oS+cyc4tXXd6Dif1KoM= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= @@ -195,8 +195,8 @@ github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoD github.com/frankban/quicktest v1.2.2/go.mod h1:Qh/WofXFeiAFII1aEBu529AtJo6Zg2VHscnEsbBnJ20= github.com/frankban/quicktest v1.11.3 h1:8sXhOn0uLys67V8EsXLc6eszDs8VXWxL3iRvebPhedY= github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= -github.com/fsouza/fake-gcs-server v1.47.4 h1:gfBhBxEra20/Om02cvcyL8EnekV8KDb01Yffjat6AKQ= -github.com/fsouza/fake-gcs-server v1.47.4/go.mod h1:vqUZbI12uy9IkRQ54Q4p5AniQsSiUq8alO9Nv2egMmA= +github.com/fsouza/fake-gcs-server v1.47.5 h1:o+wL01s01j/2OdkIaduDogXw2bZveq9TFb8f+BqEHtM= +github.com/fsouza/fake-gcs-server v1.47.5/go.mod h1:PhN8F1rHAOCL5jWyXcw8nPfLfHnka6D9fT7ctL9nbkA= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -348,7 +348,7 @@ github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcs github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= -github.com/minio/minio-go/v7 v7.0.61 h1:87c+x8J3jxQ5VUGimV9oHdpjsAvy3fhneEBKuoKEVUI= +github.com/minio/minio-go/v7 v7.0.63 h1:GbZ2oCvaUdgT5640WJOpyDhhDxvknAJU2/T3yurwcbQ= github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM= github.com/moby/patternmatcher v0.5.0 h1:YCZgJOeULcxLw1Q+sVR636pmS7sPEn1Qo2iAN6M7DBo= github.com/moby/patternmatcher v0.5.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc= @@ -497,8 +497,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= -golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/exp v0.0.0-20230807204917-050eac23e9de h1:l5Za6utMv/HsBWWqzt4S8X17j+kt1uVETUX5UFhn2rE= +golang.org/x/exp v0.0.0-20230807204917-050eac23e9de/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/sdks/java/container/common.gradle b/sdks/java/container/common.gradle index bf4c122ca91f..cc427494ed6e 100644 --- a/sdks/java/container/common.gradle +++ b/sdks/java/container/common.gradle @@ -63,6 +63,8 @@ task copyDockerfileDependencies(type: Copy) { task copySdkHarnessLauncher(type: Copy) { dependsOn ":sdks:java:container:downloadCloudProfilerAgent" + // if licenses are required, they should be present before this task run. + mustRunAfter ":sdks:java:container:pullLicenses" from configurations.sdkHarnessLauncher into "build/target" diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GceMetadataUtil.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GceMetadataUtil.java index b853ab792e08..fd49b759fd6d 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GceMetadataUtil.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GceMetadataUtil.java @@ -30,40 +30,60 @@ import org.apache.http.params.BasicHttpParams; import org.apache.http.params.HttpConnectionParams; import org.apache.http.params.HttpParams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** */ public class GceMetadataUtil { private static final String BASE_METADATA_URL = "http://metadata/computeMetadata/v1/"; + private static final Logger LOG = LoggerFactory.getLogger(GceMetadataUtil.class); + static String fetchMetadata(String key) { + String requestUrl = BASE_METADATA_URL + key; int timeoutMillis = 5000; final HttpParams httpParams = new BasicHttpParams(); HttpConnectionParams.setConnectionTimeout(httpParams, timeoutMillis); - HttpClient client = new DefaultHttpClient(httpParams); - HttpGet request = new HttpGet(BASE_METADATA_URL + key); - request.setHeader("Metadata-Flavor", "Google"); - + String ret = ""; try { + HttpClient client = new DefaultHttpClient(httpParams); + + HttpGet request = new HttpGet(requestUrl); + request.setHeader("Metadata-Flavor", "Google"); + HttpResponse response = client.execute(request); - if (response.getStatusLine().getStatusCode() != 200) { - // May mean its running on a non DataflowRunner, in which case it's perfectly normal. - return ""; + if (response.getStatusLine().getStatusCode() == 200) { + InputStream in = response.getEntity().getContent(); + try (final Reader reader = new InputStreamReader(in, StandardCharsets.UTF_8)) { + ret = CharStreams.toString(reader); + } } - InputStream in = response.getEntity().getContent(); - try (final Reader reader = new InputStreamReader(in, StandardCharsets.UTF_8)) { - return CharStreams.toString(reader); - } - } catch (IOException e) { - // May mean its running on a non DataflowRunner, in which case it's perfectly normal. + } catch (IOException ignored) { } - return ""; + + // The return value can be an empty string, which may mean it's running on a non DataflowRunner. + LOG.debug("Fetched GCE Metadata at '{}' and got '{}'", requestUrl, ret); + + return ret; + } + + private static String fetchVmInstanceMetadata(String instanceMetadataKey) { + return GceMetadataUtil.fetchMetadata("instance/" + instanceMetadataKey); } private static String fetchCustomGceMetadata(String customMetadataKey) { - return GceMetadataUtil.fetchMetadata("instance/attributes/" + customMetadataKey); + return GceMetadataUtil.fetchVmInstanceMetadata("attributes/" + customMetadataKey); } public static String fetchDataflowJobId() { return GceMetadataUtil.fetchCustomGceMetadata("job_id"); } + + public static String fetchDataflowJobName() { + return GceMetadataUtil.fetchCustomGceMetadata("job_name"); + } + + public static String fetchDataflowWorkerId() { + return GceMetadataUtil.fetchVmInstanceMetadata("id"); + } } diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle index 560b27aae162..c4a508680186 100644 --- a/sdks/java/io/google-cloud-platform/build.gradle +++ b/sdks/java/io/google-cloud-platform/build.gradle @@ -202,10 +202,8 @@ task integrationTest(type: Test, dependsOn: processTestResources) { exclude '**/BigQueryIOReadIT.class' exclude '**/BigQueryIOStorageQueryIT.class' exclude '**/BigQueryIOStorageReadIT.class' - exclude '**/BigQueryIOStorageReadTableRowIT.class' exclude '**/BigQueryIOStorageWriteIT.class' exclude '**/BigQueryToTableIT.class' - exclude '**/BigQueryIOJsonTest.class' maxParallelForks 4 classpath = sourceSets.test.runtimeClasspath @@ -244,6 +242,48 @@ task integrationTestKms(type: Test) { } } +/* + Integration tests for BigQueryIO that run on BigQuery's early rollout region (us-east7) + with the intended purpose of catching breaking changes from new BigQuery releases. + If these tests fail here but not in `Java_GCP_IO_Direct`, there may be a new BigQuery change + that is breaking the connector. If this is the case, we should verify with the appropriate + BigQuery infrastructure API team. + + To test in a BigQuery location, we just need to create our datasets in that location. + */ +task bigQueryEarlyRolloutIntegrationTest(type: Test, dependsOn: processTestResources) { + group = "Verification" + def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' + def gcpTempRoot = project.findProperty('gcpTempRoot') ?: 'gs://temp-storage-for-bigquery-day0-tests' + systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ + "--runner=DirectRunner", + "--project=${gcpProject}", + "--tempRoot=${gcpTempRoot}", + "--bigQueryLocation=us-east7", + ]) + + outputs.upToDateWhen { false } + + // export and direct read + include '**/BigQueryToTableIT.class' + include '**/BigQueryIOJsonIT.class' + include '**/BigQueryIOStorageReadTableRowIT.class' + // storage write api + include '**/StorageApiDirectWriteProtosIT.class' + include '**/StorageApiSinkFailedRowsIT.class' + include '**/StorageApiSinkRowUpdateIT.class' + include '**/StorageApiSinkSchemaUpdateIT.class' + include '**/TableRowToStorageApiProtoIT.class' + // file loads + include '**/BigQuerySchemaUpdateOptionsIT.class' + include '**/BigQueryTimePartitioningClusteringIT.class' + include '**/FileLoadsStreamingIT.class' + + maxParallelForks 4 + classpath = sourceSets.test.runtimeClasspath + testClassesDirs = sourceSets.test.output.classesDirs +} + // path(s) for Cloud Spanner related classes def spannerIncludes = [ '**/org/apache/beam/sdk/io/gcp/spanner/**', @@ -267,8 +307,8 @@ task spannerCodeCoverageReport(type: JacocoReport, dependsOn: test) { sourceDirectories.setFrom(files(project.sourceSets.main.allSource.srcDirs)) executionData.setFrom(file("${buildDir}/jacoco/test.exec")) reports { - html.enabled true - html.destination file("${buildDir}/reports/jacoco/spanner/") + html.getRequired().set(true) + html.getOutputLocation().set(file("${buildDir}/reports/jacoco/spanner/")) } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOMetadata.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOMetadata.java index ee64a7ab9ddb..1893418dedb3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOMetadata.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOMetadata.java @@ -28,8 +28,15 @@ final class BigQueryIOMetadata { private @Nullable String beamJobId; - private BigQueryIOMetadata(@Nullable String beamJobId) { + private @Nullable String beamJobName; + + private @Nullable String beamWorkerId; + + private BigQueryIOMetadata( + @Nullable String beamJobId, @Nullable String beamJobName, @Nullable String beamWorkerId) { this.beamJobId = beamJobId; + this.beamJobName = beamJobName; + this.beamWorkerId = beamWorkerId; } private static final Pattern VALID_CLOUD_LABEL_PATTERN = @@ -41,17 +48,24 @@ private BigQueryIOMetadata(@Nullable String beamJobId) { */ public static BigQueryIOMetadata create() { String dataflowJobId = GceMetadataUtil.fetchDataflowJobId(); + String dataflowJobName = GceMetadataUtil.fetchDataflowJobName(); + String dataflowWorkerId = GceMetadataUtil.fetchDataflowWorkerId(); + // If a Dataflow job id is returned on GCE metadata. Then it means // this program is running on a Dataflow GCE VM. - boolean isDataflowRunner = dataflowJobId != null && !dataflowJobId.isEmpty(); + boolean isDataflowRunner = !dataflowJobId.isEmpty(); String beamJobId = null; + String beamJobName = null; + String beamWorkerId = null; if (isDataflowRunner) { if (BigQueryIOMetadata.isValidCloudLabel(dataflowJobId)) { beamJobId = dataflowJobId; + beamJobName = dataflowJobName; + beamWorkerId = dataflowWorkerId; } } - return new BigQueryIOMetadata(beamJobId); + return new BigQueryIOMetadata(beamJobId, beamJobName, beamWorkerId); } public Map addAdditionalJobLabels(Map jobLabels) { @@ -68,6 +82,20 @@ public Map addAdditionalJobLabels(Map jobLabels) return this.beamJobId; } + /* + * Returns the beam job name. Can be null if it is not running on Dataflow. + */ + public @Nullable String getBeamJobName() { + return this.beamJobName; + } + + /* + * Returns the beam worker id. Can be null if it is not running on Dataflow. + */ + public @Nullable String getBeamWorkerId() { + return this.beamWorkerId; + } + /** * Returns true if label_value is a valid cloud label string. This function can return false in * cases where the label value is valid. However, it will not return true in a case where the diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java index 3d4565cb086e..1b6cc555511d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java @@ -1364,6 +1364,15 @@ public StreamAppendClient getStreamAppendClient( .setChannelsPerCpu(2) .build(); + String traceId = + String.format( + "Dataflow:%s:%s:%s", + bqIOMetadata.getBeamJobName() == null + ? options.getJobName() + : bqIOMetadata.getBeamJobName(), + bqIOMetadata.getBeamJobId() == null ? "" : bqIOMetadata.getBeamJobId(), + bqIOMetadata.getBeamWorkerId() == null ? "" : bqIOMetadata.getBeamWorkerId()); + StreamWriter streamWriter = StreamWriter.newBuilder(streamName, newWriteClient) .setExecutorProvider( @@ -1374,11 +1383,7 @@ public StreamAppendClient getStreamAppendClient( .setEnableConnectionPool(useConnectionPool) .setMaxInflightRequests(storageWriteMaxInflightRequests) .setMaxInflightBytes(storageWriteMaxInflightBytes) - .setTraceId( - "Dataflow:" - + (bqIOMetadata.getBeamJobId() != null - ? bqIOMetadata.getBeamJobId() - : options.getJobName())) + .setTraceId(traceId) .build(); return new StreamAppendClient() { private int pins = 0; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TestBigQueryOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TestBigQueryOptions.java index 3574c12ee3a9..4d8095c1879d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TestBigQueryOptions.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TestBigQueryOptions.java @@ -24,10 +24,17 @@ /** {@link TestPipelineOptions} for {@link TestBigQuery}. */ public interface TestBigQueryOptions extends TestPipelineOptions, BigQueryOptions, GcpOptions { + String BIGQUERY_EARLY_ROLLOUT_REGION = "us-east7"; @Description("Dataset used in the integration tests. Default is integ_test") @Default.String("integ_test") String getTargetDataset(); void setTargetDataset(String value); + + @Description("Region to perform BigQuery operations in.") + @Default.String("") + String getBigQueryLocation(); + + void setBigQueryLocation(String location); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/common/GcpIoPipelineOptionsRegistrar.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/common/GcpIoPipelineOptionsRegistrar.java index 1ed9ed6cb6c3..f1ff827fc633 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/common/GcpIoPipelineOptionsRegistrar.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/common/GcpIoPipelineOptionsRegistrar.java @@ -20,6 +20,7 @@ import com.google.auto.service.AutoService; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryOptions; +import org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions; import org.apache.beam.sdk.io.gcp.firestore.FirestoreOptions; import org.apache.beam.sdk.io.gcp.pubsub.PubsubOptions; import org.apache.beam.sdk.options.PipelineOptions; @@ -36,6 +37,7 @@ public Iterable> getPipelineOptions() { .add(BigQueryOptions.class) .add(PubsubOptions.class) .add(FirestoreOptions.class) + .add(TestBigQueryOptions.class) .build(); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java index b21fdd669596..0e9476e6a226 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java @@ -292,6 +292,21 @@ private QueryResponse getTypedTableRows(QueryResponse response) { public List queryUnflattened( String query, String projectId, boolean typed, boolean useStandardSql) throws IOException, InterruptedException { + return queryUnflattened(query, projectId, typed, useStandardSql, null); + } + + /** + * Performs a query without flattening results. May choose a location (GCP region) to perform this + * operation in. + */ + @Nonnull + public List queryUnflattened( + String query, + String projectId, + boolean typed, + boolean useStandardSql, + @Nullable String location) + throws IOException, InterruptedException { Random rnd = new Random(System.currentTimeMillis()); String temporaryDatasetId = String.format("_dataflow_temporary_dataset_%s_%s", System.nanoTime(), rnd.nextInt(1000000)); @@ -302,9 +317,11 @@ public List queryUnflattened( .setDatasetId(temporaryDatasetId) .setTableId(temporaryTableId); - createNewDataset(projectId, temporaryDatasetId); + createNewDataset(projectId, temporaryDatasetId, null, location); createNewTable( - projectId, temporaryDatasetId, new Table().setTableReference(tempTableReference)); + projectId, + temporaryDatasetId, + new Table().setTableReference(tempTableReference).setLocation(location)); JobConfigurationQuery jcQuery = new JobConfigurationQuery() @@ -325,6 +342,7 @@ public List queryUnflattened( bqClient .jobs() .getQueryResults(projectId, insertedJob.getJobReference().getJobId()) + .setLocation(location) .execute(); } while (!qResponse.getJobComplete()); @@ -395,6 +413,18 @@ public void createNewDataset(String projectId, String datasetId) public void createNewDataset( String projectId, String datasetId, @Nullable Long defaultTableExpirationMs) throws IOException, InterruptedException { + createNewDataset(projectId, datasetId, defaultTableExpirationMs, null); + } + + /** + * Creates a new dataset with defaultTableExpirationMs and in a specified location (GCP region). + */ + public void createNewDataset( + String projectId, + String datasetId, + @Nullable Long defaultTableExpirationMs, + @Nullable String location) + throws IOException, InterruptedException { Sleeper sleeper = Sleeper.DEFAULT; BackOff backoff = BackOffAdapter.toGcpBackOff(BACKOFF_FACTORY.backoff()); IOException lastException = null; @@ -410,7 +440,8 @@ public void createNewDataset( projectId, new Dataset() .setDatasetReference(new DatasetReference().setDatasetId(datasetId)) - .setDefaultTableExpirationMs(defaultTableExpirationMs)) + .setDefaultTableExpirationMs(defaultTableExpirationMs) + .setLocation(location)) .execute(); if (response != null) { LOG.info("Successfully created new dataset : " + response.getId()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java index 692a12c0f4a7..d355d6bb9336 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION; + import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; @@ -52,7 +54,13 @@ public class BigQueryIOStorageQueryIT { "1G", 11110839L, "1T", 11110839000L); - private static final String DATASET_ID = "big_query_storage"; + private static final String DATASET_ID = + TestPipeline.testingPipelineOptions() + .as(TestBigQueryOptions.class) + .getBigQueryLocation() + .equals(BIGQUERY_EARLY_ROLLOUT_REGION) + ? "big_query_storage_day0" + : "big_query_storage"; private static final String TABLE_PREFIX = "storage_read_"; private BigQueryIOStorageQueryOptions options; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java index 570938470b9d..b4f6ddb76f72 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION; import static org.junit.Assert.assertEquals; import com.google.cloud.bigquery.storage.v1.DataFormat; @@ -65,7 +66,13 @@ public class BigQueryIOStorageReadIT { "1T", 11110839000L, "multi_field", 11110839L); - private static final String DATASET_ID = "big_query_storage"; + private static final String DATASET_ID = + TestPipeline.testingPipelineOptions() + .as(TestBigQueryOptions.class) + .getBigQueryLocation() + .equals(BIGQUERY_EARLY_ROLLOUT_REGION) + ? "big_query_storage_day0" + : "big_query_storage"; private static final String TABLE_PREFIX = "storage_read_"; private BigQueryIOStorageReadOptions options; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java index 734c3af2c4d4..35e2676c70ef 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION; + import com.google.api.services.bigquery.model.TableRow; import java.util.HashSet; import java.util.Set; @@ -52,7 +54,13 @@ @RunWith(JUnit4.class) public class BigQueryIOStorageReadTableRowIT { - private static final String DATASET_ID = "big_query_import_export"; + private static final String DATASET_ID = + TestPipeline.testingPipelineOptions() + .as(TestBigQueryOptions.class) + .getBigQueryLocation() + .equals(BIGQUERY_EARLY_ROLLOUT_REGION) + ? "big_query_import_export_day0" + : "big_query_import_export"; private static final String TABLE_PREFIX = "parallel_read_table_row_"; private BigQueryIOStorageReadTableRowOptions options; @@ -67,12 +75,11 @@ public interface BigQueryIOStorageReadTableRowOptions void setInputTable(String table); } - private static class TableRowToKVPairFn extends SimpleFunction> { + private static class TableRowToKVPairFn extends SimpleFunction> { @Override - public KV apply(TableRow input) { - CharSequence sampleString = (CharSequence) input.get("sample_string"); - String key = sampleString != null ? sampleString.toString() : "null"; - return KV.of(key, BigQueryHelpers.toJsonString(input)); + public KV apply(TableRow input) { + Integer rowId = Integer.parseInt((String) input.get("id")); + return KV.of(rowId, BigQueryHelpers.toJsonString(input)); } } @@ -87,7 +94,7 @@ private void setUpTestEnvironment(String tableName) { private static void runPipeline(BigQueryIOStorageReadTableRowOptions pipelineOptions) { Pipeline pipeline = Pipeline.create(pipelineOptions); - PCollection> jsonTableRowsFromExport = + PCollection> jsonTableRowsFromExport = pipeline .apply( "ExportTable", @@ -96,7 +103,7 @@ private static void runPipeline(BigQueryIOStorageReadTableRowOptions pipelineOpt .withMethod(Method.EXPORT)) .apply("MapExportedRows", MapElements.via(new TableRowToKVPairFn())); - PCollection> jsonTableRowsFromDirectRead = + PCollection> jsonTableRowsFromDirectRead = pipeline .apply( "DirectReadTable", @@ -108,16 +115,16 @@ private static void runPipeline(BigQueryIOStorageReadTableRowOptions pipelineOpt final TupleTag exportTag = new TupleTag<>(); final TupleTag directReadTag = new TupleTag<>(); - PCollection>> unmatchedRows = + PCollection>> unmatchedRows = KeyedPCollectionTuple.of(exportTag, jsonTableRowsFromExport) .and(directReadTag, jsonTableRowsFromDirectRead) .apply(CoGroupByKey.create()) .apply( ParDo.of( - new DoFn, KV>>() { + new DoFn, KV>>() { @ProcessElement - public void processElement(ProcessContext c) throws Exception { - KV element = c.element(); + public void processElement(ProcessContext c) { + KV element = c.element(); // Add all the exported rows for the key to a collection. Set uniqueRows = new HashSet<>(); @@ -147,20 +154,20 @@ public void processElement(ProcessContext c) throws Exception { } @Test - public void testBigQueryStorageReadTableRow1() throws Exception { - setUpTestEnvironment("1"); + public void testBigQueryStorageReadTableRow100() { + setUpTestEnvironment("100"); runPipeline(options); } @Test - public void testBigQueryStorageReadTableRow10k() throws Exception { - setUpTestEnvironment("10k"); + public void testBigQueryStorageReadTableRow1k() { + setUpTestEnvironment("1K"); runPipeline(options); } @Test - public void testBigQueryStorageReadTableRow100k() throws Exception { - setUpTestEnvironment("100k"); + public void testBigQueryStorageReadTableRow10k() { + setUpTestEnvironment("10K"); runPipeline(options); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java index fc3ce0be4b69..d061898d55c7 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java @@ -26,11 +26,11 @@ import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; import java.io.IOException; +import java.security.SecureRandom; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.gcp.testing.BigqueryClient; -import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.MapElements; @@ -43,6 +43,8 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -60,24 +62,37 @@ private enum WriteMode { AT_LEAST_ONCE } - private String project; - private static final String DATASET_ID = "big_query_storage"; + private static String project; + private static final String DATASET_ID = + "big_query_storage_write_it_" + + System.currentTimeMillis() + + "_" + + new SecureRandom().nextInt(32); private static final String TABLE_PREFIX = "storage_write_"; - private BigQueryOptions bqOptions; + private static TestBigQueryOptions bqOptions; private static final BigqueryClient BQ_CLIENT = new BigqueryClient("BigQueryStorageIOWriteIT"); + @BeforeClass + public static void setup() throws Exception { + bqOptions = TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class); + project = bqOptions.as(GcpOptions.class).getProject(); + // Create one BQ dataset for all test cases. + BQ_CLIENT.createNewDataset(project, DATASET_ID, null, bqOptions.getBigQueryLocation()); + } + + @AfterClass + public static void cleanup() { + BQ_CLIENT.deleteDataset(project, DATASET_ID); + } + private void setUpTestEnvironment(WriteMode writeMode) { - PipelineOptionsFactory.register(BigQueryOptions.class); - bqOptions = TestPipeline.testingPipelineOptions().as(BigQueryOptions.class); - bqOptions.setProject(TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject()); bqOptions.setUseStorageWriteApi(true); if (writeMode == WriteMode.AT_LEAST_ONCE) { bqOptions.setUseStorageWriteApiAtLeastOnce(true); } bqOptions.setNumStorageWriteApiStreams(2); bqOptions.setStorageWriteApiTriggeringFrequencySec(1); - project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); } static class FillRowFn extends DoFn { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java index 611c691dca12..833a0a0829c7 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java @@ -87,7 +87,11 @@ public class BigQuerySchemaUpdateOptionsIT { @BeforeClass public static void setupTestEnvironment() throws Exception { project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); - BQ_CLIENT.createNewDataset(project, BIG_QUERY_DATASET_ID); + BQ_CLIENT.createNewDataset( + project, + BIG_QUERY_DATASET_ID, + null, + TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation()); } @AfterClass diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTimePartitioningClusteringIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTimePartitioningClusteringIT.java index 3ceb6f0966b7..da5f396e8d89 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTimePartitioningClusteringIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTimePartitioningClusteringIT.java @@ -24,9 +24,11 @@ import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; import com.google.api.services.bigquery.model.TimePartitioning; +import java.security.SecureRandom; import java.util.Arrays; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.io.gcp.testing.BigqueryClient; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; @@ -38,8 +40,10 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.ValueInSingleWindow; import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -49,7 +53,15 @@ public class BigQueryTimePartitioningClusteringIT { private static final String WEATHER_SAMPLES_TABLE = "apache-beam-testing.samples.weather_stations"; - private static final String DATASET_NAME = "BigQueryTimePartitioningIT"; + + private static String project; + private static final BigqueryClient BQ_CLIENT = + new BigqueryClient("BigQueryTimePartitioningClusteringIT"); + private static final String DATASET_NAME = + "BigQueryTimePartitioningIT_" + + System.currentTimeMillis() + + "_" + + new SecureRandom().nextInt(32); private static final TimePartitioning TIME_PARTITIONING = new TimePartitioning().setField("date").setType("DAY"); private static final Clustering CLUSTERING = @@ -64,6 +76,16 @@ public class BigQueryTimePartitioningClusteringIT { private Bigquery bqClient; private BigQueryClusteringITOptions options; + @BeforeClass + public static void setupTestEnvironment() throws Exception { + project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + BQ_CLIENT.createNewDataset( + project, + DATASET_NAME, + null, + TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation()); + } + @Before public void setUp() { PipelineOptionsFactory.register(BigQueryClusteringITOptions.class); @@ -72,6 +94,11 @@ public void setUp() { bqClient = BigqueryClient.getNewBigqueryClient(options.getAppName()); } + @AfterClass + public static void cleanup() { + BQ_CLIENT.deleteDataset(project, DATASET_NAME); + } + /** Customized PipelineOptions for BigQueryClustering Integration Test. */ public interface BigQueryClusteringITOptions extends TestPipelineOptions, ExperimentalOptions, BigQueryOptions { @@ -110,8 +137,7 @@ public ClusteredDestinations(String tableName) { @Override public TableDestination getDestination(ValueInSingleWindow element) { - return new TableDestination( - String.format("%s.%s", DATASET_NAME, tableName), null, TIME_PARTITIONING, CLUSTERING); + return new TableDestination(tableName, null, TIME_PARTITIONING, CLUSTERING); } @Override @@ -176,6 +202,7 @@ public void testE2EBigQueryClustering() throws Exception { @Test public void testE2EBigQueryClusteringTableFunction() throws Exception { String tableName = "weather_stations_clustered_table_function_" + System.currentTimeMillis(); + String destination = String.format("%s.%s", DATASET_NAME, tableName); Pipeline p = Pipeline.create(options); @@ -185,11 +212,7 @@ public void testE2EBigQueryClusteringTableFunction() throws Exception { BigQueryIO.writeTableRows() .to( (ValueInSingleWindow vsw) -> - new TableDestination( - String.format("%s.%s", DATASET_NAME, tableName), - null, - TIME_PARTITIONING, - CLUSTERING)) + new TableDestination(destination, null, TIME_PARTITIONING, CLUSTERING)) .withClustering() .withSchema(SCHEMA) .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) @@ -206,6 +229,7 @@ public void testE2EBigQueryClusteringTableFunction() throws Exception { public void testE2EBigQueryClusteringDynamicDestinations() throws Exception { String tableName = "weather_stations_clustered_dynamic_destinations_" + System.currentTimeMillis(); + String destination = String.format("%s.%s", DATASET_NAME, tableName); Pipeline p = Pipeline.create(options); @@ -213,7 +237,7 @@ public void testE2EBigQueryClusteringDynamicDestinations() throws Exception { .apply(ParDo.of(new KeepStationNumberAndConvertDate())) .apply( BigQueryIO.writeTableRows() - .to(new ClusteredDestinations(tableName)) + .to(new ClusteredDestinations(destination)) .withClustering() .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE)); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryToTableIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryToTableIT.java index d6b7f8e16412..1abe7752b2e0 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryToTableIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryToTableIT.java @@ -46,7 +46,6 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.Validation; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.WithKeys; @@ -214,7 +213,7 @@ private void verifyStandardQueryRes(String outputTable) throws Exception { } /** Customized PipelineOption for BigQueryToTable Pipeline. */ - public interface BigQueryToTableOptions extends TestPipelineOptions, ExperimentalOptions { + public interface BigQueryToTableOptions extends TestBigQueryOptions, ExperimentalOptions { @Description("The BigQuery query to be used for creating the source") @Validation.Required @@ -252,9 +251,11 @@ public interface BigQueryToTableOptions extends TestPipelineOptions, Experimenta @BeforeClass public static void setupTestEnvironment() throws Exception { PipelineOptionsFactory.register(BigQueryToTableOptions.class); - project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + BigQueryToTableOptions options = + TestPipeline.testingPipelineOptions().as(BigQueryToTableOptions.class); + project = options.as(GcpOptions.class).getProject(); // Create one BQ dataset for all test cases. - BQ_CLIENT.createNewDataset(project, BIG_QUERY_DATASET_ID); + BQ_CLIENT.createNewDataset(project, BIG_QUERY_DATASET_ID, null, options.getBigQueryLocation()); // Create table and insert data for new type query test cases. BQ_CLIENT.createNewTable( diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FileLoadsStreamingIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FileLoadsStreamingIT.java index 012afed6fb43..678708062b8d 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FileLoadsStreamingIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FileLoadsStreamingIT.java @@ -106,11 +106,16 @@ public static Iterable data() { private final Random randomGenerator = new Random(); + // used when test suite specifies a particular GCP location for BigQuery operations + private static String bigQueryLocation; + @BeforeClass public static void setUpTestEnvironment() throws IOException, InterruptedException { // Create one BQ dataset for all test cases. cleanUp(); - BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID); + bigQueryLocation = + TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation(); + BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID, null, bigQueryLocation); } @AfterClass @@ -293,7 +298,7 @@ private static void checkRowCompleteness( throws IOException, InterruptedException { List actualTableRows = BQ_CLIENT.queryUnflattened( - String.format("SELECT * FROM [%s]", tableSpec), PROJECT, true, false); + String.format("SELECT * FROM [%s]", tableSpec), PROJECT, true, false, bigQueryLocation); Schema rowSchema = BigQueryUtils.fromTableSchema(schema); List actualBeamRows = diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDirectWriteProtosIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDirectWriteProtosIT.java index 93bc4162409f..3da93c42a480 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDirectWriteProtosIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDirectWriteProtosIT.java @@ -80,10 +80,15 @@ private BigQueryIO.Write.Method getMethod() { : BigQueryIO.Write.Method.STORAGE_WRITE_API; } + // used when test suite specifies a particular GCP location for BigQuery operations + private static String bigQueryLocation; + @BeforeClass public static void setUpTestEnvironment() throws IOException, InterruptedException { // Create one BQ dataset for all test cases. - BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID); + bigQueryLocation = + TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation(); + BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID, null, bigQueryLocation); } @AfterClass @@ -191,7 +196,7 @@ public void testDirectWriteProtos() throws Exception { void assertRowsWritten(String tableSpec, Iterable expectedItems) throws Exception { List rows = BQ_CLIENT.queryUnflattened( - String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true); + String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true, bigQueryLocation); assertThat(rows, containsInAnyOrder(Iterables.toArray(expectedItems, TableRow.class))); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java index 3dcde8f39cd7..f721f57147e3 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java @@ -108,10 +108,15 @@ private BigQueryIO.Write.Method getMethod() { : BigQueryIO.Write.Method.STORAGE_WRITE_API; } + // used when test suite specifies a particular GCP location for BigQuery operations + private static String bigQueryLocation; + @BeforeClass public static void setUpTestEnvironment() throws IOException, InterruptedException { // Create one BQ dataset for all test cases. - BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID); + bigQueryLocation = + TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation(); + BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID, null, bigQueryLocation); } @AfterClass @@ -217,7 +222,11 @@ private void assertGoodRowsWritten(String tableSpec, Iterable goodRows TableRow queryResponse = Iterables.getOnlyElement( BQ_CLIENT.queryUnflattened( - String.format("SELECT COUNT(*) FROM %s", tableSpec), PROJECT, true, true)); + String.format("SELECT COUNT(*) FROM `%s`", tableSpec), + PROJECT, + true, + true, + bigQueryLocation)); int numRowsWritten = Integer.parseInt((String) queryResponse.get("f0_")); if (useAtLeastOnce) { assertThat(numRowsWritten, Matchers.greaterThanOrEqualTo(Iterables.size(goodRows))); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkRowUpdateIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkRowUpdateIT.java index d5366fe29613..f8cc797a87cd 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkRowUpdateIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkRowUpdateIT.java @@ -49,10 +49,15 @@ public class StorageApiSinkRowUpdateIT { private static final String BIG_QUERY_DATASET_ID = "storage_api_sink_rows_update" + System.nanoTime(); + // used when test suite specifies a particular GCP location for BigQuery operations + private static String bigQueryLocation; + @BeforeClass public static void setUpTestEnvironment() throws IOException, InterruptedException { // Create one BQ dataset for all test cases. - BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID); + bigQueryLocation = + TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation(); + BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID, null, bigQueryLocation); } @AfterClass @@ -129,7 +134,7 @@ private void assertRowsWritten(String tableSpec, Iterable expected) throws IOException, InterruptedException { List queryResponse = BQ_CLIENT.queryUnflattened( - String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true); + String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true, bigQueryLocation); assertThat(queryResponse, containsInAnyOrder(Iterables.toArray(expected, TableRow.class))); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkSchemaUpdateIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkSchemaUpdateIT.java index 6931b7ac9b98..bc99a4f50f70 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkSchemaUpdateIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkSchemaUpdateIT.java @@ -121,17 +121,21 @@ public static Iterable data() { // an updated schema. If that happens consistently, just increase these two numbers // to give it more time. // Total number of rows written to the sink - private static final int TOTAL_N = 60; + private static final int TOTAL_N = 70; // Number of rows with the original schema - private static final int ORIGINAL_N = 50; + private static final int ORIGINAL_N = 60; private final Random randomGenerator = new Random(); + // used when test suite specifies a particular GCP location for BigQuery operations + private static String bigQueryLocation; + @BeforeClass public static void setUpTestEnvironment() throws IOException, InterruptedException { // Create one BQ dataset for all test cases. - LOG.info("Creating dataset {}.", BIG_QUERY_DATASET_ID); - BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID); + bigQueryLocation = + TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation(); + BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID, null, bigQueryLocation); } @AfterClass @@ -459,7 +463,8 @@ private static void checkRowCompleteness( String.format("SELECT COUNT(DISTINCT(id)), COUNT(id) FROM [%s]", tableSpec), PROJECT, true, - false)); + false, + bigQueryLocation)); int distinctCount = Integer.parseInt((String) queryResponse.get("f0_")); int totalCount = Integer.parseInt((String) queryResponse.get("f1_")); @@ -479,7 +484,7 @@ public void checkRowsWithUpdatedSchema( throws IOException, InterruptedException { List actualRows = BQ_CLIENT.queryUnflattened( - String.format("SELECT * FROM [%s]", tableSpec), PROJECT, true, false); + String.format("SELECT * FROM [%s]", tableSpec), PROJECT, true, false, bigQueryLocation); for (TableRow row : actualRows) { // Rows written to the table should not have the extra field if diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoIT.java index 218aa7411414..f28ae588a5ec 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProtoIT.java @@ -318,10 +318,15 @@ public class TableRowToStorageApiProtoIT { .setFields(BASE_TABLE_SCHEMA.getFields())) .build()); + // used when test suite specifies a particular GCP location for BigQuery operations + private static String bigQueryLocation; + @BeforeClass public static void setUpTestEnvironment() throws IOException, InterruptedException { // Create one BQ dataset for all test cases. - BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID); + bigQueryLocation = + TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation(); + BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID, null, bigQueryLocation); } @AfterClass @@ -338,7 +343,7 @@ public void testBaseTableRow() throws IOException, InterruptedException { List actualTableRows = BQ_CLIENT.queryUnflattened( - String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true); + String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true, bigQueryLocation); assertEquals(1, actualTableRows.size()); assertEquals(BASE_TABLE_ROW_EXPECTED, actualTableRows.get(0)); @@ -364,7 +369,7 @@ public void testNestedRichTypesAndNull() throws IOException, InterruptedExceptio List actualTableRows = BQ_CLIENT.queryUnflattened( - String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true); + String.format("SELECT * FROM %s", tableSpec), PROJECT, true, true, bigQueryLocation); assertEquals(1, actualTableRows.size()); assertEquals(BASE_TABLE_ROW_EXPECTED, actualTableRows.get(0).get("nestedValue1")); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangeStreamErrorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangeStreamErrorTest.java index bf2ccd454bb5..9ffa61c93078 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangeStreamErrorTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangeStreamErrorTest.java @@ -52,7 +52,9 @@ import com.google.spanner.v1.TypeCode; import io.grpc.Status; import java.io.Serializable; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import org.apache.beam.runners.direct.DirectOptions; import org.apache.beam.runners.direct.DirectRunner; import org.apache.beam.sdk.Pipeline; @@ -68,7 +70,6 @@ import org.joda.time.Duration; import org.junit.After; import org.junit.Before; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -114,16 +115,22 @@ public void tearDown() throws NoSuchFieldException, IllegalAccessException { } @Test - @Ignore("BEAM-12164 Reenable this test when databaseClient.getDialect returns the right message.") - public void testResourceExhaustedDoesNotRetry() { + // Error code UNAVAILABLE is retried repeatedly until the RPC times out. + public void testUnavailableExceptionRetries() throws InterruptedException { + DirectOptions options = PipelineOptionsFactory.as(DirectOptions.class); + options.setBlockOnRun(false); + options.setRunner(DirectRunner.class); + Pipeline nonBlockingPipeline = TestPipeline.create(options); + mockSpannerService.setExecuteStreamingSqlExecutionTime( - SimulatedExecutionTime.ofStickyException(Status.RESOURCE_EXHAUSTED.asRuntimeException())); + SimulatedExecutionTime.ofStickyException(Status.UNAVAILABLE.asRuntimeException())); final Timestamp startTimestamp = Timestamp.ofTimeSecondsAndNanos(0, 1000); final Timestamp endTimestamp = Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), startTimestamp.getNanos() + 1); + try { - pipeline.apply( + nonBlockingPipeline.apply( SpannerIO.readChangeStream() .withSpannerConfig(getSpannerConfig()) .withChangeStreamName(TEST_CHANGE_STREAM) @@ -131,33 +138,36 @@ public void testResourceExhaustedDoesNotRetry() { .withMetadataTable(TEST_TABLE) .withInclusiveStartAt(startTimestamp) .withInclusiveEndAt(endTimestamp)); - pipeline.run().waitUntilFinish(); + PipelineResult result = nonBlockingPipeline.run(); + while (result.getState() != RUNNING) { + Thread.sleep(50); + } + // The pipeline continues making requests to Spanner to retry the Unavailable errors. + assertNull(result.waitUntilFinish(Duration.millis(500))); } finally { - thrown.expect(SpannerException.class); // databaseClient.getDialect does not currently bubble up the correct message. // Instead, the error returned is: "DEADLINE_EXCEEDED: Operation did not complete " // "in the given time" - thrown.expectMessage("RESOURCE_EXHAUSTED - Statement: 'SELECT 'POSTGRESQL' AS DIALECT"); + thrown.expectMessage("DEADLINE_EXCEEDED"); assertThat( mockSpannerService.countRequestsOfType(ExecuteSqlRequest.class), Matchers.equalTo(0)); } } @Test - @Ignore("BEAM-12164 Reenable this test when databaseClient.getDialect returns the right message.") - public void testUnavailableExceptionRetries() throws InterruptedException { + // Error code ABORTED is retried repeatedly until it times out. + public void testAbortedExceptionRetries() throws InterruptedException { + mockSpannerService.setExecuteStreamingSqlExecutionTime( + SimulatedExecutionTime.ofStickyException(Status.ABORTED.asRuntimeException())); + DirectOptions options = PipelineOptionsFactory.as(DirectOptions.class); options.setBlockOnRun(false); options.setRunner(DirectRunner.class); Pipeline nonBlockingPipeline = TestPipeline.create(options); - mockSpannerService.setExecuteStreamingSqlExecutionTime( - SimulatedExecutionTime.ofStickyException(Status.UNAVAILABLE.asRuntimeException())); - final Timestamp startTimestamp = Timestamp.ofTimeSecondsAndNanos(0, 1000); final Timestamp endTimestamp = Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), startTimestamp.getNanos() + 1); - try { nonBlockingPipeline.apply( SpannerIO.readChangeStream() @@ -171,23 +181,20 @@ public void testUnavailableExceptionRetries() throws InterruptedException { while (result.getState() != RUNNING) { Thread.sleep(50); } - // The pipeline continues making requests to Spanner to retry the Unavailable errors. + // The pipeline continues making requests to Spanner to retry the Aborted errors. assertNull(result.waitUntilFinish(Duration.millis(500))); } finally { - // databaseClient.getDialect does not currently bubble up the correct message. - // Instead, the error returned is: "DEADLINE_EXCEEDED: Operation did not complete " - // "in the given time" - thrown.expectMessage("UNAVAILABLE - Statement: 'SELECT 'POSTGRESQL' AS DIALECT"); + thrown.expectMessage("DEADLINE_EXCEEDED"); assertThat( mockSpannerService.countRequestsOfType(ExecuteSqlRequest.class), Matchers.equalTo(0)); } } @Test - @Ignore("BEAM-12164 Reenable this test when databaseClient.getDialect returns the right message.") - public void testAbortedExceptionNotRetried() { + // Error code UNKNOWN is not retried. + public void testUnknownExceptionDoesNotRetry() { mockSpannerService.setExecuteStreamingSqlExecutionTime( - SimulatedExecutionTime.ofStickyException(Status.ABORTED.asRuntimeException())); + SimulatedExecutionTime.ofStickyException(Status.UNKNOWN.asRuntimeException())); final Timestamp startTimestamp = Timestamp.ofTimeSecondsAndNanos(0, 1000); final Timestamp endTimestamp = @@ -204,19 +211,43 @@ public void testAbortedExceptionNotRetried() { pipeline.run().waitUntilFinish(); } finally { thrown.expect(SpannerException.class); - // databaseClient.getDialect does not currently bubble up the correct message. - // Instead, the error returned is: "DEADLINE_EXCEEDED: Operation did not complete " - // "in the given time" - thrown.expectMessage("ABORTED - Statement: 'SELECT 'POSTGRESQL' AS DIALECT"); + thrown.expectMessage("UNKNOWN"); assertThat( mockSpannerService.countRequestsOfType(ExecuteSqlRequest.class), Matchers.equalTo(0)); } } @Test - public void testAbortedExceptionNotRetriedithDefaultsForStreamSqlRetrySettings() { + // Error code RESOURCE_EXHAUSTED is retried repeatedly. + public void testResourceExhaustedRetry() { mockSpannerService.setExecuteStreamingSqlExecutionTime( - SimulatedExecutionTime.ofStickyException(Status.ABORTED.asRuntimeException())); + SimulatedExecutionTime.ofStickyException(Status.RESOURCE_EXHAUSTED.asRuntimeException())); + + final Timestamp startTimestamp = Timestamp.ofTimeSecondsAndNanos(0, 1000); + final Timestamp endTimestamp = + Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), startTimestamp.getNanos() + 1); + + try { + pipeline.apply( + SpannerIO.readChangeStream() + .withSpannerConfig(getSpannerConfig()) + .withChangeStreamName(TEST_CHANGE_STREAM) + .withMetadataDatabase(TEST_DATABASE) + .withMetadataTable(TEST_TABLE) + .withInclusiveStartAt(startTimestamp) + .withInclusiveEndAt(endTimestamp)); + pipeline.run().waitUntilFinish(); + } finally { + thrown.expectMessage("DEADLINE_EXCEEDED"); + assertThat( + mockSpannerService.countRequestsOfType(ExecuteSqlRequest.class), Matchers.equalTo(0)); + } + } + + @Test + public void testResourceExhaustedRetryWithDefaultSettings() { + mockSpannerService.setExecuteStreamingSqlExecutionTime( + SimulatedExecutionTime.ofStickyException(Status.RESOURCE_EXHAUSTED.asRuntimeException())); final Timestamp startTimestamp = Timestamp.ofTimeSecondsAndNanos(0, 1000); final Timestamp endTimestamp = @@ -230,6 +261,7 @@ public void testAbortedExceptionNotRetriedithDefaultsForStreamSqlRetrySettings() .withProjectId(TEST_PROJECT) .withInstanceId(TEST_INSTANCE) .withDatabaseId(TEST_DATABASE); + try { pipeline.apply( SpannerIO.readChangeStream() @@ -241,24 +273,34 @@ public void testAbortedExceptionNotRetriedithDefaultsForStreamSqlRetrySettings() .withInclusiveEndAt(endTimestamp)); pipeline.run().waitUntilFinish(); } finally { - // databaseClient.getDialect does not currently bubble up the correct message. - // Instead, the error returned is: "DEADLINE_EXCEEDED: Operation did not complete " - // "in the given time" thrown.expect(SpannerException.class); - thrown.expectMessage("ABORTED - Statement: 'SELECT 'POSTGRESQL' AS DIALECT"); + thrown.expectMessage("RESOURCE_EXHAUSTED"); assertThat( mockSpannerService.countRequestsOfType(ExecuteSqlRequest.class), Matchers.equalTo(0)); } } @Test - public void testUnknownExceptionDoesNotRetry() { - mockSpannerService.setExecuteStreamingSqlExecutionTime( - SimulatedExecutionTime.ofStickyException(Status.UNKNOWN.asRuntimeException())); - + public void testInvalidRecordReceived() { final Timestamp startTimestamp = Timestamp.ofTimeSecondsAndNanos(0, 1000); final Timestamp endTimestamp = Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), startTimestamp.getNanos() + 1); + + mockGetDialect(); + mockTableExists(); + mockGetWatermark(startTimestamp); + ResultSet getPartitionResultSet = mockGetParentPartition(startTimestamp, endTimestamp); + mockGetPartitionsAfter( + Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), startTimestamp.getNanos() - 1), + getPartitionResultSet); + mockGetPartitionsAfter( + Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), startTimestamp.getNanos()), + ResultSet.newBuilder().setMetadata(PARTITION_METADATA_RESULT_SET_METADATA).build()); + mockGetPartitionsAfter( + Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), startTimestamp.getNanos() + 1), + ResultSet.newBuilder().setMetadata(PARTITION_METADATA_RESULT_SET_METADATA).build()); + mockInvalidChangeStreamRecordReceived(startTimestamp, endTimestamp); + try { pipeline.apply( SpannerIO.readChangeStream() @@ -271,15 +313,16 @@ public void testUnknownExceptionDoesNotRetry() { pipeline.run().waitUntilFinish(); } finally { thrown.expect(SpannerException.class); - thrown.expectMessage("UNKNOWN - Statement: 'SELECT 'POSTGRESQL' AS DIALECT"); + // DatabaseClient.getDialect returns "DEADLINE_EXCEEDED: Operation did not complete in the " + // given time" even though we mocked it out. + thrown.expectMessage("DEADLINE_EXCEEDED"); assertThat( mockSpannerService.countRequestsOfType(ExecuteSqlRequest.class), Matchers.equalTo(0)); } } @Test - @Ignore("BEAM-12164 Reenable this test when databaseClient.getDialect works.") - public void testInvalidRecordReceived() { + public void testInvalidRecordReceivedWithDefaultSettings() { final Timestamp startTimestamp = Timestamp.ofTimeSecondsAndNanos(0, 1000); final Timestamp endTimestamp = Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), startTimestamp.getNanos() + 1); @@ -288,6 +331,8 @@ public void testInvalidRecordReceived() { mockTableExists(); mockGetWatermark(startTimestamp); ResultSet getPartitionResultSet = mockGetParentPartition(startTimestamp, endTimestamp); + mockchangePartitionState(startTimestamp, endTimestamp, "CREATED"); + mockchangePartitionState(startTimestamp, endTimestamp, "SCHEDULED"); mockGetPartitionsAfter( Timestamp.ofTimeSecondsAndNanos(startTimestamp.getSeconds(), startTimestamp.getNanos() - 1), getPartitionResultSet); @@ -300,9 +345,26 @@ public void testInvalidRecordReceived() { mockInvalidChangeStreamRecordReceived(startTimestamp, endTimestamp); try { + RetrySettings quickRetrySettings = + RetrySettings.newBuilder() + .setInitialRetryDelay(org.threeten.bp.Duration.ofMillis(250)) + .setMaxRetryDelay(org.threeten.bp.Duration.ofSeconds(1)) + .setRetryDelayMultiplier(5) + .setTotalTimeout(org.threeten.bp.Duration.ofSeconds(1)) + .build(); + final SpannerConfig changeStreamConfig = + SpannerConfig.create() + .withEmulatorHost(StaticValueProvider.of(SPANNER_HOST)) + .withIsLocalChannelProvider(StaticValueProvider.of(true)) + .withCommitRetrySettings(quickRetrySettings) + .withExecuteStreamingSqlRetrySettings(null) + .withProjectId(TEST_PROJECT) + .withInstanceId(TEST_INSTANCE) + .withDatabaseId(TEST_DATABASE); + pipeline.apply( SpannerIO.readChangeStream() - .withSpannerConfig(getSpannerConfig()) + .withSpannerConfig(changeStreamConfig) .withChangeStreamName(TEST_CHANGE_STREAM) .withMetadataDatabase(TEST_DATABASE) .withMetadataTable(TEST_TABLE) @@ -311,11 +373,9 @@ public void testInvalidRecordReceived() { pipeline.run().waitUntilFinish(); } finally { thrown.expect(PipelineExecutionException.class); - // DatabaseClient.getDialect returns "DEADLINE_EXCEEDED: Operation did not complete in the " - // given time" even though we mocked it out. thrown.expectMessage("Field not found"); assertThat( - mockSpannerService.countRequestsOfType(ExecuteSqlRequest.class), Matchers.equalTo(0)); + mockSpannerService.countRequestsOfType(ExecuteSqlRequest.class), Matchers.greaterThan(0)); } } @@ -487,6 +547,41 @@ private void mockTableExists() { StatementResult.query(tableExistsStatement, tableExistsResultSet)); } + private ResultSet mockchangePartitionState( + Timestamp startTimestamp, Timestamp after3Seconds, String state) { + List tokens = new ArrayList<>(); + tokens.add("Parent0"); + Statement getPartitionStatement = + Statement.newBuilder( + "SELECT * FROM my-metadata-table WHERE PartitionToken IN UNNEST(@partitionTokens) AND State = @state") + .bind("partitionTokens") + .toStringArray(tokens) + .bind("state") + .to(state) + .build(); + ResultSet getPartitionResultSet = + ResultSet.newBuilder() + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("Parent0")) + .addValues(Value.newBuilder().setListValue(ListValue.newBuilder().build())) + .addValues(Value.newBuilder().setStringValue(startTimestamp.toString())) + .addValues(Value.newBuilder().setStringValue(after3Seconds.toString())) + .addValues(Value.newBuilder().setStringValue("500")) + .addValues(Value.newBuilder().setStringValue(State.CREATED.name())) + .addValues(Value.newBuilder().setStringValue(startTimestamp.toString())) + .addValues(Value.newBuilder().setStringValue(startTimestamp.toString())) + .addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()) + .addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()) + .addValues(Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build()) + .build()) + .setMetadata(PARTITION_METADATA_RESULT_SET_METADATA) + .build(); + mockSpannerService.putStatementResult( + StatementResult.query(getPartitionStatement, getPartitionResultSet)); + return getPartitionResultSet; + } + private void mockGetDialect() { Statement determineDialectStatement = Statement.newBuilder( diff --git a/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/SqlTransformRunner.java b/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/SqlTransformRunner.java index e8b85f63b36a..6570b7fe81b2 100644 --- a/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/SqlTransformRunner.java +++ b/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/SqlTransformRunner.java @@ -283,7 +283,8 @@ public static void runUsingSqlTransform(String[] args) throws Exception { // Make an array of pipelines, each pipeline is responsible for running a corresponding query. Pipeline[] pipelines = new Pipeline[queryNames.length]; - CSVFormat csvFormat = CSVFormat.MYSQL.withDelimiter('|').withNullString(""); + CSVFormat csvFormat = + CSVFormat.MYSQL.withDelimiter('|').withTrailingDelimiter().withNullString(""); // Execute all queries, transform each result into a PCollection, write them into // the txt file and store in a GCP directory. diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py index a74ccbba041a..b7aa130fbbd8 100644 --- a/sdks/python/apache_beam/dataframe/frames.py +++ b/sdks/python/apache_beam/dataframe/frames.py @@ -1388,7 +1388,7 @@ def align(self, other, join, axis, level, method, **kwargs): Only the default, ``method=None``, is allowed.""" if level is not None: raise NotImplementedError('per-level align') - if method is not None: + if method is not None and method != lib.no_default: raise frame_base.WontImplementError( f"align(method={method!r}) is not supported because it is " "order sensitive. Only align(method=None) is supported.", @@ -2580,7 +2580,7 @@ def align(self, other, join, axis, copy, level, method, **kwargs): "align(copy=False) is not supported because it might be an inplace " "operation depending on the data. Please prefer the default " "align(copy=True).") - if method is not None: + if method is not None and method != lib.no_default: raise frame_base.WontImplementError( f"align(method={method!r}) is not supported because it is " "order sensitive. Only align(method=None) is supported.", @@ -2978,6 +2978,8 @@ def aggregate(self, func, axis, *args, **kwargs): agg = aggregate applymap = frame_base._elementwise_method('applymap', base=pd.DataFrame) + if PD_VERSION >= (2, 1): + map = frame_base._elementwise_method('map', base=pd.DataFrame) add_prefix = frame_base._elementwise_method('add_prefix', base=pd.DataFrame) add_suffix = frame_base._elementwise_method('add_suffix', base=pd.DataFrame) @@ -4594,8 +4596,9 @@ def wrapper(self, *args, **kwargs): return _unliftable_agg(meth)(self, *args, **kwargs) to_group = self._ungrouped.proxy().index - is_categorical_grouping = any(to_group.get_level_values(i).is_categorical() - for i in self._grouping_indexes) + is_categorical_grouping = any( + isinstance(to_group.get_level_values(i).dtype, pd.CategoricalDtype) + for i in self._grouping_indexes) groupby_kwargs = self._kwargs group_keys = self._group_keys @@ -4647,8 +4650,9 @@ def wrapper(self, *args, **kwargs): to_group = self._ungrouped.proxy().index group_keys = self._group_keys - is_categorical_grouping = any(to_group.get_level_values(i).is_categorical() - for i in self._grouping_indexes) + is_categorical_grouping = any( + isinstance(to_group.get_level_values(i).dtype, pd.CategoricalDtype) + for i in self._grouping_indexes) groupby_kwargs = self._kwargs project = _maybe_project_func(self._projection) diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py index 6f7a63c29164..6e32acefc61b 100644 --- a/sdks/python/apache_beam/dataframe/frames_test.py +++ b/sdks/python/apache_beam/dataframe/frames_test.py @@ -865,6 +865,8 @@ def test_corrwith_bad_axis(self): self._run_error_test(lambda df: df.corrwith(df, axis=5), df) @unittest.skipIf(PD_VERSION < (1, 2), "na_action added in pandas 1.2.0") + @pytest.mark.filterwarnings( + "ignore:The default of observed=False is deprecated:FutureWarning") def test_applymap_na_action(self): # Replicates a doctest for na_action which is incompatible with # doctest framework @@ -875,6 +877,17 @@ def test_applymap_na_action(self): # TODO: generate proxy using naive type inference on fn check_proxy=False) + @unittest.skipIf(PD_VERSION < (2, 1), "map added in 2.1.0") + def test_map_na_action(self): + # Replicates a doctest for na_action which is incompatible with + # doctest framework + df = pd.DataFrame([[pd.NA, 2.12], [3.356, 4.567]]) + self._run_test( + lambda df: df.map(lambda x: len(str(x)), na_action='ignore'), + df, + # TODO: generate proxy using naive type inference on fn + check_proxy=False) + def test_dataframe_eval_query(self): df = pd.DataFrame(np.random.randn(20, 3), columns=['a', 'b', 'c']) self._run_test(lambda df: df.eval('foo = a + b - c'), df) @@ -1021,8 +1034,14 @@ def test_categorical_groupby(self): df = df.set_index('B') # TODO(BEAM-11190): These aggregations can be done in index partitions, but # it will require a little more complex logic - self._run_test(lambda df: df.groupby(level=0).sum(), df, nonparallel=True) - self._run_test(lambda df: df.groupby(level=0).mean(), df, nonparallel=True) + self._run_test( + lambda df: df.groupby(level=0, observed=False).sum(), + df, + nonparallel=True) + self._run_test( + lambda df: df.groupby(level=0, observed=False).mean(), + df, + nonparallel=True) def test_astype_categorical(self): df = pd.DataFrame({'A': np.arange(6), 'B': list('aabbca')}) diff --git a/sdks/python/apache_beam/examples/ml-orchestration/kfp/components/preprocessing/requirements.txt b/sdks/python/apache_beam/examples/ml-orchestration/kfp/components/preprocessing/requirements.txt index e902ead34151..706adf9de0aa 100644 --- a/sdks/python/apache_beam/examples/ml-orchestration/kfp/components/preprocessing/requirements.txt +++ b/sdks/python/apache_beam/examples/ml-orchestration/kfp/components/preprocessing/requirements.txt @@ -18,4 +18,4 @@ requests==2.31.0 torch==1.13.1 torchvision==0.13.0 numpy==1.22.4 -Pillow==9.3.0 +Pillow==10.0.1 diff --git a/sdks/python/apache_beam/testing/analyzers/README.md b/sdks/python/apache_beam/testing/analyzers/README.md index 076f173f9d71..91b21076f88a 100644 --- a/sdks/python/apache_beam/testing/analyzers/README.md +++ b/sdks/python/apache_beam/testing/analyzers/README.md @@ -35,16 +35,13 @@ update already created GitHub issue or ignore performance alert by not creating ## Config file structure -The config file defines the structure to run change point analysis on a given test. To add a test to the config file, +The yaml defines the structure to run change point analysis on a given test. To add a test config to the yaml file, please follow the below structure. -**NOTE**: The Change point analysis only supports reading the metric data from Big Query for now. +**NOTE**: The Change point analysis only supports reading the metric data from `BigQuery` only. ``` -# the test_1 must be a unique id. -test_1: - test_description: Pytorch image classification on 50k images of size 224 x 224 with resnet 152 - test_target: apache_beam.testing.benchmarks.inference.pytorch_image_classification_benchmarks +test_1: # a unique id for each test config. metrics_dataset: beam_run_inference metrics_table: torch_inference_imagenet_results_resnet152 project: apache-beam-testing @@ -55,11 +52,17 @@ test_1: num_runs_in_change_point_window: 30 # optional parameter ``` -**NOTE**: `test_target` is optional. It is used for identifying the test that was causing the regression. +#### Optional Parameters: -**Note**: By default, the tool fetches metrics from BigQuery tables. `metrics_dataset`, `metrics_table`, `project` and `metric_name` should match with the values defined for performance/load tests. -The above example uses this [test configuration](https://github.com/apache/beam/blob/0a91d139dea4276dc46176c4cdcdfce210fc50c4/.test-infra/jenkins/job_InferenceBenchmarkTests_Python.groovy#L30) -to fill up the values required to fetch the data from source. +These are the optional parameters that can be added to the test config in addition to the parameters mentioned above. + +- `test_target`: Identifies the test responsible for the regression. + +- `test_description`: Provides a brief overview of the test's function. + +- `test_name`: Denotes the name of the test as stored in the BigQuery table. + +**Note**: The tool, by default, pulls metrics from BigQuery tables. Ensure that the values for `metrics_dataset`, `metrics_table`, `project`, and `metric_name` align with those defined for performance/load tests. The provided example utilizes this [test configuration](https://github.com/apache/beam/blob/0a91d139dea4276dc46176c4cdcdfce210fc50c4/.test-infra/jenkins/job_InferenceBenchmarkTests_Python.groovy#L30) to populate the necessary values for data retrieval. ### Different ways to avoid false positive change points @@ -76,8 +79,35 @@ setting `num_runs_in_change_point_window=7` will achieve it. ## Register a test for performance alerts -If a new test needs to be registered for the performance alerting tool, please add the required test parameters to the -config file. +If a new test needs to be registered for the performance alerting tool, + +- You can either add it to the config file that is already present. +- You can define your own yaml file and call the [perf_analysis.run()](https://github.com/apache/beam/blob/a46bc12a256dcaa3ae2cc9e5d6fdcaa82b59738b/sdks/python/apache_beam/testing/analyzers/perf_analysis.py#L152) method. + + +## Integrating the Perf Alert Tool with a Custom BigQuery Schema + +By default, the Perf Alert Tool retrieves metrics from the `apache-beam-testing` BigQuery projects. All performance and load tests within Beam utilize a standard [schema](https://github.com/apache/beam/blob/a7e12db9b5977c4a7b13554605c0300389a3d6da/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py#L70) for metrics publication. The tool inherently recognizes and operates with this schema when extracting metrics from BigQuery tables. + +To fetch the data from a BigQuery dataset that is not a default setting of the Apache Beam's setting, One can inherit the `MetricsFetcher` class and implement the abstract method `fetch_metric_data`. This method should return a tuple of desired metric values and timestamps of the metric values of when it was published. + +``` +from apache_beam.testing.analyzers import perf_analysis +config_file_path = +my_metric_fetcher = MyMetricsFetcher() # inherited from MetricsFetcher +perf_analysis.run(config_file_path, my_metrics_fetcher) +``` + +``Note``: The metrics and timestamps should be sorted based on the timestamps values in ascending order. + +### Configuring GitHub Parameters + +Out of the box, the performance alert tool targets the `apache/beam` repository when raising issues. If you wish to utilize this tool for another repository, you'll need to pre-set a couple of environment variables: + +- `REPO_OWNER`: Represents the owner of the repository. (e.g., `apache`) +- `REPO_NAME`: Specifies the repository name itself. (e.g., `beam`) + +Before initiating the tool, also ensure that the `GITHUB_TOKEN` is set to an authenticated GitHub token. This permits the tool to generate GitHub issues whenever performance alerts arise. ## Triage performance alert issues diff --git a/sdks/python/apache_beam/testing/analyzers/github_issues_utils.py b/sdks/python/apache_beam/testing/analyzers/github_issues_utils.py index e1f20baa50a6..82758be8f180 100644 --- a/sdks/python/apache_beam/testing/analyzers/github_issues_utils.py +++ b/sdks/python/apache_beam/testing/analyzers/github_issues_utils.py @@ -34,8 +34,8 @@ 'A Github Personal Access token is required ' 'to create Github Issues.') -_BEAM_GITHUB_REPO_OWNER = 'apache' -_BEAM_GITHUB_REPO_NAME = 'beam' +_GITHUB_REPO_OWNER = os.environ.get('REPO_OWNER', 'apache') +_GITHUB_REPO_NAME = os.environ.get('REPO_NAME', 'beam') # Adding GitHub Rest API version to the header to maintain version stability. # For more information, please look at # https://github.blog/2022-11-28-to-infinity-and-beyond-enabling-the-future-of-githubs-rest-api-with-api-versioning/ # pylint: disable=line-too-long @@ -77,10 +77,10 @@ def create_issue( Tuple containing GitHub issue number and issue URL. """ url = "https://api.github.com/repos/{}/{}/issues".format( - _BEAM_GITHUB_REPO_OWNER, _BEAM_GITHUB_REPO_NAME) + _GITHUB_REPO_OWNER, _GITHUB_REPO_NAME) data = { - 'owner': _BEAM_GITHUB_REPO_OWNER, - 'repo': _BEAM_GITHUB_REPO_NAME, + 'owner': _GITHUB_REPO_OWNER, + 'repo': _GITHUB_REPO_NAME, 'title': title, 'body': description, 'labels': [_AWAITING_TRIAGE_LABEL, _PERF_ALERT_LABEL] @@ -108,20 +108,20 @@ def comment_on_issue(issue_number: int, issue, and the comment URL. """ url = 'https://api.github.com/repos/{}/{}/issues/{}'.format( - _BEAM_GITHUB_REPO_OWNER, _BEAM_GITHUB_REPO_NAME, issue_number) + _GITHUB_REPO_OWNER, _GITHUB_REPO_NAME, issue_number) open_issue_response = requests.get( url, json.dumps({ - 'owner': _BEAM_GITHUB_REPO_OWNER, - 'repo': _BEAM_GITHUB_REPO_NAME, + 'owner': _GITHUB_REPO_OWNER, + 'repo': _GITHUB_REPO_NAME, 'issue_number': issue_number }, default=str), headers=_HEADERS).json() if open_issue_response['state'] == 'open': data = { - 'owner': _BEAM_GITHUB_REPO_OWNER, - 'repo': _BEAM_GITHUB_REPO_NAME, + 'owner': _GITHUB_REPO_OWNER, + 'repo': _GITHUB_REPO_NAME, 'body': comment_description, issue_number: issue_number, } @@ -134,13 +134,14 @@ def comment_on_issue(issue_number: int, def add_awaiting_triage_label(issue_number: int): url = 'https://api.github.com/repos/{}/{}/issues/{}/labels'.format( - _BEAM_GITHUB_REPO_OWNER, _BEAM_GITHUB_REPO_NAME, issue_number) + _GITHUB_REPO_OWNER, _GITHUB_REPO_NAME, issue_number) requests.post( url, json.dumps({'labels': [_AWAITING_TRIAGE_LABEL]}), headers=_HEADERS) def get_issue_description( - test_name: str, + test_id: str, + test_name: Optional[str], metric_name: str, timestamps: List[pd.Timestamp], metric_values: List, @@ -167,10 +168,13 @@ def get_issue_description( description = [] - description.append(_ISSUE_DESCRIPTION_TEMPLATE.format(test_name, metric_name)) + description.append(_ISSUE_DESCRIPTION_TEMPLATE.format(test_id, metric_name)) - description.append(("`Test description:` " + - f'{test_description}') if test_description else '') + if test_name: + description.append(("`test_name:` " + f'{test_name}')) + + if test_description: + description.append(("`Test description:` " + f'{test_description}')) description.append('```') diff --git a/sdks/python/apache_beam/testing/analyzers/perf_analysis.py b/sdks/python/apache_beam/testing/analyzers/perf_analysis.py index 7f1ffbb944e9..c86ecb2c4e20 100644 --- a/sdks/python/apache_beam/testing/analyzers/perf_analysis.py +++ b/sdks/python/apache_beam/testing/analyzers/perf_analysis.py @@ -23,7 +23,6 @@ import argparse import logging import os -import uuid from datetime import datetime from datetime import timezone from typing import Any @@ -33,9 +32,10 @@ import pandas as pd from apache_beam.testing.analyzers import constants +from apache_beam.testing.analyzers.perf_analysis_utils import BigQueryMetricsFetcher from apache_beam.testing.analyzers.perf_analysis_utils import GitHubIssueMetaData +from apache_beam.testing.analyzers.perf_analysis_utils import MetricsFetcher from apache_beam.testing.analyzers.perf_analysis_utils import create_performance_alert -from apache_beam.testing.analyzers.perf_analysis_utils import fetch_metric_data from apache_beam.testing.analyzers.perf_analysis_utils import find_latest_change_point_index from apache_beam.testing.analyzers.perf_analysis_utils import get_existing_issues_data from apache_beam.testing.analyzers.perf_analysis_utils import is_change_point_in_valid_window @@ -43,10 +43,10 @@ from apache_beam.testing.analyzers.perf_analysis_utils import publish_issue_metadata_to_big_query from apache_beam.testing.analyzers.perf_analysis_utils import read_test_config from apache_beam.testing.analyzers.perf_analysis_utils import validate_config -from apache_beam.testing.load_tests.load_test_metrics_utils import BigQueryMetricsFetcher -def run_change_point_analysis(params, test_name, big_query_metrics_fetcher): +def run_change_point_analysis( + params, test_id, big_query_metrics_fetcher: MetricsFetcher): """ Args: params: Dict containing parameters to run change point analysis. @@ -56,14 +56,21 @@ def run_change_point_analysis(params, test_name, big_query_metrics_fetcher): Returns: bool indicating if a change point is observed and alerted on GitHub. """ - logging.info("Running change point analysis for test %s" % test_name) + logging.info("Running change point analysis for test ID %s" % test_id) if not validate_config(params.keys()): raise ValueError( f"Please make sure all these keys {constants._PERF_TEST_KEYS} " - f"are specified for the {test_name}") + f"are specified for the {test_id}") metric_name = params['metric_name'] + # test_name will be used to query a single test from + # multiple tests in a single BQ table. Right now, the default + # assumption is that all the test have an individual BQ table + # but this might not be case for other tests(such as IO tests where + # a single BQ tables stores all the data) + test_name = params.get('test_name', None) + min_runs_between_change_points = ( constants._DEFAULT_MIN_RUNS_BETWEEN_CHANGE_POINTS) if 'min_runs_between_change_points' in params: @@ -74,15 +81,18 @@ def run_change_point_analysis(params, test_name, big_query_metrics_fetcher): if 'num_runs_in_change_point_window' in params: num_runs_in_change_point_window = params['num_runs_in_change_point_window'] - metric_values, timestamps = fetch_metric_data( - params=params, - big_query_metrics_fetcher=big_query_metrics_fetcher + metric_values, timestamps = big_query_metrics_fetcher.fetch_metric_data( + project=params['project'], + metrics_dataset=params['metrics_dataset'], + metrics_table=params['metrics_table'], + metric_name=params['metric_name'], + test_name=test_name ) change_point_index = find_latest_change_point_index( metric_values=metric_values) if not change_point_index: - logging.info("Change point is not detected for the test %s" % test_name) + logging.info("Change point is not detected for the test ID %s" % test_id) return False # since timestamps are ordered in ascending order and # num_runs_in_change_point_window refers to the latest runs, @@ -92,11 +102,11 @@ def run_change_point_analysis(params, test_name, big_query_metrics_fetcher): if not is_change_point_in_valid_window(num_runs_in_change_point_window, latest_change_point_run): logging.info( - 'Performance regression/improvement found for the test: %s. ' + 'Performance regression/improvement found for the test ID: %s. ' 'on metric %s. Since the change point run %s ' 'lies outside the num_runs_in_change_point_window distance: %s, ' 'alert is not raised.' % ( - test_name, + test_id, metric_name, latest_change_point_run + 1, num_runs_in_change_point_window)) @@ -106,8 +116,7 @@ def run_change_point_analysis(params, test_name, big_query_metrics_fetcher): last_reported_issue_number = None issue_metadata_table_name = f'{params.get("metrics_table")}_{metric_name}' existing_issue_data = get_existing_issues_data( - table_name=issue_metadata_table_name, - big_query_metrics_fetcher=big_query_metrics_fetcher) + table_name=issue_metadata_table_name) if existing_issue_data is not None: existing_issue_timestamps = existing_issue_data[ @@ -124,20 +133,21 @@ def run_change_point_analysis(params, test_name, big_query_metrics_fetcher): min_runs_between_change_points=min_runs_between_change_points) if is_alert: issue_number, issue_url = create_performance_alert( - metric_name, test_name, timestamps, + metric_name, test_id, timestamps, metric_values, change_point_index, params.get('labels', None), last_reported_issue_number, test_description = params.get('test_description', None), + test_name = test_name ) issue_metadata = GitHubIssueMetaData( issue_timestamp=pd.Timestamp( datetime.now().replace(tzinfo=timezone.utc)), # BQ doesn't allow '.' in table name - test_name=test_name.replace('.', '_'), + test_id=test_id.replace('.', '_'), + test_name=test_name, metric_name=metric_name, - test_id=uuid.uuid4().hex, change_point=metric_values[change_point_index], issue_number=issue_number, issue_url=issue_url, @@ -149,7 +159,10 @@ def run_change_point_analysis(params, test_name, big_query_metrics_fetcher): return is_alert -def run(config_file_path: Optional[str] = None) -> None: +def run( + big_query_metrics_fetcher: MetricsFetcher = BigQueryMetricsFetcher(), + config_file_path: Optional[str] = None, +) -> None: """ run is the entry point to run change point analysis on test metric data, which is read from config file, and if there is a performance @@ -169,12 +182,10 @@ def run(config_file_path: Optional[str] = None) -> None: tests_config: Dict[str, Dict[str, Any]] = read_test_config(config_file_path) - big_query_metrics_fetcher = BigQueryMetricsFetcher() - - for test_name, params in tests_config.items(): + for test_id, params in tests_config.items(): run_change_point_analysis( params=params, - test_name=test_name, + test_id=test_id, big_query_metrics_fetcher=big_query_metrics_fetcher) diff --git a/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py b/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py index 094cd9c47ec0..9c7921300d9d 100644 --- a/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py +++ b/sdks/python/apache_beam/testing/analyzers/perf_analysis_test.py @@ -32,6 +32,7 @@ from apache_beam.io.filesystems import FileSystems from apache_beam.testing.analyzers import constants from apache_beam.testing.analyzers import github_issues_utils + from apache_beam.testing.analyzers.perf_analysis_utils import BigQueryMetricsFetcher from apache_beam.testing.analyzers.perf_analysis_utils import is_change_point_in_valid_window from apache_beam.testing.analyzers.perf_analysis_utils import is_perf_alert from apache_beam.testing.analyzers.perf_analysis_utils import e_divisive @@ -41,18 +42,18 @@ from apache_beam.testing.analyzers.perf_analysis_utils import validate_config from apache_beam.testing.load_tests import load_test_metrics_utils except ImportError as e: - analysis = None # type: ignore + raise unittest.SkipTest('Missing dependencies to run perf analysis tests.') # mock methods. -def get_fake_data_with_no_change_point(**kwargs): +def get_fake_data_with_no_change_point(*args, **kwargs): num_samples = 20 metric_values = [1] * num_samples timestamps = list(range(num_samples)) return metric_values, timestamps -def get_fake_data_with_change_point(**kwargs): +def get_fake_data_with_change_point(*args, **kwargs): # change point will be at index 13. num_samples = 20 metric_values = [0] * 12 + [3] + [4] * 7 @@ -69,10 +70,6 @@ def get_existing_issue_data(**kwargs): }]) -@unittest.skipIf( - analysis is None, - 'Missing dependencies. ' - 'Test dependencies are missing for the Analyzer.') class TestChangePointAnalysis(unittest.TestCase): def setUp(self) -> None: self.single_change_point_series = [0] * 10 + [1] * 10 @@ -151,18 +148,20 @@ def test_duplicate_change_points_are_not_valid_alerts(self): min_runs_between_change_points=min_runs_between_change_points) self.assertFalse(is_alert) - @mock.patch( - 'apache_beam.testing.analyzers.perf_analysis.fetch_metric_data', + @mock.patch.object( + BigQueryMetricsFetcher, + 'fetch_metric_data', get_fake_data_with_no_change_point) def test_no_alerts_when_no_change_points(self): is_alert = analysis.run_change_point_analysis( params=self.params, - test_name=self.test_id, - big_query_metrics_fetcher=None) + test_id=self.test_id, + big_query_metrics_fetcher=BigQueryMetricsFetcher()) self.assertFalse(is_alert) - @mock.patch( - 'apache_beam.testing.analyzers.perf_analysis.fetch_metric_data', + @mock.patch.object( + BigQueryMetricsFetcher, + 'fetch_metric_data', get_fake_data_with_change_point) @mock.patch( 'apache_beam.testing.analyzers.perf_analysis.get_existing_issues_data', @@ -178,12 +177,13 @@ def test_no_alerts_when_no_change_points(self): def test_alert_on_data_with_change_point(self, *args): is_alert = analysis.run_change_point_analysis( params=self.params, - test_name=self.test_id, - big_query_metrics_fetcher=None) + test_id=self.test_id, + big_query_metrics_fetcher=BigQueryMetricsFetcher()) self.assertTrue(is_alert) - @mock.patch( - 'apache_beam.testing.analyzers.perf_analysis.fetch_metric_data', + @mock.patch.object( + BigQueryMetricsFetcher, + 'fetch_metric_data', get_fake_data_with_change_point) @mock.patch( 'apache_beam.testing.analyzers.perf_analysis.get_existing_issues_data', @@ -198,8 +198,8 @@ def test_alert_on_data_with_change_point(self, *args): def test_alert_on_data_with_reported_change_point(self, *args): is_alert = analysis.run_change_point_analysis( params=self.params, - test_name=self.test_id, - big_query_metrics_fetcher=None) + test_id=self.test_id, + big_query_metrics_fetcher=BigQueryMetricsFetcher()) self.assertFalse(is_alert) def test_change_point_has_anomaly_marker_in_gh_description(self): @@ -208,7 +208,8 @@ def test_change_point_has_anomaly_marker_in_gh_description(self): change_point_index = find_latest_change_point_index(metric_values) description = github_issues_utils.get_issue_description( - test_name=self.test_id, + test_id=self.test_id, + test_name=self.params.get('test_name', None), test_description=self.params['test_description'], metric_name=self.params['metric_name'], metric_values=metric_values, diff --git a/sdks/python/apache_beam/testing/analyzers/perf_analysis_utils.py b/sdks/python/apache_beam/testing/analyzers/perf_analysis_utils.py index 0a559fc4beeb..f9604c490fc0 100644 --- a/sdks/python/apache_beam/testing/analyzers/perf_analysis_utils.py +++ b/sdks/python/apache_beam/testing/analyzers/perf_analysis_utils.py @@ -14,11 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import abc import logging from dataclasses import asdict from dataclasses import dataclass from statistics import median -from typing import Any from typing import Dict from typing import List from typing import Optional @@ -28,11 +28,11 @@ import pandas as pd import yaml from google.api_core import exceptions +from google.cloud import bigquery from apache_beam.testing.analyzers import constants from apache_beam.testing.analyzers import github_issues_utils from apache_beam.testing.load_tests import load_test_metrics_utils -from apache_beam.testing.load_tests.load_test_metrics_utils import BigQueryMetricsFetcher from apache_beam.testing.load_tests.load_test_metrics_utils import BigQueryMetricsPublisher from signal_processing_algorithms.energy_statistics.energy_statistics import e_divisive @@ -59,9 +59,7 @@ def is_change_point_in_valid_window( return num_runs_in_change_point_window > latest_change_point_run -def get_existing_issues_data( - table_name: str, big_query_metrics_fetcher: BigQueryMetricsFetcher -) -> Optional[pd.DataFrame]: +def get_existing_issues_data(table_name: str) -> Optional[pd.DataFrame]: """ Finds the most recent GitHub issue created for the test_name. If no table found with name=test_name, return (None, None) @@ -73,12 +71,14 @@ def get_existing_issues_data( LIMIT 10 """ try: - df = big_query_metrics_fetcher.fetch(query=query) + client = bigquery.Client() + query_job = client.query(query=query) + existing_issue_data = query_job.result().to_dataframe() except exceptions.NotFound: # If no table found, that means this is first performance regression # on the current test+metric. return None - return df + return existing_issue_data def is_perf_alert( @@ -123,33 +123,6 @@ def validate_config(keys): return constants._PERF_TEST_KEYS.issubset(keys) -def fetch_metric_data( - params: Dict[str, Any], big_query_metrics_fetcher: BigQueryMetricsFetcher -) -> Tuple[List[Union[int, float]], List[pd.Timestamp]]: - """ - Args: - params: Dict containing keys required to fetch data from a data source. - big_query_metrics_fetcher: A BigQuery metrics fetcher for fetch metrics. - Returns: - Tuple[List[Union[int, float]], List[pd.Timestamp]]: Tuple containing list - of metric_values and list of timestamps. Both are sorted in ascending - order wrt timestamps. - """ - query = f""" - SELECT * - FROM {params['project']}.{params['metrics_dataset']}.{params['metrics_table']} - WHERE CONTAINS_SUBSTR(({load_test_metrics_utils.METRICS_TYPE_LABEL}), '{params['metric_name']}') - ORDER BY {load_test_metrics_utils.SUBMIT_TIMESTAMP_LABEL} DESC - LIMIT {constants._NUM_DATA_POINTS_TO_RUN_CHANGE_POINT_ANALYSIS} - """ - metric_data: pd.DataFrame = big_query_metrics_fetcher.fetch(query=query) - metric_data.sort_values( - by=[load_test_metrics_utils.SUBMIT_TIMESTAMP_LABEL], inplace=True) - return ( - metric_data[load_test_metrics_utils.VALUE_LABEL].tolist(), - metric_data[load_test_metrics_utils.SUBMIT_TIMESTAMP_LABEL].tolist()) - - def find_change_points(metric_values: List[Union[float, int]]): return e_divisive(metric_values) @@ -175,7 +148,7 @@ def find_latest_change_point_index(metric_values: List[Union[float, int]]): def publish_issue_metadata_to_big_query(issue_metadata, table_name): """ - Published issue_metadata to BigQuery with table name=test_name. + Published issue_metadata to BigQuery with table name. """ bq_metrics_publisher = BigQueryMetricsPublisher( project_name=constants._BQ_PROJECT_NAME, @@ -190,18 +163,21 @@ def publish_issue_metadata_to_big_query(issue_metadata, table_name): def create_performance_alert( metric_name: str, - test_name: str, + test_id: str, timestamps: List[pd.Timestamp], metric_values: List[Union[int, float]], change_point_index: int, labels: List[str], existing_issue_number: Optional[int], - test_description: Optional[str] = None) -> Tuple[int, str]: + test_description: Optional[str] = None, + test_name: Optional[str] = None, +) -> Tuple[int, str]: """ Creates performance alert on GitHub issues and returns GitHub issue number and issue URL. """ description = github_issues_utils.get_issue_description( + test_id=test_id, test_name=test_name, test_description=test_description, metric_name=metric_name, @@ -213,7 +189,7 @@ def create_performance_alert( issue_number, issue_url = github_issues_utils.report_change_point_on_issues( title=github_issues_utils._ISSUE_TITLE_TEMPLATE.format( - test_name, metric_name + test_id, metric_name ), description=description, labels=labels, @@ -253,3 +229,55 @@ def filter_change_points_by_median_threshold( if relative_change > threshold: valid_change_points.append(idx) return valid_change_points + + +class MetricsFetcher(metaclass=abc.ABCMeta): + @abc.abstractmethod + def fetch_metric_data( + self, + *, + project, + metrics_dataset, + metrics_table, + metric_name, + test_name=None) -> Tuple[List[Union[int, float]], List[pd.Timestamp]]: + """ + Define SQL query and fetch the timestamp values and metric values + from BigQuery tables. + """ + raise NotImplementedError + + +class BigQueryMetricsFetcher(MetricsFetcher): + def fetch_metric_data( + self, + *, + project, + metrics_dataset, + metrics_table, + metric_name, + test_name=None, + ) -> Tuple[List[Union[int, float]], List[pd.Timestamp]]: + """ + Args: + params: Dict containing keys required to fetch data from a data source. + Returns: + Tuple[List[Union[int, float]], List[pd.Timestamp]]: Tuple containing list + of metric_values and list of timestamps. Both are sorted in ascending + order wrt timestamps. + """ + query = f""" + SELECT * + FROM {project}.{metrics_dataset}.{metrics_table} + WHERE CONTAINS_SUBSTR(({load_test_metrics_utils.METRICS_TYPE_LABEL}), '{metric_name}') + ORDER BY {load_test_metrics_utils.SUBMIT_TIMESTAMP_LABEL} DESC + LIMIT {constants._NUM_DATA_POINTS_TO_RUN_CHANGE_POINT_ANALYSIS} + """ + client = bigquery.Client() + query_job = client.query(query=query) + metric_data = query_job.result().to_dataframe() + metric_data.sort_values( + by=[load_test_metrics_utils.SUBMIT_TIMESTAMP_LABEL], inplace=True) + return ( + metric_data[load_test_metrics_utils.VALUE_LABEL].tolist(), + metric_data[load_test_metrics_utils.SUBMIT_TIMESTAMP_LABEL].tolist()) diff --git a/sdks/python/apache_beam/testing/analyzers/tests_config.yaml b/sdks/python/apache_beam/testing/analyzers/tests_config.yaml index f808f5e41d74..ec9cfe6f1ac0 100644 --- a/sdks/python/apache_beam/testing/analyzers/tests_config.yaml +++ b/sdks/python/apache_beam/testing/analyzers/tests_config.yaml @@ -16,7 +16,7 @@ # # for the unique key to define a test, please use the following format: -# {test_name}-{metric_name} +# {test_id}-{metric_name} pytorch_image_classification_benchmarks-resnet152-mean_inference_batch_latency_micro_secs: test_description: diff --git a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py index 92a5f68351fe..01db2c114efb 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py @@ -38,7 +38,6 @@ from typing import Optional from typing import Union -import pandas as pd import requests from requests.auth import HTTPBasicAuth @@ -650,13 +649,3 @@ def __init__(self): def process(self, element): yield self.timestamp_val_fn( element, self.timestamp_fn(micros=int(self.time_fn() * 1000000))) - - -class BigQueryMetricsFetcher: - def __init__(self): - self.client = bigquery.Client() - - def fetch(self, query) -> pd.DataFrame: - query_job = self.client.query(query=query) - result = query_job.result() - return result.to_dataframe() diff --git a/sdks/python/scripts/run_integration_test.sh b/sdks/python/scripts/run_integration_test.sh index 6ad592080ae2..5ac3627a0960 100755 --- a/sdks/python/scripts/run_integration_test.sh +++ b/sdks/python/scripts/run_integration_test.sh @@ -79,6 +79,7 @@ SUITE="" COLLECT_MARKERS= REQUIREMENTS_FILE="" ARCH="" +PY_VERSION="" # Default test (pytest) options. # Run WordCountIT.test_wordcount_it by default if no test options are @@ -169,6 +170,11 @@ case $key in shift # past argument shift # past value ;; + --py_version) + PY_VERSION="$2" + shift # past argument + shift # past value + ;; *) # unknown option echo "Unknown option: $1" exit 1 @@ -242,6 +248,9 @@ if [[ -z $PIPELINE_OPTS ]]; then if [[ "$ARCH" == "ARM" ]]; then opts+=("--machine_type=t2a-standard-1") + + IMAGE_NAME="beam_python${PY_VERSION}_sdk" + opts+=("--sdk_container_image=us.gcr.io/$PROJECT/$USER/$IMAGE_NAME:$MULTIARCH_TAG") fi if [[ ! -z "$KMS_KEY_NAME" ]]; then diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 7766cf3a377c..a713b82400e7 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -144,7 +144,9 @@ task postCommitIT { } task postCommitArmIT { + def pyversion = "${project.ext.pythonVersion.replace('.', '')}" dependsOn 'initializeForDataflowJob' + dependsOn ":sdks:python:container:py${pyversion}:docker" doLast { def testOpts = basicPytestOpts + ["--numprocesses=8", "--dist=loadfile"] @@ -153,6 +155,7 @@ task postCommitArmIT { "sdk_location": project.ext.sdkLocation, "suite": "postCommitIT-df${pythonVersionSuffix}", "collect": "it_postcommit", + "py_version": project.ext.pythonVersion, "arch": "ARM" ] def cmdArgs = mapToArgString(argMap) diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle index 208a1d9d39ca..bc4aa99c79b4 100644 --- a/sdks/python/test-suites/tox/py38/build.gradle +++ b/sdks/python/test-suites/tox/py38/build.gradle @@ -153,6 +153,8 @@ task archiveFilesToLint(type: Zip) { include "**/*.md" include "**/build.gradle" include '**/build.gradle.kts' + exclude '**/build/**' // intermediate build directory + exclude 'website/www/site/themes/docsy/**' // fork to google/docsy exclude "**/node_modules/*" exclude "**/.gogradle/*" } diff --git a/website/www/site/content/en/blog/beamquest.md b/website/www/site/content/en/blog/beamquest.md index eea893bf8227..dde6376b4077 100644 --- a/website/www/site/content/en/blog/beamquest.md +++ b/website/www/site/content/en/blog/beamquest.md @@ -34,6 +34,6 @@ Individuals aren’t the only ones who can benefit from completing this quest - Data Processing is a key part of AI/ML workflows. Given the recent advancements in artificial intelligence, now’s the time to jump into the world of data processing! Get started on your journey [here](https://www.cloudskillsboost.google/quests/310). -We are currently offering this quest **FREE OF CHARGE** until **July 8, 2023** for the **first 2,000** people. To obtain your badge for **FREE**, use the [Access Code](https://www.cloudskillsboost.google/catalog?qlcampaign=1h-swiss-19), create an account, and search ["Getting Started with Apache Beam"](https://www.cloudskillsboost.google/quests/310). +We are currently offering this quest **FREE OF CHARGE**. To obtain your badge for **FREE**, use the [Access Code](https://www.cloudskillsboost.google/catalog?qlcampaign=1h-swiss-19), create an account, and search ["Getting Started with Apache Beam"](https://www.cloudskillsboost.google/quests/310). If the code does not work, please email [dev@beam.apache.org](dev@beam.apache.org) to obtain a free code. PS: Once you earn your badge, please [share it on social media](https://support.google.com/qwiklabs/answer/9222527?hl=en&sjid=14905615709060962899-NA)!