{
}
project.ext.applyDockerRunNature = {
- project.apply plugin: "com.palantir.docker-run"
+ project.apply plugin: BeamDockerRunPlugin
}
/** ***********************************************************************************************/
diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/VendorJavaPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/VendorJavaPlugin.groovy
index 061ccf27cce2..97d96e6cf1eb 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/VendorJavaPlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/VendorJavaPlugin.groovy
@@ -126,7 +126,7 @@ artifactId=${project.name}
}
config.exclusions.each { exclude it }
- classifier = null
+ archiveClassifier = null
mergeServiceFiles()
zip64 true
exclude "META-INF/INDEX.LIST"
diff --git a/examples/notebooks/beam-ml/automatic_model_refresh.ipynb b/examples/notebooks/beam-ml/automatic_model_refresh.ipynb
index 67fe51af1253..9cbab0a14178 100644
--- a/examples/notebooks/beam-ml/automatic_model_refresh.ipynb
+++ b/examples/notebooks/beam-ml/automatic_model_refresh.ipynb
@@ -1,605 +1,530 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": []
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- }
- },
- "cells": [{
- "cell_type": "code",
- "source": [
- "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
- "\n",
- "# Licensed to the Apache Software Foundation (ASF) under one\n",
- "# or more contributor license agreements. See the NOTICE file\n",
- "# distributed with this work for additional information\n",
- "# regarding copyright ownership. The ASF licenses this file\n",
- "# to you under the Apache License, Version 2.0 (the\n",
- "# \"License\"); you may not use this file except in compliance\n",
- "# with the License. You may obtain a copy of the License at\n",
- "#\n",
- "# http://www.apache.org/licenses/LICENSE-2.0\n",
- "#\n",
- "# Unless required by applicable law or agreed to in writing,\n",
- "# software distributed under the License is distributed on an\n",
- "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
- "# KIND, either express or implied. See the License for the\n",
- "# specific language governing permissions and limitations\n",
- "# under the License"
- ],
- "metadata": {
- "cellView": "form",
- "id": "OsFaZscKSPvo"
- },
- "execution_count": null,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Update ML models in running pipelines\n",
- "\n",
- "\n"
- ],
- "metadata": {
- "id": "ZUSiAR62SgO8"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "This notebook demonstrates how to perform automatic model updates without stopping your Apache Beam pipeline.\n",
- "You can use side inputs to update your model in real time, even while the Apache Beam pipeline is running. The side input is passed in a `ModelHandler` configuration object. You can update the model either by leveraging one of Apache Beam's provided patterns, such as the `WatchFilePattern`, or by configuring a custom side input `PCollection` that defines the logic for the model update.\n",
- "\n",
- "The pipeline in this notebook uses a RunInference `PTransform` with TensorFlow machine learning (ML) models to run inference on images. To update the model, it uses a side input `PCollection` that emits `ModelMetadata`.\n",
- "For more information about side inputs, see the [Side inputs](https://beam.apache.org/documentation/programming-guide/#side-inputs) section in the Apache Beam Programming Guide.\n",
- "\n",
- "This example uses `WatchFilePattern` as a side input. `WatchFilePattern` is used to watch for file updates that match the `file_pattern` based on timestamps. It emits the latest `ModelMetadata`, which is used in the RunInference `PTransform` to automatically update the ML model without stopping the Apache Beam pipeline.\n"
- ],
- "metadata": {
- "id": "tBtqF5UpKJNZ"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Before you begin\n",
- "Install the dependencies required to run this notebook.\n",
- "\n",
- "To use RunInference with side inputs for automatic model updates, use Apache Beam version 2.46.0 or later."
- ],
- "metadata": {
- "id": "SPuXFowiTpWx"
- }
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "id": "1RyTYsFEIOlA",
- "outputId": "0e6b88a7-82d8-4d94-951c-046a9b8b7abb",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }],
- "source": [
- "!pip install apache_beam[gcp]>=2.46.0 --quiet\n",
- "!pip install tensorflow\n",
- "!pip install tensorflow_hub"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "# Imports required for the notebook.\n",
- "import logging\n",
- "import time\n",
- "from typing import Iterable\n",
- "from typing import Tuple\n",
- "\n",
- "import apache_beam as beam\n",
- "from apache_beam.examples.inference.tensorflow_imagenet_segmentation import PostProcessor\n",
- "from apache_beam.examples.inference.tensorflow_imagenet_segmentation import read_image\n",
- "from apache_beam.ml.inference.base import PredictionResult\n",
- "from apache_beam.ml.inference.base import RunInference\n",
- "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor\n",
- "from apache_beam.ml.inference.utils import WatchFilePattern\n",
- "from apache_beam.options.pipeline_options import GoogleCloudOptions\n",
- "from apache_beam.options.pipeline_options import PipelineOptions\n",
- "from apache_beam.options.pipeline_options import SetupOptions\n",
- "from apache_beam.options.pipeline_options import StandardOptions\n",
- "from apache_beam.transforms.periodicsequence import PeriodicImpulse\n",
- "import numpy\n",
- "from PIL import Image\n",
- "import tensorflow as tf"
- ],
- "metadata": {
- "id": "Rs4cwwNrIV9H"
- },
- "execution_count": 2,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "code",
- "source": [
- "# Authenticate to your Google Cloud account.\n",
- "from google.colab import auth\n",
- "auth.authenticate_user()"
- ],
- "metadata": {
- "id": "jAKpPcmmGm03"
- },
- "execution_count": 3,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Configure the runner\n",
- "\n",
- "This pipeline uses the Dataflow Runner. To run the pipeline, you need to complete the following tasks:\n",
- "\n",
- "* Ensure that you have all the required permissions to run the pipeline on Dataflow.\n",
- "* Configure the pipeline options for the pipeline to run on Dataflow. Make sure the pipeline is using streaming mode.\n",
- "\n",
- "In the following code, replace `BUCKET_NAME` with the the name of your Cloud Storage bucket."
- ],
- "metadata": {
- "id": "ORYNKhH3WQyP"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "options = PipelineOptions()\n",
- "options.view_as(StandardOptions).streaming = True\n",
- "\n",
- "# Provide required pipeline options for the Dataflow Runner.\n",
- "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n",
- "\n",
- "# Set the project to the default project in your current Google Cloud environment.\n",
- "options.view_as(GoogleCloudOptions).project = 'your-project'\n",
- "\n",
- "# Set the Google Cloud region that you want to run Dataflow in.\n",
- "options.view_as(GoogleCloudOptions).region = 'us-central1'\n",
- "\n",
- "# IMPORTANT: Replace BUCKET_NAME with the the name of your Cloud Storage bucket.\n",
- "dataflow_gcs_location = \"gs://BUCKET_NAME/tmp/\"\n",
- "\n",
- "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n",
- "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n",
- "\n",
- "# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.\n",
- "options.view_as(GoogleCloudOptions).temp_location = '%s/temp' % dataflow_gcs_location\n",
- "\n"
- ],
- "metadata": {
- "id": "wWjbnq6X-4uE"
- },
- "execution_count": 4,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "Install the `tensorflow` and `tensorflow_hub` dependencies on Dataflow. Use the `requirements_file` pipeline option to pass these dependencies."
- ],
- "metadata": {
- "id": "HTJV8pO2Wcw4"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# In a requirements file, define the dependencies required for the pipeline.\n",
- "deps_required_for_pipeline = ['tensorflow>=2.12.0', 'tensorflow-hub>=0.10.0', 'Pillow>=9.0.0']\n",
- "requirements_file_path = './requirements.txt'\n",
- "# Write the dependencies to the requirements file.\n",
- "with open(requirements_file_path, 'w') as f:\n",
- " for dep in deps_required_for_pipeline:\n",
- " f.write(dep + '\\n')\n",
- "\n",
- "# Install the pipeline dependencies on Dataflow.\n",
- "options.view_as(SetupOptions).requirements_file = requirements_file_path"
- ],
- "metadata": {
- "id": "lEy4PkluWbdm"
- },
- "execution_count": 5,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Use the TensorFlow model handler\n",
- " This example uses `TFModelHandlerTensor` as the model handler and the `resnet_101` model trained on [ImageNet](https://www.image-net.org/).\n",
- "\n",
- " Download the model from [Google Cloud Storage](https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels.h5) (link downloads the model), and place it in the directory that you want to use to update your model.\n",
- "\n",
- "In the following code, replace `BUCKET_NAME` with the the name of your Cloud Storage bucket."
- ],
- "metadata": {
- "id": "_AUNH_GJk_NE"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "model_handler = TFModelHandlerTensor(\n",
- " model_uri=\"gs://BUCKET_NAME/resnet101_weights_tf_dim_ordering_tf_kernels.h5\")"
- ],
- "metadata": {
- "id": "kkSnsxwUk-Sp"
- },
- "execution_count": 6,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Preprocess images\n",
- "\n",
- "Use `preprocess_image` to run the inference, read the image, and convert the image to a TensorFlow tensor."
- ],
- "metadata": {
- "id": "tZH0r0sL-if5"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "def preprocess_image(image_name, image_dir):\n",
- " img = tf.keras.utils.get_file(image_name, image_dir + image_name)\n",
- " img = Image.open(img).resize((224, 224))\n",
- " img = numpy.array(img) / 255.0\n",
- " img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)\n",
- " return img_tensor"
- ],
- "metadata": {
- "id": "dU5imgTt-8Ne"
- },
- "execution_count": 7,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "code",
- "source": [
- "class PostProcessor(beam.DoFn):\n",
- " \"\"\"Process the PredictionResult to get the predicted label.\n",
- " Returns predicted label.\n",
- " \"\"\"\n",
- " def process(self, element: PredictionResult) -> Iterable[Tuple[str, str]]:\n",
- " predicted_class = numpy.argmax(element.inference, axis=-1)\n",
- " labels_path = tf.keras.utils.get_file(\n",
- " 'ImageNetLabels.txt',\n",
- " 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt' # pylint: disable=line-too-long\n",
- " )\n",
- " imagenet_labels = numpy.array(open(labels_path).read().splitlines())\n",
- " predicted_class_name = imagenet_labels[predicted_class]\n",
- " yield predicted_class_name.title(), element.model_id"
- ],
- "metadata": {
- "id": "6V5tJxO6-gyt"
- },
- "execution_count": 8,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "code",
- "source": [
- "# Define the pipeline object.\n",
- "pipeline = beam.Pipeline(options=options)"
- ],
- "metadata": {
- "id": "GpdKk72O_NXT",
- "outputId": "bcbaa8a6-0408-427a-de9e-78a6a7eefd7b",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 400
- }
- },
- "execution_count": 9,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "Next, review the pipeline steps and examine the code.\n",
- "\n",
- "### Pipeline steps\n"
- ],
- "metadata": {
- "id": "elZ53uxc_9Hv"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "1. Create a `PeriodicImpulse` transform, which emits output every `n` seconds. The `PeriodicImpulse` transform generates an infinite sequence of elements with a given runtime interval.\n",
- "\n",
- " In this example, `PeriodicImpulse` mimics the Pub/Sub source. Because the inputs in a streaming pipeline arrive in intervals, use `PeriodicImpulse` to output elements at `m` intervals.\n",
- "To learn more about `PeriodicImpulse`, see the [`PeriodicImpulse` code](https://github.com/apache/beam/blob/9c52e0594d6f0e59cd17ee005acfb41da508e0d5/sdks/python/apache_beam/transforms/periodicsequence.py#L150)."
- ],
- "metadata": {
- "id": "305tkV2sAD-S"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "start_timestamp = time.time() # start timestamp of the periodic impulse\n",
- "end_timestamp = start_timestamp + 60 * 20 # end timestamp of the periodic impulse (will run for 20 minutes).\n",
- "main_input_fire_interval = 60 # interval in seconds at which the main input PCollection is emitted.\n",
- "side_input_fire_interval = 60 # interval in seconds at which the side input PCollection is emitted.\n",
- "\n",
- "periodic_impulse = (\n",
- " pipeline\n",
- " | \"MainInputPcoll\" >> PeriodicImpulse(\n",
- " start_timestamp=start_timestamp,\n",
- " stop_timestamp=end_timestamp,\n",
- " fire_interval=main_input_fire_interval))"
- ],
- "metadata": {
- "id": "vUFStz66_Tbb",
- "outputId": "39f2704b-021e-4d41-fce3-a2fac90a5bad",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 133
- }
- },
- "execution_count": 10,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "2. To read and preprocess the images, use the `read_image` function. This example uses `Cat-with-beanie.jpg` for all inferences.\n",
- "\n",
- " **Note**: Image used for prediction is licensed in CC-BY. The creator is listed in the [LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt) file."
- ],
- "metadata": {
- "id": "8-sal2rFAxP2"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "![download.png]()"
- ],
- "metadata": {
- "id": "gW4cE8bhXS-d"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "image_data = (periodic_impulse | beam.Map(lambda x: \"Cat-with-beanie.jpg\")\n",
- " | \"ReadImage\" >> beam.Map(lambda image_name: read_image(\n",
- " image_name=image_name, image_dir='https://storage.googleapis.com/apache-beam-samples/image_captioning/')))"
- ],
- "metadata": {
- "id": "dGg11TpV_aV6",
- "outputId": "a57e8197-6756-4fd8-a664-f51ef2fea730",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 204
- }
- },
- "execution_count": 11,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "3. Pass the images to the RunInference `PTransform`. RunInference takes `model_handler` and `model_metadata_pcoll` as input parameters.\n",
- " * `model_metadata_pcoll` is a side input `PCollection` to the RunInference `PTransform`. This side input is used to update the `model_uri` in the `model_handler` without needing to stop the Apache Beam pipeline\n",
- " * Use `WatchFilePattern` as side input to watch a `file_pattern` matching `.h5` files. In this case, the `file_pattern` is `'gs://BUCKET_NAME/*.h5'`.\n",
- "\n"
- ],
- "metadata": {
- "id": "eB0-ewd-BCKE"
- }
- },
- {
- "cell_type": "code",
- "source": [
- " # The side input used to watch for the .h5 file and update the model_uri of the TFModelHandlerTensor.\n",
- "file_pattern = 'gs://BUCKET_NAME/*.h5'\n",
- "side_input_pcoll = (\n",
- " pipeline\n",
- " | \"WatchFilePattern\" >> WatchFilePattern(file_pattern=file_pattern,\n",
- " interval=side_input_fire_interval,\n",
- " stop_timestamp=end_timestamp))\n",
- "inferences = (\n",
- " image_data\n",
- " | \"ApplyWindowing\" >> beam.WindowInto(beam.window.FixedWindows(10))\n",
- " | \"RunInference\" >> RunInference(model_handler=model_handler,\n",
- " model_metadata_pcoll=side_input_pcoll))"
- ],
- "metadata": {
- "id": "_AjvvexJ_hUq",
- "outputId": "291fcc38-0abb-4b11-f840-4a850097a56f",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 133
- }
- },
- "execution_count": 12,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "4. Post-process the `PredictionResult` object.\n",
- "When the inference is complete, RunInference outputs a `PredictionResult` object that contains the fields `example`, `inference`, and `model_id`. The `model_id` field identifies the model used to run the inference. The `PostProcessor` returns the predicted label and the model ID used to run the inference on the predicted label."
- ],
- "metadata": {
- "id": "lTA4wRWNDVis"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "post_processor = (\n",
- " inferences\n",
- " | \"PostProcessResults\" >> beam.ParDo(PostProcessor())\n",
- " | \"LogResults\" >> beam.Map(logging.info))"
- ],
- "metadata": {
- "id": "9TB76fo-_vZJ",
- "outputId": "3e12d482-1bdf-4136-fbf7-9d5bb4bb62c3",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 222
- }
- },
- "execution_count": 13,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Watch for the model update\n",
- "\n",
- "After the pipeline starts processing data and when you see output emitted from the RunInference `PTransform`, upload a `resnet152` model saved in `.h5` format to a Google Cloud Storage bucket location that matches the `file_pattern` you defined earlier. You can [download a copy of the model](https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet152_weights_tf_dim_ordering_tf_kernels.h5) (link downloads the model). RunInference uses `WatchFilePattern` as a side input to update the `model_uri` of `TFModelHandlerTensor`."
- ],
- "metadata": {
- "id": "wYp-mBHHjOjA"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Run the pipeline\n",
- "\n",
- "Use the following code to run the pipeline."
- ],
- "metadata": {
- "id": "_ty03jDnKdKR"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# Run the pipeline.\n",
- "result = pipeline.run().wait_until_finish()"
- ],
- "metadata": {
- "id": "wd0VJLeLEWBU",
- "outputId": "3489c891-05d2-4739-d693-1899cfe78859",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 186
- }
- },
- "execution_count": 14,
- "outputs": [{
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "\n"
- ]
- }]
- }
- ]
-}
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
+ "\n",
+ "# Licensed to the Apache Software Foundation (ASF) under one\n",
+ "# or more contributor license agreements. See the NOTICE file\n",
+ "# distributed with this work for additional information\n",
+ "# regarding copyright ownership. The ASF licenses this file\n",
+ "# to you under the Apache License, Version 2.0 (the\n",
+ "# \"License\"); you may not use this file except in compliance\n",
+ "# with the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing,\n",
+ "# software distributed under the License is distributed on an\n",
+ "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+ "# KIND, either express or implied. See the License for the\n",
+ "# specific language governing permissions and limitations\n",
+ "# under the License"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "OsFaZscKSPvo"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Update ML models in running pipelines\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "ZUSiAR62SgO8"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This notebook demonstrates how to perform automatic model updates without stopping your Apache Beam pipeline.\n",
+ "You can use side inputs to update your model in real time, even while the Apache Beam pipeline is running. The side input is passed in a `ModelHandler` configuration object. You can update the model either by leveraging one of Apache Beam's provided patterns, such as the `WatchFilePattern`, or by configuring a custom side input `PCollection` that defines the logic for the model update.\n",
+ "\n",
+ "The pipeline in this notebook uses a RunInference `PTransform` with TensorFlow machine learning (ML) models to run inference on images. To update the model, it uses a side input `PCollection` that emits `ModelMetadata`.\n",
+ "For more information about side inputs, see the [Side inputs](https://beam.apache.org/documentation/programming-guide/#side-inputs) section in the Apache Beam Programming Guide.\n",
+ "\n",
+ "This example uses `WatchFilePattern` as a side input. `WatchFilePattern` is used to watch for file updates that match the `file_pattern` based on timestamps. It emits the latest `ModelMetadata`, which is used in the RunInference `PTransform` to automatically update the ML model without stopping the Apache Beam pipeline.\n"
+ ],
+ "metadata": {
+ "id": "tBtqF5UpKJNZ"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Before you begin\n",
+ "Install the dependencies required to run this notebook.\n",
+ "\n",
+ "To use RunInference with side inputs for automatic model updates, use Apache Beam version 2.46.0 or later."
+ ],
+ "metadata": {
+ "id": "SPuXFowiTpWx"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1RyTYsFEIOlA"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install apache_beam[gcp]>=2.46.0 --quiet\n",
+ "!pip install tensorflow --quiet\n",
+ "!pip install tensorflow_hub --quiet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Imports required for the notebook.\n",
+ "import logging\n",
+ "import time\n",
+ "from typing import Iterable\n",
+ "from typing import Tuple\n",
+ "\n",
+ "import apache_beam as beam\n",
+ "from apache_beam.ml.inference.base import PredictionResult\n",
+ "from apache_beam.ml.inference.base import RunInference\n",
+ "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor\n",
+ "from apache_beam.ml.inference.utils import WatchFilePattern\n",
+ "from apache_beam.options.pipeline_options import GoogleCloudOptions\n",
+ "from apache_beam.options.pipeline_options import PipelineOptions\n",
+ "from apache_beam.options.pipeline_options import SetupOptions\n",
+ "from apache_beam.options.pipeline_options import StandardOptions\n",
+ "from apache_beam.options.pipeline_options import WorkerOptions\n",
+ "from apache_beam.transforms.periodicsequence import PeriodicImpulse\n",
+ "import numpy\n",
+ "from PIL import Image\n",
+ "import tensorflow as tf"
+ ],
+ "metadata": {
+ "id": "Rs4cwwNrIV9H"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Authenticate to your Google Cloud account.\n",
+ "def auth_to_colab():\n",
+ " from google.colab import auth\n",
+ " auth.authenticate_user()\n",
+ "\n",
+ "auth_to_colab()"
+ ],
+ "metadata": {
+ "id": "jAKpPcmmGm03"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Configure the runner\n",
+ "\n",
+ "This pipeline uses the Dataflow Runner. To run the pipeline, you need to complete the following tasks:\n",
+ "\n",
+ "* Ensure that you have all the required permissions to run the pipeline on Dataflow.\n",
+ "* Configure the pipeline options for the pipeline to run on Dataflow. Make sure the pipeline is using streaming mode.\n",
+ "\n",
+ "In the following code, replace `BUCKET_NAME` with the the name of your Cloud Storage bucket."
+ ],
+ "metadata": {
+ "id": "ORYNKhH3WQyP"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "options = PipelineOptions()\n",
+ "options.view_as(StandardOptions).streaming = True\n",
+ "\n",
+ "BUCKET_NAME = '' # Replace with your bucket name.\n",
+ "\n",
+ "# Provide required pipeline options for the Dataflow Runner.\n",
+ "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n",
+ "\n",
+ "# Set the project to the default project in your current Google Cloud environment.\n",
+ "options.view_as(GoogleCloudOptions).project = ''\n",
+ "\n",
+ "# Set the Google Cloud region that you want to run Dataflow in.\n",
+ "options.view_as(GoogleCloudOptions).region = 'us-central1'\n",
+ "\n",
+ "# IMPORTANT: Replace BUCKET_NAME with the the name of your Cloud Storage bucket.\n",
+ "dataflow_gcs_location = \"gs://%s/dataflow\" % BUCKET_NAME\n",
+ "\n",
+ "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n",
+ "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n",
+ "\n",
+ "\n",
+ "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n",
+ "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n",
+ "\n",
+ "# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.\n",
+ "options.view_as(GoogleCloudOptions).temp_location = '%s/temp' % dataflow_gcs_location\n",
+ "\n",
+ "options.view_as(SetupOptions).save_main_session = True\n",
+ "\n",
+ "# Launching Dataflow with only one worker might result in processing delays due to\n",
+ "# initial input processing. This could further postpone the side input model updates.\n",
+ "# To expedite the model update process, it's recommended to set num_workers>1.\n",
+ "# https://github.com/apache/beam/issues/28776\n",
+ "options.view_as(WorkerOptions).num_workers = 5"
+ ],
+ "metadata": {
+ "id": "wWjbnq6X-4uE"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Install the `tensorflow` and `tensorflow_hub` dependencies on Dataflow. Use the `requirements_file` pipeline option to pass these dependencies."
+ ],
+ "metadata": {
+ "id": "HTJV8pO2Wcw4"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# In a requirements file, define the dependencies required for the pipeline.\n",
+ "!printf 'tensorflow>=2.12.0\\ntensorflow_hub>=0.10.0\\nPillow>=9.0.0' > ./requirements.txt\n",
+ "# Install the pipeline dependencies on Dataflow.\n",
+ "options.view_as(SetupOptions).requirements_file = './requirements.txt'"
+ ],
+ "metadata": {
+ "id": "lEy4PkluWbdm"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Use the TensorFlow model handler\n",
+ " This example uses `TFModelHandlerTensor` as the model handler and the `resnet_101` model trained on [ImageNet](https://www.image-net.org/).\n",
+ "\n",
+ "\n",
+ "For DataflowRunner, the model needs to be stored remote location accessible by the Beam pipeline. So we will download `ResNet101` model and upload it to the GCS location.\n"
+ ],
+ "metadata": {
+ "id": "_AUNH_GJk_NE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model = tf.keras.applications.resnet.ResNet101()\n",
+ "model.save('resnet101_weights_tf_dim_ordering_tf_kernels.keras')\n",
+ "# After saving the model locally, upload the model to GCS bucket and provide that gcs bucket `URI` as `model_uri` to the `TFModelHandler`\n",
+ "# Replace `BUCKET_NAME` value with actual bucket name.\n",
+ "!gsutil cp resnet101_weights_tf_dim_ordering_tf_kernels.keras gs:///dataflow/resnet101_weights_tf_dim_ordering_tf_kernels.keras"
+ ],
+ "metadata": {
+ "id": "ibkWiwVNvyrn"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model_handler = TFModelHandlerTensor(\n",
+ " model_uri=dataflow_gcs_location + \"/resnet101_weights_tf_dim_ordering_tf_kernels.keras\")"
+ ],
+ "metadata": {
+ "id": "kkSnsxwUk-Sp"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Preprocess images\n",
+ "\n",
+ "Use `preprocess_image` to run the inference, read the image, and convert the image to a TensorFlow tensor."
+ ],
+ "metadata": {
+ "id": "tZH0r0sL-if5"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def preprocess_image(image_name, image_dir):\n",
+ " img = tf.keras.utils.get_file(image_name, image_dir + image_name)\n",
+ " img = Image.open(img).resize((224, 224))\n",
+ " img = numpy.array(img) / 255.0\n",
+ " img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)\n",
+ " return img_tensor"
+ ],
+ "metadata": {
+ "id": "dU5imgTt-8Ne"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class PostProcessor(beam.DoFn):\n",
+ " \"\"\"Process the PredictionResult to get the predicted label.\n",
+ " Returns predicted label.\n",
+ " \"\"\"\n",
+ " def process(self, element: PredictionResult) -> Iterable[Tuple[str, str]]:\n",
+ " predicted_class = numpy.argmax(element.inference, axis=-1)\n",
+ " labels_path = tf.keras.utils.get_file(\n",
+ " 'ImageNetLabels.txt',\n",
+ " 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt' # pylint: disable=line-too-long\n",
+ " )\n",
+ " imagenet_labels = numpy.array(open(labels_path).read().splitlines())\n",
+ " predicted_class_name = imagenet_labels[predicted_class]\n",
+ " yield predicted_class_name.title(), element.model_id"
+ ],
+ "metadata": {
+ "id": "6V5tJxO6-gyt"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Define the pipeline object.\n",
+ "pipeline = beam.Pipeline(options=options)"
+ ],
+ "metadata": {
+ "id": "GpdKk72O_NXT"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Next, review the pipeline steps and examine the code.\n",
+ "\n",
+ "### Pipeline steps\n"
+ ],
+ "metadata": {
+ "id": "elZ53uxc_9Hv"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "1. Create a `PeriodicImpulse` transform, which emits output every `n` seconds. The `PeriodicImpulse` transform generates an infinite sequence of elements with a given runtime interval.\n",
+ "\n",
+ " In this example, `PeriodicImpulse` mimics the Pub/Sub source. Because the inputs in a streaming pipeline arrive in intervals, use `PeriodicImpulse` to output elements at `m` intervals.\n",
+ "To learn more about `PeriodicImpulse`, see the [`PeriodicImpulse` code](https://github.com/apache/beam/blob/9c52e0594d6f0e59cd17ee005acfb41da508e0d5/sdks/python/apache_beam/transforms/periodicsequence.py#L150)."
+ ],
+ "metadata": {
+ "id": "305tkV2sAD-S"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "start_timestamp = time.time() # start timestamp of the periodic impulse\n",
+ "end_timestamp = start_timestamp + 60 * 20 # end timestamp of the periodic impulse (will run for 20 minutes).\n",
+ "main_input_fire_interval = 60 # interval in seconds at which the main input PCollection is emitted.\n",
+ "side_input_fire_interval = 60 # interval in seconds at which the side input PCollection is emitted.\n",
+ "\n",
+ "periodic_impulse = (\n",
+ " pipeline\n",
+ " | \"MainInputPcoll\" >> PeriodicImpulse(\n",
+ " start_timestamp=start_timestamp,\n",
+ " stop_timestamp=end_timestamp,\n",
+ " fire_interval=main_input_fire_interval))"
+ ],
+ "metadata": {
+ "id": "vUFStz66_Tbb"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "2. To read and preprocess the images, use the `preprocess_image` function. This example uses `Cat-with-beanie.jpg` for all inferences.\n",
+ "\n",
+ " **Note**: Image used for prediction is licensed in CC-BY. The creator is listed in the [LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt) file."
+ ],
+ "metadata": {
+ "id": "8-sal2rFAxP2"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "![download.png]()"
+ ],
+ "metadata": {
+ "id": "gW4cE8bhXS-d"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "image_data = (periodic_impulse | beam.Map(lambda x: \"Cat-with-beanie.jpg\")\n",
+ " | \"ReadImage\" >> beam.Map(lambda image_name: preprocess_image(\n",
+ " image_name=image_name, image_dir='https://storage.googleapis.com/apache-beam-samples/image_captioning/')))"
+ ],
+ "metadata": {
+ "id": "dGg11TpV_aV6"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "3. Pass the images to the RunInference `PTransform`. RunInference takes `model_handler` and `model_metadata_pcoll` as input parameters.\n",
+ " * `model_metadata_pcoll` is a side input `PCollection` to the RunInference `PTransform`. This side input is used to update the `model_uri` in the `model_handler` without needing to stop the Apache Beam pipeline\n",
+ " * Use `WatchFilePattern` as side input to watch a `file_pattern` matching `.keras` files. In this case, the `file_pattern` is `'gs://BUCKET_NAME/dataflow/*keras'`.\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "eB0-ewd-BCKE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ " # The side input used to watch for the .keras file and update the model_uri of the TFModelHandlerTensor.\n",
+ "file_pattern = dataflow_gcs_location + '/*.keras'\n",
+ "side_input_pcoll = (\n",
+ " pipeline\n",
+ " | \"WatchFilePattern\" >> WatchFilePattern(file_pattern=file_pattern,\n",
+ " interval=side_input_fire_interval,\n",
+ " stop_timestamp=end_timestamp))\n",
+ "inferences = (\n",
+ " image_data\n",
+ " | \"ApplyWindowing\" >> beam.WindowInto(beam.window.FixedWindows(10))\n",
+ " | \"RunInference\" >> RunInference(model_handler=model_handler,\n",
+ " model_metadata_pcoll=side_input_pcoll))"
+ ],
+ "metadata": {
+ "id": "_AjvvexJ_hUq"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "4. Post-process the `PredictionResult` object.\n",
+ "When the inference is complete, RunInference outputs a `PredictionResult` object that contains the fields `example`, `inference`, and `model_id`. The `model_id` field identifies the model used to run the inference. The `PostProcessor` returns the predicted label and the model ID used to run the inference on the predicted label."
+ ],
+ "metadata": {
+ "id": "lTA4wRWNDVis"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "post_processor = (\n",
+ " inferences\n",
+ " | \"PostProcessResults\" >> beam.ParDo(PostProcessor())\n",
+ " | \"LogResults\" >> beam.Map(logging.info))"
+ ],
+ "metadata": {
+ "id": "9TB76fo-_vZJ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Watch for the model update\n",
+ "\n",
+ "After the pipeline starts processing data and when you see output emitted from the RunInference `PTransform`, upload a `resnet152` model saved in `.keras` format to a Google Cloud Storage bucket location that matches the `file_pattern` you defined earlier.\n"
+ ],
+ "metadata": {
+ "id": "wYp-mBHHjOjA"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model = tf.keras.applications.resnet.ResNet152()\n",
+ "model.save('resnet152_weights_tf_dim_ordering_tf_kernels.keras')\n",
+ "# Replace the `BUCKET_NAME` with the actual bucket name.\n",
+ "!gsutil cp resnet152_weights_tf_dim_ordering_tf_kernels.keras gs:///resnet152_weights_tf_dim_ordering_tf_kernels.keras"
+ ],
+ "metadata": {
+ "id": "FpUfNBSWH9Xy"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run the pipeline\n",
+ "\n",
+ "Use the following code to run the pipeline."
+ ],
+ "metadata": {
+ "id": "_ty03jDnKdKR"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Run the pipeline.\n",
+ "result = pipeline.run().wait_until_finish()"
+ ],
+ "metadata": {
+ "id": "wd0VJLeLEWBU"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/examples/notebooks/healthcare/beam_nlp.ipynb b/examples/notebooks/healthcare/beam_nlp.ipynb
index 5106aaa607d9..4ba4a5e0a739 100644
--- a/examples/notebooks/healthcare/beam_nlp.ipynb
+++ b/examples/notebooks/healthcare/beam_nlp.ipynb
@@ -146,7 +146,7 @@