diff --git a/examples/notebooks/beam-ml/mltransform_basic.ipynb b/examples/notebooks/beam-ml/mltransform_basic.ipynb
new file mode 100644
index 000000000000..820bc3400b58
--- /dev/null
+++ b/examples/notebooks/beam-ml/mltransform_basic.ipynb
@@ -0,0 +1,679 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
+ "\n",
+ "# Licensed to the Apache Software Foundation (ASF) under one\n",
+ "# or more contributor license agreements. See the NOTICE file\n",
+ "# distributed with this work for additional information\n",
+ "# regarding copyright ownership. The ASF licenses this file\n",
+ "# to you under the Apache License, Version 2.0 (the\n",
+ "# \"License\"); you may not use this file except in compliance\n",
+ "# with the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing,\n",
+ "# software distributed under the License is distributed on an\n",
+ "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+ "# KIND, either express or implied. See the License for the\n",
+ "# specific language governing permissions and limitations\n",
+ "# under the License"
+ ],
+ "metadata": {
+ "id": "34gTXZ7BIArp"
+ },
+ "id": "34gTXZ7BIArp",
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Preprocess data with MLTransform\n",
+ "\n",
+ "
\n"
+ ],
+ "metadata": {
+ "id": "0n0YAd-0KQyi"
+ },
+ "id": "0n0YAd-0KQyi"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d3b81cf2-8603-42bd-995e-9e14631effd0",
+ "metadata": {
+ "id": "d3b81cf2-8603-42bd-995e-9e14631effd0"
+ },
+ "source": [
+ "This notebook demonstrates how to use `MLTransform` to preprocess your data for machine learning models. `MLTransform` is a `PTransform` that wraps multiple Apache Beam data processing transforms. As a result, `MLTransform` gives you the ability to preprocess different types of data in multiple ways with one transform.\n",
+ "\n",
+ "This notebook uses data processing transforms defined in the [apache_beam/ml/transforms/tft](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.tft.html) module."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f0097dbd-2657-4cbe-a334-e0401816db01",
+ "metadata": {
+ "id": "f0097dbd-2657-4cbe-a334-e0401816db01"
+ },
+ "source": [
+ "## Import the requried modules\n",
+ "\n",
+ "To use `MLTransfrom`, install `tensorflow_transform` and the Apache Beam SDK version 2.50.0 or later.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install tensorflow_transform --quiet\n",
+ "!pip install apache_beam>=2.50.0 --quiet"
+ ],
+ "metadata": {
+ "id": "MRWkC-n2DmjM"
+ },
+ "id": "MRWkC-n2DmjM",
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "88ddd3a4-3643-4731-b99e-a5d697fbc165",
+ "metadata": {
+ "id": "88ddd3a4-3643-4731-b99e-a5d697fbc165"
+ },
+ "outputs": [],
+ "source": [
+ "import apache_beam as beam\n",
+ "from apache_beam.ml.transforms.base import MLTransform\n",
+ "from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary\n",
+ "from apache_beam.options.pipeline_options import PipelineOptions\n",
+ "from apache_beam.ml.transforms.utils import ArtifactsFetcher"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Artifacts are additional data elements created by data transformations. Examples of artifacts are the `minimum` and `maximum` values from a `ScaleTo01` transformation, or the `mean` and `variance` from a `ScaleToZScore` transformation. For more information about artifacts, see [Artifacts](https://beam.apache.org/documentation/ml/preprocess-data/#artifacts).\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "90nXXc_A4Bmf"
+ },
+ "id": "90nXXc_A4Bmf"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bdabbc57-ec98-4113-b37e-61962f488d61",
+ "metadata": {
+ "id": "bdabbc57-ec98-4113-b37e-61962f488d61"
+ },
+ "outputs": [],
+ "source": [
+ "# Store artifacts generated by MLTransform.\n",
+ "# Each MLTransform instance requires an empty artifact location.\n",
+ "# This method deletes and refreshes the artifact location for each example.\n",
+ "artifact_location = './my_artifacts'\n",
+ "def delete_artifact_location(artifact_location):\n",
+ " import shutil\n",
+ " import os\n",
+ " if os.path.exists(artifact_location):\n",
+ " shutil.rmtree(artifact_location)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "28b1719c-7287-4cec-870b-9fabc4c4a4ef",
+ "metadata": {
+ "id": "28b1719c-7287-4cec-870b-9fabc4c4a4ef"
+ },
+ "source": [
+ "## Compute and map the vocabulary\n",
+ "\n",
+ "\n",
+ "[ComputeAndApplyVocabulary](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.tft.html#apache_beam.ml.transforms.tft.ComputeAndApplyVocabulary) is a data processing transform that computes a unique vocabulary from a dataset and then maps each word or token to a distinct integer index. It facilitates transforming textual data into numerical representations for machine learning tasks.\n",
+ "\n",
+ "Use `ComputeAndApplyVocabulary` with `MLTransform`.\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "56d6d09a-8d34-444f-a1e4-a75624b36932",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "56d6d09a-8d34-444f-a1e4-a75624b36932",
+ "outputId": "2eb99e87-fb23-498c-ed08-775befa3a823"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Row(x=array([1, 0, 4]))\n",
+ "Row(x=array([1, 0, 6, 2, 3, 5]))\n"
+ ]
+ }
+ ],
+ "source": [
+ "delete_artifact_location(artifact_location)\n",
+ "\n",
+ "data = [\n",
+ " {'x': ['I', 'love', 'pie']},\n",
+ " {'x': ['I', 'love', 'going', 'to', 'the', 'park']}\n",
+ "]\n",
+ "options = PipelineOptions()\n",
+ "with beam.Pipeline(options=options) as p:\n",
+ " data = (\n",
+ " p\n",
+ " | 'CreateData' >> beam.Create(data)\n",
+ " | 'MLTransform' >> MLTransform(write_artifact_location=artifact_location).with_transform(ComputeAndApplyVocabulary(columns=['x']))\n",
+ " | 'PrintResults' >> beam.Map(print)\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1e133002-7229-459d-8e3c-b41f4d65e76d",
+ "metadata": {
+ "id": "1e133002-7229-459d-8e3c-b41f4d65e76d"
+ },
+ "source": [
+ "### Fetch vocabulary artifacts\n",
+ "\n",
+ "This example generates a file with all the vocabulary in the dataset, referred to in `MLTransform` as an artifact. To fetch artifacts generated by the `ComputeAndApplyVocabulary` transform, use the `ArtifactsFetcher` class. This class fetches both a vocabulary list and a path to the vocabulary file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9c5fe46a-c718-4a82-bad8-aa091c0b0538",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "9c5fe46a-c718-4a82-bad8-aa091c0b0538",
+ "outputId": "cd8b6cf3-6093-4b1b-a063-ff327c090a92"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "['love', 'I', 'to', 'the', 'pie', 'park', 'going']\n",
+ "./my_artifacts/transform_fn/assets/compute_and_apply_vocab\n",
+ "7\n"
+ ]
+ }
+ ],
+ "source": [
+ "fetcher = ArtifactsFetcher(artifact_location=artifact_location)\n",
+ "# get vocab list\n",
+ "vocab_list = fetcher.get_vocab_list()\n",
+ "print(vocab_list)\n",
+ "# get vocab file path\n",
+ "vocab_file_path = fetcher.get_vocab_filepath()\n",
+ "print(vocab_file_path)\n",
+ "# get vocab size\n",
+ "vocab_size = fetcher.get_vocab_size()\n",
+ "print(vocab_size)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5f955f3d-3192-42f7-aa55-48249223418d",
+ "metadata": {
+ "id": "5f955f3d-3192-42f7-aa55-48249223418d"
+ },
+ "source": [
+ "## Use TD-IDF to weight terms\n",
+ "\n",
+ "[TF-IDF](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.tft.html#apache_beam.ml.transforms.tft.ComputeAndApplyVocabulary) (Term Frequency-Inverse Document Frequency) is a numerical statistic used in text processing to reflect how important a word is to a document in a collection or corpus. It balances the frequency of a word in a document against its frequency in the entire corpus, giving higher value to more specific terms.\n",
+ "\n",
+ "Use `TF-IDF` with `MLTransform`.\n",
+ "\n",
+ "1. Compute the vocabulary of the dataset by using `ComputeAndApplyVocabulary`.\n",
+ "2. Use the output of `ComputeAndApplyVocabulary` to calculate the `TF-IDF` weights.\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8a8cb94b-57eb-4c4c-aa4c-22cf3193ea85",
+ "metadata": {
+ "id": "8a8cb94b-57eb-4c4c-aa4c-22cf3193ea85"
+ },
+ "outputs": [],
+ "source": [
+ "from apache_beam.ml.transforms.tft import TFIDF"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "970d7222-194e-460e-b698-a00f1fcafb95",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "970d7222-194e-460e-b698-a00f1fcafb95",
+ "outputId": "e87409ed-5e33-43fa-d3b6-a0c012636cef"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Row(x=array([1, 0, 4]), x_tfidf_weight=array([0.33333334, 0.33333334, 0.4684884 ], dtype=float32), x_vocab_index=array([0, 1, 4]))\n",
+ "Row(x=array([1, 0, 6, 2, 3, 5]), x_tfidf_weight=array([0.16666667, 0.16666667, 0.2342442 , 0.2342442 , 0.2342442 ,\n",
+ " 0.2342442 ], dtype=float32), x_vocab_index=array([0, 1, 2, 3, 5, 6]))\n"
+ ]
+ }
+ ],
+ "source": [
+ "data = [\n",
+ " {'x': ['I', 'love', 'pie']},\n",
+ " {'x': ['I', 'love', 'going', 'to', 'the', 'park']}\n",
+ "]\n",
+ "delete_artifact_location(artifact_location)\n",
+ "options = PipelineOptions()\n",
+ "with beam.Pipeline(options=options) as p:\n",
+ " data = (\n",
+ " p\n",
+ " | beam.Create(data)\n",
+ " | MLTransform(write_artifact_location=artifact_location\n",
+ " ).with_transform(ComputeAndApplyVocabulary(columns=['x'])\n",
+ " ).with_transform(TFIDF(columns=['x']))\n",
+ " )\n",
+ " _ = data | beam.Map(print)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7b1feb4f-bb0b-4f61-8349-e1ba411858cf",
+ "metadata": {
+ "id": "7b1feb4f-bb0b-4f61-8349-e1ba411858cf"
+ },
+ "source": [
+ "### TF-IDF output\n",
+ "\n",
+ "`TF-IDF` produces two output columns for a given input. For example, if you input `x`, the output column names in the dictionary are `x_vocab_index` and `x_tfidf_weight`.\n",
+ "\n",
+ "- `vocab_index`: indices of the words computed in the `ComputeAndApplyVocabulary` transform.\n",
+ "- `tfidif_weight`: the weight for each vocabulary index. The weight represents how important the word present at that `vocab_index` is to the document.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d3b5b9dd-ed35-460b-9fb3-0ffb5c3633db",
+ "metadata": {
+ "id": "d3b5b9dd-ed35-460b-9fb3-0ffb5c3633db"
+ },
+ "source": [
+ "## Scale the data\n",
+ "\n",
+ "The following examples show two ways to scale data:\n",
+ "\n",
+ "* Scale data between 0 and 1.\n",
+ "* Scale data using z-score.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3bd20692-6d14-4ece-a2e7-69a2a6fac5d4",
+ "metadata": {
+ "id": "3bd20692-6d14-4ece-a2e7-69a2a6fac5d4"
+ },
+ "source": [
+ "### Scale the data between 0 and 1\n",
+ "\n",
+ "Scale the data so that it's in the range of 0 and 1. To scale the data, the transform calculates `minimum` and `maximum` values on the whole dataset, and then performs the following calculation:\n",
+ "\n",
+ "`x = (x - x_min) / (x_max)`\n",
+ "\n",
+ "To scale the data, use the [ScaleTo01](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.tft.html#apache_beam.ml.transforms.tft.ScaleTo01) data processing transform in `MLTransform`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "841a8e1f-2f5b-4fd9-bb35-12a2393922de",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "841a8e1f-2f5b-4fd9-bb35-12a2393922de",
+ "outputId": "efcae38d-96f6-4394-e5f5-c36644d3a9ff"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Row(x=array([0. , 0.01010101, 0.02020202], dtype=float32), x_max=array([100.], dtype=float32), x_min=array([1.], dtype=float32))\n",
+ "Row(x=array([0.03030303, 0.04040404, 0.06060606], dtype=float32), x_max=array([100.], dtype=float32), x_min=array([1.], dtype=float32))\n",
+ "Row(x=array([0.09090909, 0.01010101, 0.09090909, 0.33333334, 1. ,\n",
+ " 0.53535354, 0.1919192 , 0.09090909, 0.01010101, 0.02020202,\n",
+ " 0.1010101 , 0.11111111], dtype=float32), x_max=array([100.], dtype=float32), x_min=array([1.], dtype=float32))\n"
+ ]
+ }
+ ],
+ "source": [
+ "delete_artifact_location(artifact_location)\n",
+ "\n",
+ "from apache_beam.ml.transforms.tft import ScaleTo01\n",
+ "data = [\n",
+ " {'x': [1, 2, 3]}, {'x': [4, 5, 7]}, {'x': [10, 2, 10, 34, 100, 54, 20, 10, 2, 3, 11, 12]}]\n",
+ "\n",
+ "with beam.Pipeline() as p:\n",
+ " _ = (\n",
+ " p\n",
+ " | 'CreateData' >> beam.Create(data)\n",
+ " | 'MLTransform' >> MLTransform(write_artifact_location=artifact_location).with_transform(ScaleTo01(columns=['x']))\n",
+ " | 'PrintResults' >> beam.Map(print)\n",
+ " )\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1838ecb-2168-45f8-bdf2-41ae0007cb71",
+ "metadata": {
+ "id": "b1838ecb-2168-45f8-bdf2-41ae0007cb71"
+ },
+ "source": [
+ "The output contains artifacts such as `x_max` and `x_min`, which represent the maximum and minimum values of the entire dataset.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Scale the data by using the z-score\n",
+ "\n",
+ "Scale to the data using the z-score\n",
+ "\n",
+ "Similar to `ScaleTo01`, use [ScaleToZScore](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.tft.html#apache_beam.ml.transforms.tft.ScaleToZScore) to scale the values by using the [z-score]([z-score](https://www.tensorflow.org/tfx/transform/api_docs/python/tft/scale_to_z_score#:~:text=Scaling%20to%20z%2Dscore%20subtracts%20out%20the%20mean%20and%20divides%20by%20standard%20deviation.%20Note%20that%20the%20standard%20deviation%20computed%20here%20is%20based%20on%20the%20biased%20variance%20(0%20delta%20degrees%20of%20freedom)%2C%20as%20computed%20by%20analyzers.var.).\n"
+ ],
+ "metadata": {
+ "id": "_bHdYkuF74Fe"
+ },
+ "id": "_bHdYkuF74Fe"
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "delete_artifact_location(artifact_location)\n",
+ "\n",
+ "from apache_beam.ml.transforms.tft import ScaleToZScore\n",
+ "data = [\n",
+ " {'x': [1, 2, 3]}, {'x': [4, 5, 7]}, {'x': [10, 2, 10, 34, 100, 54, 20, 10, 2, 3, 11, 12]}]\n",
+ "\n",
+ "# delete_artifact_location(artifact_location)\n",
+ "with beam.Pipeline() as p:\n",
+ " _ = (\n",
+ " p\n",
+ " | 'CreateData' >> beam.Create(data)\n",
+ " | 'MLTransform' >> MLTransform(write_artifact_location=artifact_location).with_transform(ScaleToZScore(columns=['x']))\n",
+ " | 'PrintResults' >> beam.Map(print)\n",
+ " )\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "aHK6zdfE732A",
+ "outputId": "8b4f5082-35a2-42c4-9342-a77f99338e17"
+ },
+ "id": "aHK6zdfE732A",
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Row(x=array([-0.62608355, -0.5846515 , -0.54321957], dtype=float32), x_mean=array([16.11111], dtype=float32), x_var=array([582.5432], dtype=float32))\n",
+ "Row(x=array([-0.50178754, -0.46035555, -0.37749153], dtype=float32), x_mean=array([16.11111], dtype=float32), x_var=array([582.5432], dtype=float32))\n",
+ "Row(x=array([-0.25319555, -0.5846515 , -0.25319555, 0.7411725 , 3.4756844 ,\n",
+ " 1.5698125 , 0.16112447, -0.25319555, -0.5846515 , -0.54321957,\n",
+ " -0.21176355, -0.17033154], dtype=float32), x_mean=array([16.11111], dtype=float32), x_var=array([582.5432], dtype=float32))\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Use multiple transforms on a single MLTransform\n",
+ "\n",
+ "Apply the same transform on multiple columns. For example, columns `x` and\n",
+ "`y` require scaling by 0 and 1. For column `s`, compute vocabulary. You can use a single `MLTransform` for both of these tasks.\n",
+ "\n",
+ "When using multiple data processing transforms, either pass the transforms as chained transforms or directly as a list."
+ ],
+ "metadata": {
+ "id": "FNoWfyMR8JI-"
+ },
+ "id": "FNoWfyMR8JI-"
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Use multiple data processing transforms in a single MLTransform\n",
+ "\n",
+ "The following example shows multiple data processing transforms chained to `MLTransform`."
+ ],
+ "metadata": {
+ "id": "Mj6hd3jZ9-nr"
+ },
+ "id": "Mj6hd3jZ9-nr"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6e382cca-cfd3-4ac1-956a-16480603dd5b",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "6e382cca-cfd3-4ac1-956a-16480603dd5b",
+ "outputId": "7f185d92-ad91-4067-c11f-66150968ec97"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Row(s=array([1, 0, 4]), x=array([0. , 0.16666667, 0.33333334], dtype=float32), x_max=array([7.], dtype=float32), x_min=array([1.], dtype=float32), y=array([0. , 0.8910891, 1. ], dtype=float32), y_max=array([111.], dtype=float32), y_min=array([10.], dtype=float32))\n",
+ "Row(s=array([1, 0, 6, 2, 3, 5]), x=array([0.5 , 0.6666667, 1. ], dtype=float32), x_max=array([7.], dtype=float32), x_min=array([1.], dtype=float32), y=array([0.00990099, 0.10891089, 0.3960396 ], dtype=float32), y_max=array([111.], dtype=float32), y_min=array([10.], dtype=float32))\n"
+ ]
+ }
+ ],
+ "source": [
+ "delete_artifact_location(artifact_location)\n",
+ "\n",
+ "from apache_beam.ml.transforms.tft import ScaleTo01\n",
+ "from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary\n",
+ "\n",
+ "data = [\n",
+ " {'x': [1, 2, 3], 'y': [10, 100, 111], 's': ['I', 'love', 'pie']},\n",
+ " {'x': [4, 5, 7], 'y': [11, 21, 50], 's': ['I', 'love', 'going', 'to', 'the', 'park']}\n",
+ "]\n",
+ "\n",
+ "# delete_artifact_location(artifact_location)\n",
+ "with beam.Pipeline() as p:\n",
+ " _ = (\n",
+ " p\n",
+ " | 'CreateData' >> beam.Create(data)\n",
+ " | 'MLTransform' >> MLTransform(write_artifact_location=artifact_location).with_transform(\n",
+ " ScaleTo01(columns=['x', 'y'])).with_transform(ComputeAndApplyVocabulary(columns=['s']))\n",
+ " | 'PrintResults' >> beam.Map(print)\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The following example shows multiple data processing transforms passed in as a list to `MLTransform`."
+ ],
+ "metadata": {
+ "id": "IIrL13uEG3mH"
+ },
+ "id": "IIrL13uEG3mH"
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "delete_artifact_location(artifact_location)\n",
+ "\n",
+ "from apache_beam.ml.transforms.tft import ScaleTo01\n",
+ "from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary\n",
+ "\n",
+ "data = [\n",
+ " {'x': [1, 2, 3], 'y': [10, 100, 111], 's': ['I', 'love', 'pie']},\n",
+ " {'x': [4, 5, 7], 'y': [11, 21, 50], 's': ['I', 'love', 'going', 'to', 'the', 'park']}\n",
+ "]\n",
+ "\n",
+ "transforms = [\n",
+ " ScaleTo01(columns=['x', 'y']),\n",
+ " ComputeAndApplyVocabulary(columns=['s'])\n",
+ "]\n",
+ "\n",
+ "with beam.Pipeline() as p:\n",
+ " _ = (\n",
+ " p\n",
+ " | 'CreateData' >> beam.Create(data)\n",
+ " | 'MLTransform' >> MLTransform(write_artifact_location=artifact_location,\n",
+ " transforms=transforms)\n",
+ " | 'PrintResults' >> beam.Map(print)\n",
+ " )"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "equV7ptY-FKL",
+ "outputId": "9c3f9461-31e9-41de-cc5d-96e7ff5a3600"
+ },
+ "id": "equV7ptY-FKL",
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Row(s=array([1, 0, 4]), x=array([0. , 0.16666667, 0.33333334], dtype=float32), x_max=array([7.], dtype=float32), x_min=array([1.], dtype=float32), y=array([0. , 0.8910891, 1. ], dtype=float32), y_max=array([111.], dtype=float32), y_min=array([10.], dtype=float32))\n",
+ "Row(s=array([1, 0, 6, 2, 3, 5]), x=array([0.5 , 0.6666667, 1. ], dtype=float32), x_max=array([7.], dtype=float32), x_min=array([1.], dtype=float32), y=array([0.00990099, 0.10891089, 0.3960396 ], dtype=float32), y_max=array([111.], dtype=float32), y_min=array([10.], dtype=float32))\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### MLTransform for inference workloads\n",
+ "\n",
+ "The previous examples show how to preprocess data for model training. This example uses the same preprocessing steps on the inference data. By using the same steps on the inference data, you can maintain consistent results.\n",
+ "\n",
+ "Preprocess the data going into the inference by using the same preprocessing steps used on the data prior to training. To do this with `MLTransform`, pass the artifact location from the previous transforms to the parameter `read_artifact_location`. `MLTransform` uses the values and artifacts produced in the previous steps. You don't need to provide the transforms, because they are saved with the artifacts in the artifact location.\n"
+ ],
+ "metadata": {
+ "id": "kcnQSwkA-eSA"
+ },
+ "id": "kcnQSwkA-eSA"
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "data = [\n",
+ " {'x': [2], 'y': [59, 91, 85], 's': ['love']},\n",
+ " {'x': [4, 5, 7], 'y': [111, 26, 30], 's': ['I', 'love', 'parks', 'and', 'dogs']}\n",
+ "]\n",
+ "\n",
+ "with beam.Pipeline() as p:\n",
+ " _ = (\n",
+ " p\n",
+ " | 'CreateData' >> beam.Create(data)\n",
+ " | 'MLTransform' >> MLTransform(read_artifact_location=artifact_location)\n",
+ " | 'PrintResults' >> beam.Map(print)\n",
+ " )"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "m0HpQ-Ff-Xmz",
+ "outputId": "1631e0f6-ee58-4c90-f90d-0e183aaaf3c2"
+ },
+ "id": "m0HpQ-Ff-Xmz",
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Row(s=array([0]), x=array([0.16666667], dtype=float32), x_max=array([7.], dtype=float32), x_min=array([1.], dtype=float32), y=array([0.48514852, 0.8019802 , 0.7425743 ], dtype=float32), y_max=array([111.], dtype=float32), y_min=array([10.], dtype=float32))\n",
+ "Row(s=array([ 1, 0, -1, -1, -1]), x=array([0.5 , 0.6666667, 1. ], dtype=float32), x_max=array([7.], dtype=float32), x_min=array([1.], dtype=float32), y=array([1. , 0.15841584, 0.1980198 ], dtype=float32), y_max=array([111.], dtype=float32), y_min=array([10.], dtype=float32))\n"
+ ]
+ }
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.5"
+ },
+ "colab": {
+ "provenance": [],
+ "include_colab_link": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file