diff --git a/examples/notebooks/beam-ml/README.md b/examples/notebooks/beam-ml/README.md
index 77bf3fc99f15..a1fe7ab19f51 100644
--- a/examples/notebooks/beam-ml/README.md
+++ b/examples/notebooks/beam-ml/README.md
@@ -77,6 +77,7 @@ This section contains the following example notebooks.
### Multi-model pipelines
* [Ensemble model using an image captioning and ranking](https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_multi_model.ipynb)
+* [Run ML Inference with Different Models Per Key](https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/per_key_models.ipynb)
### Model Evaluation
diff --git a/examples/notebooks/beam-ml/per_key_models.ipynb b/examples/notebooks/beam-ml/per_key_models.ipynb
new file mode 100644
index 000000000000..b529449555d0
--- /dev/null
+++ b/examples/notebooks/beam-ml/per_key_models.ipynb
@@ -0,0 +1,608 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
+ "\n",
+ "# Licensed to the Apache Software Foundation (ASF) under one\n",
+ "# or more contributor license agreements. See the NOTICE file\n",
+ "# distributed with this work for additional information\n",
+ "# regarding copyright ownership. The ASF licenses this file\n",
+ "# to you under the Apache License, Version 2.0 (the\n",
+ "# \"License\"); you may not use this file except in compliance\n",
+ "# with the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing,\n",
+ "# software distributed under the License is distributed on an\n",
+ "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+ "# KIND, either express or implied. See the License for the\n",
+ "# specific language governing permissions and limitations\n",
+ "# under the License"
+ ],
+ "metadata": {
+ "id": "OsFaZscKSPvo"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Run ML inference with multiple differently-trained models\n",
+ "\n",
+ "
\n"
+ ],
+ "metadata": {
+ "id": "ZUSiAR62SgO8"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Running inference with multiple differently-trained models performing the same task is useful in many scenarios, including the following examples:\n",
+ "\n",
+ "* You want to compare the performance of multiple different models.\n",
+ "* You have models trained on different datasets that you want to use conditionally based on additional metadata.\n",
+ "\n",
+ "In Apache Beam, the recommended way to run inference is to use the `RunInference` transform. By using a `KeyedModelHandler`, you can efficiently run inference with O(100s) of models without having to manage memory yourself.\n",
+ "\n",
+ "This notebook demonstrates how to use a `KeyedModelHandler` to run inference in an Apache Beam pipeline with multiple different models on a per-key basis. This notebook uses pretrained pipelines from Hugging Face. Before continuing with this notebook, it is recommended that you walk through the [beginner RunInference notebook](https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_pytorch_tensorflow_sklearn.ipynb)."
+ ],
+ "metadata": {
+ "id": "ZAVOrrW2An1n"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Install dependencies\n",
+ "\n",
+ "First, install both Apache Beam and the dependencies needed by Hugging Face."
+ ],
+ "metadata": {
+ "id": "_fNyheQoDgGt"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "B-ENznuJqArA",
+ "outputId": "f72963fc-82db-4d0d-9225-07f6b501e256"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ ""
+ ]
+ }
+ ],
+ "source": [
+ "# Note that this notebook currently installs from Beam head since this feature hasn't been released yet.\n",
+ "# It will be released with version 2.51.0, at which point you can install with the following command:\n",
+ "# !pip install apache_beam[gcp]>=2.51.0 --quiet\n",
+ "!git clone https://github.com/apache/beam\n",
+ "!pip install -r beam/sdks/python/build-requirements.txt\n",
+ "!pip install -e ./beam/sdks/python[gcp]\n",
+ "!pip install torch --quiet\n",
+ "!pip install transformers --quiet\n",
+ "\n",
+ "# To use the newly installed versions, restart the runtime.\n",
+ "exit()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from typing import Dict\n",
+ "from typing import Iterable\n",
+ "from typing import Tuple\n",
+ "\n",
+ "from transformers import pipeline\n",
+ "\n",
+ "import apache_beam as beam\n",
+ "from apache_beam.ml.inference.base import KeyedModelHandler\n",
+ "from apache_beam.ml.inference.base import KeyModelMapping\n",
+ "from apache_beam.ml.inference.base import PredictionResult\n",
+ "from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler\n",
+ "from apache_beam.ml.inference.base import RunInference"
+ ],
+ "metadata": {
+ "id": "wUmBEglvsOYW"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Define the model configurations\n",
+ "\n",
+ "A model handler is the Apache Beam method used to define the configuration needed to load and invoke models. Because this example uses two models, we define two model handlers, one for each model. Because both models are incapsulated within Hugging Face pipelines, we use the model handler `HuggingFacePipelineModelHandler`.\n",
+ "\n",
+ "In this notebook, we load the models using Hugging Face and run them against an example. The models produce different outputs."
+ ],
+ "metadata": {
+ "id": "uEqljVgCD7hx"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "distilbert_mh = HuggingFacePipelineModelHandler('text-classification', model=\"distilbert-base-uncased-finetuned-sst-2-english\")\n",
+ "roberta_mh = HuggingFacePipelineModelHandler('text-classification', model=\"roberta-large-mnli\")\n",
+ "\n",
+ "distilbert_pipe = pipeline('text-classification', model=\"distilbert-base-uncased-finetuned-sst-2-english\")\n",
+ "roberta_large_pipe = pipeline(model=\"roberta-large-mnli\")"
+ ],
+ "metadata": {
+ "id": "v2NJT5ZcxgH5",
+ "outputId": "3924d72e-5c49-477d-c50f-6d9098f5a4b2"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/629 [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "b7cb51663677434ca42de6b5e6f37420"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading model.safetensors: 0%| | 0.00/268M [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "3702756019854683a9dea9f8af0a29d0"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)okenizer_config.json: 0%| | 0.00/48.0 [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "52b9fdb51d514c2e8b9fa5813972ab01"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/232k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "eca24b7b7b1847c1aed6aa59a44ed63a"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/688 [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "4d4cfe9a0ca54897aa991420bee01ff9"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading model.safetensors: 0%| | 0.00/1.43G [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "aee85cd919d24125acff1663fba0b47c"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)olve/main/vocab.json: 0%| | 0.00/899k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "0af8ad4eed2d49878fa83b5828d58a96"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)olve/main/merges.txt: 0%| | 0.00/456k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "1ed943a51c704ab7a72101b5b6182772"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)/main/tokenizer.json: 0%| | 0.00/1.36M [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "5b1dcbb533894267b184fd591d8ccdbc"
+ }
+ },
+ "metadata": {}
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "distilbert_pipe(\"This restaurant is awesome\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "H3nYX9thy8ec",
+ "outputId": "826e3285-24b9-47a8-d2a6-835543fdcae7"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[{'label': 'POSITIVE', 'score': 0.9998743534088135}]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "roberta_large_pipe(\"This restaurant is awesome\")\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "IIfc94ODyjUg",
+ "outputId": "94ec8afb-ebfb-47ce-9813-48358741bc6b"
+ },
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[{'label': 'NEUTRAL', 'score': 0.7313134670257568}]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 4
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Define the examples\n",
+ "\n",
+ "Next, define examples to input into the pipeline. The examples include their correct classifications."
+ ],
+ "metadata": {
+ "id": "yd92MC7YEsTf"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "examples = [\n",
+ " (\"This restaurant is awesome\", \"positive\"),\n",
+ " (\"This restaurant is bad\", \"negative\"),\n",
+ " (\"I feel fine\", \"neutral\"),\n",
+ " (\"I love chocolate\", \"positive\"),\n",
+ "]"
+ ],
+ "metadata": {
+ "id": "5HAziWEavQws"
+ },
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "To feed the examples into RunInference, you need distinct keys that can map to the model. In this case, to make it possible to extract the actual sentiment of the example later, define keys in the form `-`."
+ ],
+ "metadata": {
+ "id": "r6GXL5PLFBY7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class FormatExamples(beam.DoFn):\n",
+ " \"\"\"\n",
+ " Map each example to a tuple of ('-', 'example').\n",
+ " We use these keys to map our elements to the correct models.\n",
+ " \"\"\"\n",
+ " def process(self, element: Tuple[str, str]) -> Iterable[Tuple[str, str]]:\n",
+ " yield (f'distilbert-{element[1]}', element[0])\n",
+ " yield (f'roberta-{element[1]}', element[0])"
+ ],
+ "metadata": {
+ "id": "p2uVwws8zRpg"
+ },
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Use the formatted keys to define a `KeyedModelHandler` that maps keys to the `ModelHandler` used for those keys. The `KeyedModelHandler` method lets you define an optional `max_models_per_worker_hint`, which limits the number of models that can be held in a single worker process at one time. If you're worried about your worker running out of memory, use this option. For more information about managing memory, see [Use a keyed ModelHandler](https://beam.apache.org/documentation/sdks/python-machine-learning/index.html#use-a-keyed-modelhandler)."
+ ],
+ "metadata": {
+ "id": "IP65_5nNGIb8"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "per_key_mhs = [\n",
+ " KeyModelMapping(['distilbert-positive', 'distilbert-neutral', 'distilbert-negative'], distilbert_mh),\n",
+ " KeyModelMapping(['roberta-positive', 'roberta-neutral', 'roberta-negative'], roberta_mh)\n",
+ "]\n",
+ "mh = KeyedModelHandler(per_key_mhs, max_models_per_worker_hint=2)"
+ ],
+ "metadata": {
+ "id": "DZpfjeGL2hMG"
+ },
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Postprocess the results\n",
+ "\n",
+ "The `RunInference` transform returns a Tuple containing:\n",
+ "* the original key\n",
+ "* a `PredictionResult` object containing the original example and the inference.\n",
+ "Use those outputs to extract the relevant data. Then, to compare each model's prediction, group this data by the original example."
+ ],
+ "metadata": {
+ "id": "_a4ZmnD5FSeG"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class ExtractResults(beam.DoFn):\n",
+ " \"\"\"\n",
+ " Extract the relevant data from the PredictionResult object.\n",
+ " \"\"\"\n",
+ " def process(self, element: Tuple[str, PredictionResult]) -> Iterable[Tuple[str, Dict[str, str]]]:\n",
+ " actual_sentiment = element[0].split('-')[1]\n",
+ " model = element[0].split('-')[0]\n",
+ " result = element[1]\n",
+ " example = result.example\n",
+ " predicted_sentiment = result.inference[0]['label']\n",
+ "\n",
+ " yield (example, {'model': model, 'actual_sentiment': actual_sentiment, 'predicted_sentiment': predicted_sentiment})"
+ ],
+ "metadata": {
+ "id": "FOwFNQA053TG"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Finally, print the results produced by each model."
+ ],
+ "metadata": {
+ "id": "JVnv4gGbFohk"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class PrintResults(beam.DoFn):\n",
+ " \"\"\"\n",
+ " Print the results produced by each model along with the actual sentiment.\n",
+ " \"\"\"\n",
+ " def process(self, element: Tuple[str, Iterable[Dict[str, str]]]):\n",
+ " example = element[0]\n",
+ " actual_sentiment = element[1][0]['actual_sentiment']\n",
+ " predicted_sentiment_1 = element[1][0]['predicted_sentiment']\n",
+ " model_1 = element[1][0]['model']\n",
+ " predicted_sentiment_2 = element[1][1]['predicted_sentiment']\n",
+ " model_2 = element[1][1]['model']\n",
+ "\n",
+ " if model_1 == 'distilbert':\n",
+ " distilbert_prediction = predicted_sentiment_1\n",
+ " roberta_prediction = predicted_sentiment_2\n",
+ " else:\n",
+ " roberta_prediction = predicted_sentiment_1\n",
+ " distilbert_prediction = predicted_sentiment_2\n",
+ "\n",
+ " print(f'Example: {example}\\nActual Sentiment: {actual_sentiment}\\n'\n",
+ " f'Distilbert Prediction: {distilbert_prediction}\\n'\n",
+ " f'Roberta Prediction: {roberta_prediction}\\n------------')"
+ ],
+ "metadata": {
+ "id": "kUQJNYOa9Q5-"
+ },
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run the pipeline\n",
+ "\n",
+ "Put together all of the pieces to run a single Apache Beam pipeline."
+ ],
+ "metadata": {
+ "id": "-LrpmM2PGAkf"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with beam.Pipeline() as beam_pipeline:\n",
+ "\n",
+ " formatted_examples = (\n",
+ " beam_pipeline\n",
+ " | \"ReadExamples\" >> beam.Create(examples)\n",
+ " | \"FormatExamples\" >> beam.ParDo(FormatExamples()))\n",
+ " inferences = (\n",
+ " formatted_examples\n",
+ " | \"Run Inference\" >> RunInference(mh)\n",
+ " | \"ExtractResults\" >> beam.ParDo(ExtractResults())\n",
+ " | \"GroupByExample\" >> beam.GroupByKey()\n",
+ " )\n",
+ "\n",
+ " inferences | beam.ParDo(PrintResults())\n",
+ "\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 463
+ },
+ "id": "B9Wti3XH0Iqe",
+ "outputId": "528ad732-ecf8-4877-ab6a-badad7944fed"
+ },
+ "execution_count": 10,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/javascript": [
+ "\n",
+ " if (typeof window.interactive_beam_jquery == 'undefined') {\n",
+ " var jqueryScript = document.createElement('script');\n",
+ " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
+ " jqueryScript.type = 'text/javascript';\n",
+ " jqueryScript.onload = function() {\n",
+ " var datatableScript = document.createElement('script');\n",
+ " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
+ " datatableScript.type = 'text/javascript';\n",
+ " datatableScript.onload = function() {\n",
+ " window.interactive_beam_jquery = jQuery.noConflict(true);\n",
+ " window.interactive_beam_jquery(document).ready(function($){\n",
+ " \n",
+ " });\n",
+ " }\n",
+ " document.head.appendChild(datatableScript);\n",
+ " };\n",
+ " document.head.appendChild(jqueryScript);\n",
+ " } else {\n",
+ " window.interactive_beam_jquery(document).ready(function($){\n",
+ " \n",
+ " });\n",
+ " }"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Example: This restaurant is awesome\n",
+ "Actual Sentiment: positive\n",
+ "Distilbert Prediction: POSITIVE\n",
+ "Roberta Prediction: NEUTRAL\n",
+ "------------\n",
+ "Example: This restaurant is bad\n",
+ "Actual Sentiment: negative\n",
+ "Distilbert Prediction: NEGATIVE\n",
+ "Roberta Prediction: NEUTRAL\n",
+ "------------\n",
+ "Example: I love chocolate\n",
+ "Actual Sentiment: positive\n",
+ "Distilbert Prediction: POSITIVE\n",
+ "Roberta Prediction: NEUTRAL\n",
+ "------------\n",
+ "Example: I feel fine\n",
+ "Actual Sentiment: neutral\n",
+ "Distilbert Prediction: POSITIVE\n",
+ "Roberta Prediction: ENTAILMENT\n",
+ "------------\n"
+ ]
+ }
+ ]
+ }
+ ]
+}