
{
"cell_type": "markdown",
"source": [
- "Then, download [this raw CSV file](https://https://github.com/socd06/medical-nlp/blob/master/data/test.csv), and then upload it into Colab. You should be able to view this file (*test.csv*) in the \"Files\" tab in Colab after uploading."
+ "Then, download [this raw CSV file](https://github.com/socd06/medical-nlp/blob/master/data/test.csv), and then upload it into Colab. You should be able to view this file (*test.csv*) in the \"Files\" tab in Colab after uploading."
],
"metadata": {
"id": "1IArtEm8QuCR"
diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar
index afba109285af..7f93135c49b7 100644
Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ
diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties
index 4e86b9270786..ac72c34e8acc 100644
--- a/gradle/wrapper/gradle-wrapper.properties
+++ b/gradle/wrapper/gradle-wrapper.properties
@@ -1,6 +1,7 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
-distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip
+distributionUrl=https\://services.gradle.org/distributions/gradle-8.3-bin.zip
networkTimeout=10000
+validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
diff --git a/gradlew b/gradlew
index 65dcd68d65c8..0adc8e1a5321 100755
--- a/gradlew
+++ b/gradlew
@@ -83,10 +83,8 @@ done
# This is normally unused
# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
-APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
-
-# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
-DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
+APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
@@ -133,10 +131,13 @@ location of your Java installation."
fi
else
JAVACMD=java
- which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+ if ! command -v java >/dev/null 2>&1
+ then
+ die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
+ fi
fi
# Increase the maximum file descriptors if we can.
@@ -144,7 +145,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
- # shellcheck disable=SC3045
+ # shellcheck disable=SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
@@ -152,7 +153,7 @@ if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
'' | soft) :;; #(
*)
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
- # shellcheck disable=SC3045
+ # shellcheck disable=SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
@@ -197,6 +198,10 @@ if "$cygwin" || "$msys" ; then
done
fi
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+
# Collect all arguments for the java command;
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
# shell script including quotes and variable substitutions, so put them in
diff --git a/playground/kafka-emulator/build.gradle b/playground/kafka-emulator/build.gradle
index 486a232f9b99..2d3f70aa9883 100644
--- a/playground/kafka-emulator/build.gradle
+++ b/playground/kafka-emulator/build.gradle
@@ -24,11 +24,11 @@ plugins {
applyJavaNature(exportJavadoc: false, publish: false)
distZip {
- archiveName "${baseName}.zip"
+ archiveFileName = "${archiveBaseName}.zip"
}
distTar {
- archiveName "${baseName}.tar"
+ archiveFileName = "${archiveBaseName}.tar"
}
dependencies {
diff --git a/sdks/go.mod b/sdks/go.mod
index e17427227eba..d817ae549857 100644
--- a/sdks/go.mod
+++ b/sdks/go.mod
@@ -28,7 +28,7 @@ require (
cloud.google.com/go/datastore v1.14.0
cloud.google.com/go/profiler v0.3.1
cloud.google.com/go/pubsub v1.33.0
- cloud.google.com/go/spanner v1.49.0
+ cloud.google.com/go/spanner v1.50.0
cloud.google.com/go/storage v1.33.0
github.com/aws/aws-sdk-go-v2 v1.21.0
github.com/aws/aws-sdk-go-v2/config v1.18.43
@@ -67,8 +67,8 @@ require (
)
require (
- github.com/fsouza/fake-gcs-server v1.47.4
- golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
+ github.com/fsouza/fake-gcs-server v1.47.5
+ golang.org/x/exp v0.0.0-20230807204917-050eac23e9de
)
require (
@@ -88,7 +88,7 @@ require (
cloud.google.com/go v0.110.7 // indirect
cloud.google.com/go/compute v1.23.0 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
- cloud.google.com/go/iam v1.1.1 // indirect
+ cloud.google.com/go/iam v1.1.2 // indirect
cloud.google.com/go/longrunning v0.5.1 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.1 // indirect
diff --git a/sdks/go.sum b/sdks/go.sum
index 71c1c4545c89..9f43e9a53abc 100644
--- a/sdks/go.sum
+++ b/sdks/go.sum
@@ -26,8 +26,8 @@ cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
cloud.google.com/go/datastore v1.14.0 h1:Mq0ApTRdLW3/dyiw+DkjTk0+iGIUvkbzaC8sfPwWTH4=
cloud.google.com/go/datastore v1.14.0/go.mod h1:GAeStMBIt9bPS7jMJA85kgkpsMkvseWWXiaHya9Jes8=
-cloud.google.com/go/iam v1.1.1 h1:lW7fzj15aVIXYHREOqjRBV9PsH0Z6u8Y46a1YGvQP4Y=
-cloud.google.com/go/iam v1.1.1/go.mod h1:A5avdyVL2tCppe4unb0951eI9jreack+RJ0/d+KUZOU=
+cloud.google.com/go/iam v1.1.2 h1:gacbrBdWcoVmGLozRuStX45YKvJtzIjJdAolzUs1sm4=
+cloud.google.com/go/iam v1.1.2/go.mod h1:A5avdyVL2tCppe4unb0951eI9jreack+RJ0/d+KUZOU=
cloud.google.com/go/kms v1.15.0 h1:xYl5WEaSekKYN5gGRyhjvZKM22GVBBCzegGNVPy+aIs=
cloud.google.com/go/longrunning v0.5.1 h1:Fr7TXftcqTudoyRJa113hyaqlGdiBQkp0Gq7tErFDWI=
cloud.google.com/go/longrunning v0.5.1/go.mod h1:spvimkwdz6SPWKEt/XBij79E9fiTkHSQl/fRUUQJYJc=
@@ -38,8 +38,8 @@ cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+
cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA=
cloud.google.com/go/pubsub v1.33.0 h1:6SPCPvWav64tj0sVX/+npCBKhUi/UjJehy9op/V3p2g=
cloud.google.com/go/pubsub v1.33.0/go.mod h1:f+w71I33OMyxf9VpMVcZbnG5KSUkCOUHYpFd5U1GdRc=
-cloud.google.com/go/spanner v1.49.0 h1:+HY8C4uztU7XyLz3xMi/LCXdetLEOExhvRFJu2NiVXM=
-cloud.google.com/go/spanner v1.49.0/go.mod h1:eGj9mQGK8+hkgSVbHNQ06pQ4oS+cyc4tXXd6Dif1KoM=
+cloud.google.com/go/spanner v1.50.0 h1:QrJFOpaxCXdXF+GkiruLz642PHxkdj68PbbnLw3O2Zw=
+cloud.google.com/go/spanner v1.50.0/go.mod h1:eGj9mQGK8+hkgSVbHNQ06pQ4oS+cyc4tXXd6Dif1KoM=
cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos=
cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk=
@@ -195,8 +195,8 @@ github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoD
github.com/frankban/quicktest v1.2.2/go.mod h1:Qh/WofXFeiAFII1aEBu529AtJo6Zg2VHscnEsbBnJ20=
github.com/frankban/quicktest v1.11.3 h1:8sXhOn0uLys67V8EsXLc6eszDs8VXWxL3iRvebPhedY=
github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k=
-github.com/fsouza/fake-gcs-server v1.47.4 h1:gfBhBxEra20/Om02cvcyL8EnekV8KDb01Yffjat6AKQ=
-github.com/fsouza/fake-gcs-server v1.47.4/go.mod h1:vqUZbI12uy9IkRQ54Q4p5AniQsSiUq8alO9Nv2egMmA=
+github.com/fsouza/fake-gcs-server v1.47.5 h1:o+wL01s01j/2OdkIaduDogXw2bZveq9TFb8f+BqEHtM=
+github.com/fsouza/fake-gcs-server v1.47.5/go.mod h1:PhN8F1rHAOCL5jWyXcw8nPfLfHnka6D9fT7ctL9nbkA=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
@@ -348,7 +348,7 @@ github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcs
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
-github.com/minio/minio-go/v7 v7.0.61 h1:87c+x8J3jxQ5VUGimV9oHdpjsAvy3fhneEBKuoKEVUI=
+github.com/minio/minio-go/v7 v7.0.63 h1:GbZ2oCvaUdgT5640WJOpyDhhDxvknAJU2/T3yurwcbQ=
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
github.com/moby/patternmatcher v0.5.0 h1:YCZgJOeULcxLw1Q+sVR636pmS7sPEn1Qo2iAN6M7DBo=
github.com/moby/patternmatcher v0.5.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
@@ -497,8 +497,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
-golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw=
-golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
+golang.org/x/exp v0.0.0-20230807204917-050eac23e9de h1:l5Za6utMv/HsBWWqzt4S8X17j+kt1uVETUX5UFhn2rE=
+golang.org/x/exp v0.0.0-20230807204917-050eac23e9de/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
diff --git a/sdks/java/container/common.gradle b/sdks/java/container/common.gradle
index bf4c122ca91f..cc427494ed6e 100644
--- a/sdks/java/container/common.gradle
+++ b/sdks/java/container/common.gradle
@@ -63,6 +63,8 @@ task copyDockerfileDependencies(type: Copy) {
task copySdkHarnessLauncher(type: Copy) {
dependsOn ":sdks:java:container:downloadCloudProfilerAgent"
+ // if licenses are required, they should be present before this task run.
+ mustRunAfter ":sdks:java:container:pullLicenses"
from configurations.sdkHarnessLauncher
into "build/target"
diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GceMetadataUtil.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GceMetadataUtil.java
index b853ab792e08..fd49b759fd6d 100644
--- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GceMetadataUtil.java
+++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/GceMetadataUtil.java
@@ -30,40 +30,60 @@
import org.apache.http.params.BasicHttpParams;
import org.apache.http.params.HttpConnectionParams;
import org.apache.http.params.HttpParams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/** */
public class GceMetadataUtil {
private static final String BASE_METADATA_URL = "http://metadata/computeMetadata/v1/";
+ private static final Logger LOG = LoggerFactory.getLogger(GceMetadataUtil.class);
+
static String fetchMetadata(String key) {
+ String requestUrl = BASE_METADATA_URL + key;
int timeoutMillis = 5000;
final HttpParams httpParams = new BasicHttpParams();
HttpConnectionParams.setConnectionTimeout(httpParams, timeoutMillis);
- HttpClient client = new DefaultHttpClient(httpParams);
- HttpGet request = new HttpGet(BASE_METADATA_URL + key);
- request.setHeader("Metadata-Flavor", "Google");
-
+ String ret = "";
try {
+ HttpClient client = new DefaultHttpClient(httpParams);
+
+ HttpGet request = new HttpGet(requestUrl);
+ request.setHeader("Metadata-Flavor", "Google");
+
HttpResponse response = client.execute(request);
- if (response.getStatusLine().getStatusCode() != 200) {
- // May mean its running on a non DataflowRunner, in which case it's perfectly normal.
- return "";
+ if (response.getStatusLine().getStatusCode() == 200) {
+ InputStream in = response.getEntity().getContent();
+ try (final Reader reader = new InputStreamReader(in, StandardCharsets.UTF_8)) {
+ ret = CharStreams.toString(reader);
+ }
}
- InputStream in = response.getEntity().getContent();
- try (final Reader reader = new InputStreamReader(in, StandardCharsets.UTF_8)) {
- return CharStreams.toString(reader);
- }
- } catch (IOException e) {
- // May mean its running on a non DataflowRunner, in which case it's perfectly normal.
+ } catch (IOException ignored) {
}
- return "";
+
+ // The return value can be an empty string, which may mean it's running on a non DataflowRunner.
+ LOG.debug("Fetched GCE Metadata at '{}' and got '{}'", requestUrl, ret);
+
+ return ret;
+ }
+
+ private static String fetchVmInstanceMetadata(String instanceMetadataKey) {
+ return GceMetadataUtil.fetchMetadata("instance/" + instanceMetadataKey);
}
private static String fetchCustomGceMetadata(String customMetadataKey) {
- return GceMetadataUtil.fetchMetadata("instance/attributes/" + customMetadataKey);
+ return GceMetadataUtil.fetchVmInstanceMetadata("attributes/" + customMetadataKey);
}
public static String fetchDataflowJobId() {
return GceMetadataUtil.fetchCustomGceMetadata("job_id");
}
+
+ public static String fetchDataflowJobName() {
+ return GceMetadataUtil.fetchCustomGceMetadata("job_name");
+ }
+
+ public static String fetchDataflowWorkerId() {
+ return GceMetadataUtil.fetchVmInstanceMetadata("id");
+ }
}
diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle
index 560b27aae162..c4a508680186 100644
--- a/sdks/java/io/google-cloud-platform/build.gradle
+++ b/sdks/java/io/google-cloud-platform/build.gradle
@@ -202,10 +202,8 @@ task integrationTest(type: Test, dependsOn: processTestResources) {
exclude '**/BigQueryIOReadIT.class'
exclude '**/BigQueryIOStorageQueryIT.class'
exclude '**/BigQueryIOStorageReadIT.class'
- exclude '**/BigQueryIOStorageReadTableRowIT.class'
exclude '**/BigQueryIOStorageWriteIT.class'
exclude '**/BigQueryToTableIT.class'
- exclude '**/BigQueryIOJsonTest.class'
maxParallelForks 4
classpath = sourceSets.test.runtimeClasspath
@@ -244,6 +242,48 @@ task integrationTestKms(type: Test) {
}
}
+/*
+ Integration tests for BigQueryIO that run on BigQuery's early rollout region (us-east7)
+ with the intended purpose of catching breaking changes from new BigQuery releases.
+ If these tests fail here but not in `Java_GCP_IO_Direct`, there may be a new BigQuery change
+ that is breaking the connector. If this is the case, we should verify with the appropriate
+ BigQuery infrastructure API team.
+
+ To test in a BigQuery location, we just need to create our datasets in that location.
+ */
+task bigQueryEarlyRolloutIntegrationTest(type: Test, dependsOn: processTestResources) {
+ group = "Verification"
+ def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing'
+ def gcpTempRoot = project.findProperty('gcpTempRoot') ?: 'gs://temp-storage-for-bigquery-day0-tests'
+ systemProperty "beamTestPipelineOptions", JsonOutput.toJson([
+ "--runner=DirectRunner",
+ "--project=${gcpProject}",
+ "--tempRoot=${gcpTempRoot}",
+ "--bigQueryLocation=us-east7",
+ ])
+
+ outputs.upToDateWhen { false }
+
+ // export and direct read
+ include '**/BigQueryToTableIT.class'
+ include '**/BigQueryIOJsonIT.class'
+ include '**/BigQueryIOStorageReadTableRowIT.class'
+ // storage write api
+ include '**/StorageApiDirectWriteProtosIT.class'
+ include '**/StorageApiSinkFailedRowsIT.class'
+ include '**/StorageApiSinkRowUpdateIT.class'
+ include '**/StorageApiSinkSchemaUpdateIT.class'
+ include '**/TableRowToStorageApiProtoIT.class'
+ // file loads
+ include '**/BigQuerySchemaUpdateOptionsIT.class'
+ include '**/BigQueryTimePartitioningClusteringIT.class'
+ include '**/FileLoadsStreamingIT.class'
+
+ maxParallelForks 4
+ classpath = sourceSets.test.runtimeClasspath
+ testClassesDirs = sourceSets.test.output.classesDirs
+}
+
// path(s) for Cloud Spanner related classes
def spannerIncludes = [
'**/org/apache/beam/sdk/io/gcp/spanner/**',
@@ -267,8 +307,8 @@ task spannerCodeCoverageReport(type: JacocoReport, dependsOn: test) {
sourceDirectories.setFrom(files(project.sourceSets.main.allSource.srcDirs))
executionData.setFrom(file("${buildDir}/jacoco/test.exec"))
reports {
- html.enabled true
- html.destination file("${buildDir}/reports/jacoco/spanner/")
+ html.getRequired().set(true)
+ html.getOutputLocation().set(file("${buildDir}/reports/jacoco/spanner/"))
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOMetadata.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOMetadata.java
index ee64a7ab9ddb..1893418dedb3 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOMetadata.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOMetadata.java
@@ -28,8 +28,15 @@ final class BigQueryIOMetadata {
private @Nullable String beamJobId;
- private BigQueryIOMetadata(@Nullable String beamJobId) {
+ private @Nullable String beamJobName;
+
+ private @Nullable String beamWorkerId;
+
+ private BigQueryIOMetadata(
+ @Nullable String beamJobId, @Nullable String beamJobName, @Nullable String beamWorkerId) {
this.beamJobId = beamJobId;
+ this.beamJobName = beamJobName;
+ this.beamWorkerId = beamWorkerId;
}
private static final Pattern VALID_CLOUD_LABEL_PATTERN =
@@ -41,17 +48,24 @@ private BigQueryIOMetadata(@Nullable String beamJobId) {
*/
public static BigQueryIOMetadata create() {
String dataflowJobId = GceMetadataUtil.fetchDataflowJobId();
+ String dataflowJobName = GceMetadataUtil.fetchDataflowJobName();
+ String dataflowWorkerId = GceMetadataUtil.fetchDataflowWorkerId();
+
// If a Dataflow job id is returned on GCE metadata. Then it means
// this program is running on a Dataflow GCE VM.
- boolean isDataflowRunner = dataflowJobId != null && !dataflowJobId.isEmpty();
+ boolean isDataflowRunner = !dataflowJobId.isEmpty();
String beamJobId = null;
+ String beamJobName = null;
+ String beamWorkerId = null;
if (isDataflowRunner) {
if (BigQueryIOMetadata.isValidCloudLabel(dataflowJobId)) {
beamJobId = dataflowJobId;
+ beamJobName = dataflowJobName;
+ beamWorkerId = dataflowWorkerId;
}
}
- return new BigQueryIOMetadata(beamJobId);
+ return new BigQueryIOMetadata(beamJobId, beamJobName, beamWorkerId);
}
public Map addAdditionalJobLabels(Map jobLabels) {
@@ -68,6 +82,20 @@ public Map addAdditionalJobLabels(Map jobLabels)
return this.beamJobId;
}
+ /*
+ * Returns the beam job name. Can be null if it is not running on Dataflow.
+ */
+ public @Nullable String getBeamJobName() {
+ return this.beamJobName;
+ }
+
+ /*
+ * Returns the beam worker id. Can be null if it is not running on Dataflow.
+ */
+ public @Nullable String getBeamWorkerId() {
+ return this.beamWorkerId;
+ }
+
/**
* Returns true if label_value is a valid cloud label string. This function can return false in
* cases where the label value is valid. However, it will not return true in a case where the
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
index 3d4565cb086e..1b6cc555511d 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryServicesImpl.java
@@ -1364,6 +1364,15 @@ public StreamAppendClient getStreamAppendClient(
.setChannelsPerCpu(2)
.build();
+ String traceId =
+ String.format(
+ "Dataflow:%s:%s:%s",
+ bqIOMetadata.getBeamJobName() == null
+ ? options.getJobName()
+ : bqIOMetadata.getBeamJobName(),
+ bqIOMetadata.getBeamJobId() == null ? "" : bqIOMetadata.getBeamJobId(),
+ bqIOMetadata.getBeamWorkerId() == null ? "" : bqIOMetadata.getBeamWorkerId());
+
StreamWriter streamWriter =
StreamWriter.newBuilder(streamName, newWriteClient)
.setExecutorProvider(
@@ -1374,11 +1383,7 @@ public StreamAppendClient getStreamAppendClient(
.setEnableConnectionPool(useConnectionPool)
.setMaxInflightRequests(storageWriteMaxInflightRequests)
.setMaxInflightBytes(storageWriteMaxInflightBytes)
- .setTraceId(
- "Dataflow:"
- + (bqIOMetadata.getBeamJobId() != null
- ? bqIOMetadata.getBeamJobId()
- : options.getJobName()))
+ .setTraceId(traceId)
.build();
return new StreamAppendClient() {
private int pins = 0;
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TestBigQueryOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TestBigQueryOptions.java
index 3574c12ee3a9..4d8095c1879d 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TestBigQueryOptions.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TestBigQueryOptions.java
@@ -24,10 +24,17 @@
/** {@link TestPipelineOptions} for {@link TestBigQuery}. */
public interface TestBigQueryOptions extends TestPipelineOptions, BigQueryOptions, GcpOptions {
+ String BIGQUERY_EARLY_ROLLOUT_REGION = "us-east7";
@Description("Dataset used in the integration tests. Default is integ_test")
@Default.String("integ_test")
String getTargetDataset();
void setTargetDataset(String value);
+
+ @Description("Region to perform BigQuery operations in.")
+ @Default.String("")
+ String getBigQueryLocation();
+
+ void setBigQueryLocation(String location);
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/common/GcpIoPipelineOptionsRegistrar.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/common/GcpIoPipelineOptionsRegistrar.java
index 1ed9ed6cb6c3..f1ff827fc633 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/common/GcpIoPipelineOptionsRegistrar.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/common/GcpIoPipelineOptionsRegistrar.java
@@ -20,6 +20,7 @@
import com.google.auto.service.AutoService;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryOptions;
+import org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions;
import org.apache.beam.sdk.io.gcp.firestore.FirestoreOptions;
import org.apache.beam.sdk.io.gcp.pubsub.PubsubOptions;
import org.apache.beam.sdk.options.PipelineOptions;
@@ -36,6 +37,7 @@ public Iterable> getPipelineOptions() {
.add(BigQueryOptions.class)
.add(PubsubOptions.class)
.add(FirestoreOptions.class)
+ .add(TestBigQueryOptions.class)
.build();
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java
index b21fdd669596..0e9476e6a226 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/BigqueryClient.java
@@ -292,6 +292,21 @@ private QueryResponse getTypedTableRows(QueryResponse response) {
public List queryUnflattened(
String query, String projectId, boolean typed, boolean useStandardSql)
throws IOException, InterruptedException {
+ return queryUnflattened(query, projectId, typed, useStandardSql, null);
+ }
+
+ /**
+ * Performs a query without flattening results. May choose a location (GCP region) to perform this
+ * operation in.
+ */
+ @Nonnull
+ public List queryUnflattened(
+ String query,
+ String projectId,
+ boolean typed,
+ boolean useStandardSql,
+ @Nullable String location)
+ throws IOException, InterruptedException {
Random rnd = new Random(System.currentTimeMillis());
String temporaryDatasetId =
String.format("_dataflow_temporary_dataset_%s_%s", System.nanoTime(), rnd.nextInt(1000000));
@@ -302,9 +317,11 @@ public List queryUnflattened(
.setDatasetId(temporaryDatasetId)
.setTableId(temporaryTableId);
- createNewDataset(projectId, temporaryDatasetId);
+ createNewDataset(projectId, temporaryDatasetId, null, location);
createNewTable(
- projectId, temporaryDatasetId, new Table().setTableReference(tempTableReference));
+ projectId,
+ temporaryDatasetId,
+ new Table().setTableReference(tempTableReference).setLocation(location));
JobConfigurationQuery jcQuery =
new JobConfigurationQuery()
@@ -325,6 +342,7 @@ public List queryUnflattened(
bqClient
.jobs()
.getQueryResults(projectId, insertedJob.getJobReference().getJobId())
+ .setLocation(location)
.execute();
} while (!qResponse.getJobComplete());
@@ -395,6 +413,18 @@ public void createNewDataset(String projectId, String datasetId)
public void createNewDataset(
String projectId, String datasetId, @Nullable Long defaultTableExpirationMs)
throws IOException, InterruptedException {
+ createNewDataset(projectId, datasetId, defaultTableExpirationMs, null);
+ }
+
+ /**
+ * Creates a new dataset with defaultTableExpirationMs and in a specified location (GCP region).
+ */
+ public void createNewDataset(
+ String projectId,
+ String datasetId,
+ @Nullable Long defaultTableExpirationMs,
+ @Nullable String location)
+ throws IOException, InterruptedException {
Sleeper sleeper = Sleeper.DEFAULT;
BackOff backoff = BackOffAdapter.toGcpBackOff(BACKOFF_FACTORY.backoff());
IOException lastException = null;
@@ -410,7 +440,8 @@ public void createNewDataset(
projectId,
new Dataset()
.setDatasetReference(new DatasetReference().setDatasetId(datasetId))
- .setDefaultTableExpirationMs(defaultTableExpirationMs))
+ .setDefaultTableExpirationMs(defaultTableExpirationMs)
+ .setLocation(location))
.execute();
if (response != null) {
LOG.info("Successfully created new dataset : " + response.getId());
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java
index 692a12c0f4a7..d355d6bb9336 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.io.gcp.bigquery;
+import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION;
+
import java.util.Map;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
@@ -52,7 +54,13 @@ public class BigQueryIOStorageQueryIT {
"1G", 11110839L,
"1T", 11110839000L);
- private static final String DATASET_ID = "big_query_storage";
+ private static final String DATASET_ID =
+ TestPipeline.testingPipelineOptions()
+ .as(TestBigQueryOptions.class)
+ .getBigQueryLocation()
+ .equals(BIGQUERY_EARLY_ROLLOUT_REGION)
+ ? "big_query_storage_day0"
+ : "big_query_storage";
private static final String TABLE_PREFIX = "storage_read_";
private BigQueryIOStorageQueryOptions options;
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java
index 570938470b9d..b4f6ddb76f72 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.gcp.bigquery;
+import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION;
import static org.junit.Assert.assertEquals;
import com.google.cloud.bigquery.storage.v1.DataFormat;
@@ -65,7 +66,13 @@ public class BigQueryIOStorageReadIT {
"1T", 11110839000L,
"multi_field", 11110839L);
- private static final String DATASET_ID = "big_query_storage";
+ private static final String DATASET_ID =
+ TestPipeline.testingPipelineOptions()
+ .as(TestBigQueryOptions.class)
+ .getBigQueryLocation()
+ .equals(BIGQUERY_EARLY_ROLLOUT_REGION)
+ ? "big_query_storage_day0"
+ : "big_query_storage";
private static final String TABLE_PREFIX = "storage_read_";
private BigQueryIOStorageReadOptions options;
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java
index 734c3af2c4d4..35e2676c70ef 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.io.gcp.bigquery;
+import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION;
+
import com.google.api.services.bigquery.model.TableRow;
import java.util.HashSet;
import java.util.Set;
@@ -52,7 +54,13 @@
@RunWith(JUnit4.class)
public class BigQueryIOStorageReadTableRowIT {
- private static final String DATASET_ID = "big_query_import_export";
+ private static final String DATASET_ID =
+ TestPipeline.testingPipelineOptions()
+ .as(TestBigQueryOptions.class)
+ .getBigQueryLocation()
+ .equals(BIGQUERY_EARLY_ROLLOUT_REGION)
+ ? "big_query_import_export_day0"
+ : "big_query_import_export";
private static final String TABLE_PREFIX = "parallel_read_table_row_";
private BigQueryIOStorageReadTableRowOptions options;
@@ -67,12 +75,11 @@ public interface BigQueryIOStorageReadTableRowOptions
void setInputTable(String table);
}
- private static class TableRowToKVPairFn extends SimpleFunction> {
+ private static class TableRowToKVPairFn extends SimpleFunction> {
@Override
- public KV apply(TableRow input) {
- CharSequence sampleString = (CharSequence) input.get("sample_string");
- String key = sampleString != null ? sampleString.toString() : "null";
- return KV.of(key, BigQueryHelpers.toJsonString(input));
+ public KV apply(TableRow input) {
+ Integer rowId = Integer.parseInt((String) input.get("id"));
+ return KV.of(rowId, BigQueryHelpers.toJsonString(input));
}
}
@@ -87,7 +94,7 @@ private void setUpTestEnvironment(String tableName) {
private static void runPipeline(BigQueryIOStorageReadTableRowOptions pipelineOptions) {
Pipeline pipeline = Pipeline.create(pipelineOptions);
- PCollection> jsonTableRowsFromExport =
+ PCollection> jsonTableRowsFromExport =
pipeline
.apply(
"ExportTable",
@@ -96,7 +103,7 @@ private static void runPipeline(BigQueryIOStorageReadTableRowOptions pipelineOpt
.withMethod(Method.EXPORT))
.apply("MapExportedRows", MapElements.via(new TableRowToKVPairFn()));
- PCollection> jsonTableRowsFromDirectRead =
+ PCollection> jsonTableRowsFromDirectRead =
pipeline
.apply(
"DirectReadTable",
@@ -108,16 +115,16 @@ private static void runPipeline(BigQueryIOStorageReadTableRowOptions pipelineOpt
final TupleTag exportTag = new TupleTag<>();
final TupleTag directReadTag = new TupleTag<>();
- PCollection>> unmatchedRows =
+ PCollection>> unmatchedRows =
KeyedPCollectionTuple.of(exportTag, jsonTableRowsFromExport)
.and(directReadTag, jsonTableRowsFromDirectRead)
.apply(CoGroupByKey.create())
.apply(
ParDo.of(
- new DoFn, KV>>() {
+ new DoFn, KV>>() {
@ProcessElement
- public void processElement(ProcessContext c) throws Exception {
- KV element = c.element();
+ public void processElement(ProcessContext c) {
+ KV element = c.element();
// Add all the exported rows for the key to a collection.
Set uniqueRows = new HashSet<>();
@@ -147,20 +154,20 @@ public void processElement(ProcessContext c) throws Exception {
}
@Test
- public void testBigQueryStorageReadTableRow1() throws Exception {
- setUpTestEnvironment("1");
+ public void testBigQueryStorageReadTableRow100() {
+ setUpTestEnvironment("100");
runPipeline(options);
}
@Test
- public void testBigQueryStorageReadTableRow10k() throws Exception {
- setUpTestEnvironment("10k");
+ public void testBigQueryStorageReadTableRow1k() {
+ setUpTestEnvironment("1K");
runPipeline(options);
}
@Test
- public void testBigQueryStorageReadTableRow100k() throws Exception {
- setUpTestEnvironment("100k");
+ public void testBigQueryStorageReadTableRow10k() {
+ setUpTestEnvironment("10K");
runPipeline(options);
}
}
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java
index fc3ce0be4b69..d061898d55c7 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageWriteIT.java
@@ -26,11 +26,11 @@
import com.google.api.services.bigquery.model.TableRow;
import com.google.api.services.bigquery.model.TableSchema;
import java.io.IOException;
+import java.security.SecureRandom;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.gcp.testing.BigqueryClient;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.MapElements;
@@ -43,6 +43,8 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.joda.time.Duration;
import org.joda.time.Instant;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -60,24 +62,37 @@ private enum WriteMode {
AT_LEAST_ONCE
}
- private String project;
- private static final String DATASET_ID = "big_query_storage";
+ private static String project;
+ private static final String DATASET_ID =
+ "big_query_storage_write_it_"
+ + System.currentTimeMillis()
+ + "_"
+ + new SecureRandom().nextInt(32);
private static final String TABLE_PREFIX = "storage_write_";
- private BigQueryOptions bqOptions;
+ private static TestBigQueryOptions bqOptions;
private static final BigqueryClient BQ_CLIENT = new BigqueryClient("BigQueryStorageIOWriteIT");
+ @BeforeClass
+ public static void setup() throws Exception {
+ bqOptions = TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class);
+ project = bqOptions.as(GcpOptions.class).getProject();
+ // Create one BQ dataset for all test cases.
+ BQ_CLIENT.createNewDataset(project, DATASET_ID, null, bqOptions.getBigQueryLocation());
+ }
+
+ @AfterClass
+ public static void cleanup() {
+ BQ_CLIENT.deleteDataset(project, DATASET_ID);
+ }
+
private void setUpTestEnvironment(WriteMode writeMode) {
- PipelineOptionsFactory.register(BigQueryOptions.class);
- bqOptions = TestPipeline.testingPipelineOptions().as(BigQueryOptions.class);
- bqOptions.setProject(TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject());
bqOptions.setUseStorageWriteApi(true);
if (writeMode == WriteMode.AT_LEAST_ONCE) {
bqOptions.setUseStorageWriteApiAtLeastOnce(true);
}
bqOptions.setNumStorageWriteApiStreams(2);
bqOptions.setStorageWriteApiTriggeringFrequencySec(1);
- project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject();
}
static class FillRowFn extends DoFn {
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java
index 611c691dca12..833a0a0829c7 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java
@@ -87,7 +87,11 @@ public class BigQuerySchemaUpdateOptionsIT {
@BeforeClass
public static void setupTestEnvironment() throws Exception {
project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject();
- BQ_CLIENT.createNewDataset(project, BIG_QUERY_DATASET_ID);
+ BQ_CLIENT.createNewDataset(
+ project,
+ BIG_QUERY_DATASET_ID,
+ null,
+ TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation());
}
@AfterClass
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTimePartitioningClusteringIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTimePartitioningClusteringIT.java
index 3ceb6f0966b7..da5f396e8d89 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTimePartitioningClusteringIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTimePartitioningClusteringIT.java
@@ -24,9 +24,11 @@
import com.google.api.services.bigquery.model.TableRow;
import com.google.api.services.bigquery.model.TableSchema;
import com.google.api.services.bigquery.model.TimePartitioning;
+import java.security.SecureRandom;
import java.util.Arrays;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.io.gcp.testing.BigqueryClient;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
@@ -38,8 +40,10 @@
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.ValueInSingleWindow;
import org.checkerframework.checker.nullness.qual.Nullable;
+import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
+import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -49,7 +53,15 @@
public class BigQueryTimePartitioningClusteringIT {
private static final String WEATHER_SAMPLES_TABLE =
"apache-beam-testing.samples.weather_stations";
- private static final String DATASET_NAME = "BigQueryTimePartitioningIT";
+
+ private static String project;
+ private static final BigqueryClient BQ_CLIENT =
+ new BigqueryClient("BigQueryTimePartitioningClusteringIT");
+ private static final String DATASET_NAME =
+ "BigQueryTimePartitioningIT_"
+ + System.currentTimeMillis()
+ + "_"
+ + new SecureRandom().nextInt(32);
private static final TimePartitioning TIME_PARTITIONING =
new TimePartitioning().setField("date").setType("DAY");
private static final Clustering CLUSTERING =
@@ -64,6 +76,16 @@ public class BigQueryTimePartitioningClusteringIT {
private Bigquery bqClient;
private BigQueryClusteringITOptions options;
+ @BeforeClass
+ public static void setupTestEnvironment() throws Exception {
+ project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject();
+ BQ_CLIENT.createNewDataset(
+ project,
+ DATASET_NAME,
+ null,
+ TestPipeline.testingPipelineOptions().as(TestBigQueryOptions.class).getBigQueryLocation());
+ }
+
@Before
public void setUp() {
PipelineOptionsFactory.register(BigQueryClusteringITOptions.class);
@@ -72,6 +94,11 @@ public void setUp() {
bqClient = BigqueryClient.getNewBigqueryClient(options.getAppName());
}
+ @AfterClass
+ public static void cleanup() {
+ BQ_CLIENT.deleteDataset(project, DATASET_NAME);
+ }
+
/** Customized PipelineOptions for BigQueryClustering Integration Test. */
public interface BigQueryClusteringITOptions
extends TestPipelineOptions, ExperimentalOptions, BigQueryOptions {
@@ -110,8 +137,7 @@ public ClusteredDestinations(String tableName) {
@Override
public TableDestination getDestination(ValueInSingleWindow element) {
- return new TableDestination(
- String.format("%s.%s", DATASET_NAME, tableName), null, TIME_PARTITIONING, CLUSTERING);
+ return new TableDestination(tableName, null, TIME_PARTITIONING, CLUSTERING);
}
@Override
@@ -176,6 +202,7 @@ public void testE2EBigQueryClustering() throws Exception {
@Test
public void testE2EBigQueryClusteringTableFunction() throws Exception {
String tableName = "weather_stations_clustered_table_function_" + System.currentTimeMillis();
+ String destination = String.format("%s.%s", DATASET_NAME, tableName);
Pipeline p = Pipeline.create(options);
@@ -185,11 +212,7 @@ public void testE2EBigQueryClusteringTableFunction() throws Exception {
BigQueryIO.writeTableRows()
.to(
(ValueInSingleWindow vsw) ->
- new TableDestination(
- String.format("%s.%s", DATASET_NAME, tableName),
- null,
- TIME_PARTITIONING,
- CLUSTERING))
+ new TableDestination(destination, null, TIME_PARTITIONING, CLUSTERING))
.withClustering()
.withSchema(SCHEMA)
.withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED)
@@ -206,6 +229,7 @@ public void testE2EBigQueryClusteringTableFunction() throws Exception {
public void testE2EBigQueryClusteringDynamicDestinations() throws Exception {
String tableName =
"weather_stations_clustered_dynamic_destinations_" + System.currentTimeMillis();
+ String destination = String.format("%s.%s", DATASET_NAME, tableName);
Pipeline p = Pipeline.create(options);
@@ -213,7 +237,7 @@ public void testE2EBigQueryClusteringDynamicDestinations() throws Exception {
.apply(ParDo.of(new KeepStationNumberAndConvertDate()))
.apply(
BigQueryIO.writeTableRows()
- .to(new ClusteredDestinations(tableName))
+ .to(new ClusteredDestinations(destination))
.withClustering()
.withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED)
.withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE));
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryToTableIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryToTableIT.java
index d6b7f8e16412..1abe7752b2e0 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryToTableIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryToTableIT.java
@@ -46,7 +46,6 @@
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.Validation;
import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.testing.TestPipelineOptions;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.WithKeys;
@@ -214,7 +213,7 @@ private void verifyStandardQueryRes(String outputTable) throws Exception {
}
/** Customized PipelineOption for BigQueryToTable Pipeline. */
- public interface BigQueryToTableOptions extends TestPipelineOptions, ExperimentalOptions {
+ public interface BigQueryToTableOptions extends TestBigQueryOptions, ExperimentalOptions {
@Description("The BigQuery query to be used for creating the source")
@Validation.Required
@@ -252,9 +251,11 @@ public interface BigQueryToTableOptions extends TestPipelineOptions, Experimenta
@BeforeClass
public static void setupTestEnvironment() throws Exception {
PipelineOptionsFactory.register(BigQueryToTableOptions.class);
- project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject();
+ BigQueryToTableOptions options =
+ TestPipeline.testingPipelineOptions().as(BigQueryToTableOptions.class);
+ project = options.as(GcpOptions.class).getProject();
// Create one BQ dataset for all test cases.
- BQ_CLIENT.createNewDataset(project, BIG_QUERY_DATASET_ID);
+ BQ_CLIENT.createNewDataset(project, BIG_QUERY_DATASET_ID, null, options.getBigQueryLocation());
// Create table and insert data for new type query test cases.
BQ_CLIENT.createNewTable(
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FileLoadsStreamingIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FileLoadsStreamingIT.java
index 012afed6fb43..678708062b8d 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FileLoadsStreamingIT.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/FileLoadsStreamingIT.java
@@ -106,11 +106,16 @@ public static Iterable