From d4a007155774842e77487a0c6b7695d3c72905c0 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 22 Apr 2024 18:22:36 -0400 Subject: [PATCH 01/18] Vllm first pass [wip] --- .../inference/test_resources/vllm.dockerfile | 28 ++ .../ml/inference/vllm_inference.py | 254 ++++++++++++++++++ .../ml/inference/vllm_inference_it_test.py | 0 3 files changed, 282 insertions(+) create mode 100644 sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile create mode 100644 sdks/python/apache_beam/ml/inference/vllm_inference.py create mode 100644 sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py diff --git a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile new file mode 100644 index 000000000000..eb66ad982102 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile @@ -0,0 +1,28 @@ +# Used for any vLLM integration test + +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 + +RUN apt update +RUN apt install software-properties-common -y +RUN add-apt-repository ppa:deadsnakes/ppa +RUN apt update +RUN apt install python3.10 python3.10-venv python3.10-dev -y +RUN rm /usr/bin/python3 +RUN ln -s python3.10 /usr/bin/python3 +RUN python3 --version +RUN apt-get install -y curl +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && pip install --upgrade pip + +RUN pip install --no-cache-dir apache-beam[gcp]==2.55.0 openai vllm + +RUN apt install libcairo2-dev pkg-config python3-dev -y +RUN pip install pycairo + +# Verify that there are no conflicting dependencies. +RUN pip check + +# Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image. +COPY --from=apache/beam_python3.10_sdk:2.55.0 /opt/apache/beam /opt/apache/beam + +# Set the entrypoint to Apache Beam SDK worker launcher. +ENTRYPOINT [ "/opt/apache/beam/boot" ] diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py new file mode 100644 index 000000000000..407839f6317a --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -0,0 +1,254 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pytype: skip-file + +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.utils import subprocess_server +from dataclasses import dataclass +from openai import OpenAI +import logging +import threading +import time +import subprocess +from typing import Any +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Tuple + +try: + import vllm # type: ignore unused-import + logging.info('vllm module successfully imported.') +except ModuleNotFoundError: + msg = 'vllm module was not found. This is ok as long as the specified ' \ + 'runner has vllm dependencies installed.' + logging.warning(msg) + +__all__ = [ + 'OpenAIChatMessage', + 'VLLMCompletionsModelHandler', + 'VLLMChatModelHandler', +] + + +@dataclass(frozen=True) +class OpenAIChatMessage(): + """" + Dataclass containing previous chat messages in conversation. + Role is the entity that sent the message (either 'user' or 'system'). + Content is the contents of the message. + """ + role: str + content: str + + +def start_process(cmd) -> Tuple[subprocess.Popen, int]: + port, = subprocess_server.pick_port(None) + cmd = [arg.replace('{{PORT}}', str(port)) for arg in cmd] # pylint: disable=not-an-iterable + logging.info("Starting service with %s", str(cmd).replace("',", "'")) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + # Emit the output of this command as info level logging. + def log_stdout(): + line = process.stdout.readline() + while line: + # The log obtained from stdout is bytes, decode it into string. + # Remove newline via rstrip() to not print an empty line. + logging.info(line.decode(errors='backslashreplace').rstrip()) + line = process.stdout.readline() + + t = threading.Thread(target=log_stdout) + t.daemon = True + t.start() + return process, port + + +def getVLLMClient(port) -> OpenAI: + openai_api_key = "EMPTY" + openai_api_base = f"http://localhost:{port}/v1" + return OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + +class _VLLMModelServer(): + def __init__(self, model_name): + self._model_name = model_name + self._server_started = False + self._server_process = None + self._server_port = None + + self.start_server() + + def start_server(self, retries=3): + if not self._server_started: + self._server_process, self._server_port = start_process([ + 'python', + '-m', + 'vllm.entrypoints.openai.api_server', + '--model', + self._model_name, + '--port', + '{{PORT}}', + ]) + + client = getVLLMClient(self._server_port) + while self._server_process.poll() is None: + try: + models = client.models.list().data + logging.info('models: %s' % models) + if len(models) > 0: + self._server_started = True + return + except: + pass + # Sleep while bringing up the process + time.sleep(5) + + if retries == 0: + raise Exception( + 'Failed to start vLLM server, process exited with code %s. ' + 'See worker logs to determine cause.' + % self._server_process.poll()) + else: + self.start_server(retries - 1) + + def get_server_port(self) -> int: + if not self._server_started: + self.start_server() + return self._server_port + + +class VLLMCompletionsModelHandler(ModelHandler[str, + PredictionResult, + _VLLMModelServer]): + def __init__( + self, + model_name: str, + ): + """Implementation of the ModelHandler interface for vLLM using text as + input. + + Example Usage:: + + pcoll | RunInference(VLLMModelHandler(model_name='facebook/opt-125m')) + + Args: + model_name: The vLLM model. See + https://docs.vllm.ai/en/latest/models/supported_models.html for + supported models. + """ + self._model_name = model_name + self._env_vars = {} + + def load_model(self) -> _VLLMModelServer: + return _VLLMModelServer(self._model_name) + + def run_inference( + self, + batch: Sequence[str], + model: _VLLMModelServer, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """Runs inferences on a batch of text strings. + + Args: + batch: A sequence of examples as text strings. + model: A _VLLMModelServer containing info for connecting to the server. + inference_args: Any additional arguments for an inference. + + Returns: + An Iterable of type PredictionResult. + """ + client = getVLLMClient(model.get_server_port()) + predictions = [] + for prompt in batch: + completion = client.completions.create( + model=self._model_name, prompt=prompt, **inference_args) + predictions.append(completion) + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def share_model_across_processes(self) -> bool: + return True + + def should_skip_batching(self) -> bool: + # Batching does not help since vllm is already doing dynamic batching and + # each request is sent one by one anyways + return True + + +class VLLMChatModelHandler(ModelHandler[Sequence[OpenAIChatMessage], + PredictionResult, + _VLLMModelServer]): + def __init__( + self, + model_name: str, + ): + """ Implementation of the ModelHandler interface for vLLM using previous + messages as input. + + Example Usage:: + + pcoll | RunInference(VLLMModelHandler(model_name='facebook/opt-125m')) + + Args: + model_name: The vLLM model. See + https://docs.vllm.ai/en/latest/models/supported_models.html for + supported models. + """ + self._model_name = model_name + self._env_vars = {} + + def load_model(self) -> _VLLMModelServer: + return _VLLMModelServer(self._model_name) + + def run_inference( + self, + batch: Sequence[Sequence[OpenAIChatMessage]], + model: _VLLMModelServer, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """Runs inferences on a batch of text strings. + + Args: + batch: A sequence of examples as OpenAI messages. + model: A _VLLMModelServer containing info for connecting to the spun up server. + inference_args: Any additional arguments for an inference. + + Returns: + An Iterable of type PredictionResult. + """ + client = getVLLMClient(model.get_server_port()) + predictions = [] + for messages in batch: + completion = client.completions.create( + model=self._model_name, messages=messages, **inference_args) + predictions.append(completion) + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def share_model_across_processes(self) -> bool: + return True + + def should_skip_batching(self) -> bool: + # Batching does not help since vllm is already doing dynamic batching and + # each request is sent one by one anyways + return True diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py b/sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py new file mode 100644 index 000000000000..e69de29bb2d1 From 76235ff9559dc3a41d11ad13b0381802aad82476 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Wed, 24 Apr 2024 13:33:39 -0400 Subject: [PATCH 02/18] Example for integration tests wip --- .../inference/vllm_text_completion.py | 141 ++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 sdks/python/apache_beam/examples/inference/vllm_text_completion.py diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py new file mode 100644 index 000000000000..b1fae7ddad63 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -0,0 +1,141 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" A sample pipeline using the RunInference API to classify images of flowers. +This pipeline reads an already-processes representation of an image of +sunflowers and sends it to a deployed Vertex AI model endpoint, then +returns the predictions from the classifier model. The model and image +are from the Hello Image Data Vertex AI tutorial (see +https://cloud.google.com/vertex-ai/docs/tutorials/image-recognition-custom +for more information.) +""" + +import argparse +import logging +from typing import Iterable + +import apache_beam as beam +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.vllm_inference import OpenAIChatMessage, VLLMCompletionsModelHandler, VLLMChatModelHandler +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult + +COMPLETION_EXAMPLES = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "John cena is", +] + +CHAT_EXAMPLES = [ + [ + OpenAIChatMessage(role='user', content='What is an example of a type of penguin?'), + OpenAIChatMessage(role='system', content='An emperor penguin is a type of penguin.'), + OpenAIChatMessage(role='user', content='Tell me about them') + ], + [ + OpenAIChatMessage(role='user', content='What colors are in the rainbow?'), + OpenAIChatMessage(role='system', content='Red, orange, yellow, green, blue, indigo, and violet are colors in the rainbow.'), + OpenAIChatMessage(role='user', content='Do other colors ever appear?') + ], + [ + OpenAIChatMessage(role='user', content='Who is the president of the United States?') + ], + [ + OpenAIChatMessage(role='user', content='What state is Fargo in?'), + OpenAIChatMessage(role='system', content='Fargo is in North Dakota.'), + OpenAIChatMessage(role='user', content='How many people live there?'), + OpenAIChatMessage(role='system', content='Approximately 130,000 people live in Fargo, North Dakota.'), + OpenAIChatMessage(role='user', content='What is Fargo known for?'), + ], + [ + OpenAIChatMessage(role='user', content='How many fish are in the ocean?'), + ], +] + +def parse_known_args(argv): + """Parses args for the workflow.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--model', + dest='model', + type=str, + required=False, + default='facebook/opt-125m', + help='LLM to use for task') + parser.add_argument( + '--output', + dest='output', + type=str, + required=True, + help='Path to save output predictions.') + parser.add_argument( + '--chat', + dest='chat', + type=bool, + required=False, + default=False, + help='Whether to use chat model handler and examples') + return parser.parse_known_args(argv) + + +class PostProcessor(beam.DoFn): + def process(self, element: PredictionResult) -> Iterable[str]: + yield element.example + ": " + element.inference + + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + """ + Args: + argv: Command line arguments defined for this example. + save_main_session: Used for internal testing. + test_pipeline: Used for internal testing. + """ + known_args, pipeline_args = parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + + model_handler = VLLMCompletionsModelHandler(model_name=known_args.model) + input_examples = COMPLETION_EXAMPLES + + if known_args.chat: + model_handler = VLLMChatModelHandler(model_name=known_args.model) + input_examples = CHAT_EXAMPLES + + pipeline = test_pipeline + if not test_pipeline: + pipeline = beam.Pipeline(options=pipeline_options) + + examples = pipeline | "Create examples" >> beam.Create(input_examples) + predictions = examples | "RunInference" >> RunInference(model_handler) + process_output = predictions | "Process Predictions" >> beam.ParDo( + PostProcessor()) + _ = process_output | "WriteOutput" >> beam.io.WriteToText( + known_args.output, shard_name_template='', append_trailing_newlines=True) + + result = pipeline.run() + result.wait_until_finish() + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() From 63cea6044f0df01a1689c11bb59c2c90a38c2f4e Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Sun, 8 Sep 2024 20:20:31 -0400 Subject: [PATCH 03/18] Still wip --- .../apache_beam/examples/inference/README.md | 74 ++++++++++++++++ .../inference/vllm_text_completion.py | 13 ++- .../ml/inference/vllm_inference_it_test.py | 88 +++++++++++++++++++ 3 files changed, 168 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/README.md b/sdks/python/apache_beam/examples/inference/README.md index 3bb68440ed60..9d6c1314efeb 100644 --- a/sdks/python/apache_beam/examples/inference/README.md +++ b/sdks/python/apache_beam/examples/inference/README.md @@ -853,6 +853,7 @@ path/to/my/image2: dandelions (78) Each line represents a prediction of the flower type along with the confidence in that prediction. --- + ## Text classifcation with a Vertex AI LLM [`vertex_ai_llm_text_classification.py`](./vertex_ai_llm_text_classification.py) contains an implementation for a RunInference pipeline that performs image classification using a model hosted on Vertex AI (based on https://cloud.google.com/vertex-ai/docs/tutorials/image-recognition-custom). @@ -882,4 +883,77 @@ This writes the output to the output file with contents like: ``` Each line represents a tuple containing the example, a [PredictionResult](https://beam.apache.org/releases/pydoc/2.40.0/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult) object with the response from the model in the inference field, and the endpoint id representing the model id. +--- + +## Text completion with vLLM + +[`vllm_text_completion.py`](./vllm_text_completion.py) contains an implementation for a RunInference pipeline that performs text completion using a local [vLLM](https://docs.vllm.ai/en/latest/) server. + +The pipeline reads in a set of text prompts or past messages, uses RunInference to spin up a local inference server and perform inference, and then writes the predictions to a text file. + +### Model for text completion + +To use this transform, you can use any [LLM supported by vLLM](https://docs.vllm.ai/en/latest/models/supported_models.html). + +### Running `vllm_text_completion.py` + +To run the text completion pipeline locally using the Facebook opt 125M model, use the following command. +```sh +python -m apache_beam.examples.inference.vllm_text_completion \ + --model "facebook/opt-125m" \ + --output 'path/to/output/file.txt' \ + <... aditional pipeline arguments to configure runner if not running in GPU environment ...> +``` + +You will either need to run this locally with a GPU accelerator or remotely on a runner that supports acceleration. +For example, you could run this on Dataflow with a GPU with the following command: + +```sh +python -m apache_beam.examples.inference.vllm_text_completion \ + --model "facebook/opt-125m" \ + --output 'gs://path/to/output/file.txt' \ + --runner dataflow \ + --project \ + --region us-central1 \ + --temp_location \ + --worker_harness_container_image "gcr.io/apache-beam-testing/beam-ml/vllm:latest" \ + --machine_type "n1-standard-4" \ + --dataflow_service_options "worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx" \ + --staging_location +``` + +Make sure to enable the 5xx driver since vLLM only works with 5xx drivers, not 4xx. + +This writes the output to the output file with contents like: +``` +'Hello, my name is', PredictionResult(example={'prompt': 'Hello, my name is'}, inference=Completion(id='cmpl-5f5113a317c949309582b1966511ffc4', choices=[CompletionChoice(finish_reason='length', index=0, logprobs=None, text=' Joel, my dad is Anton Harriman and my wife is Lydia. ', stop_reason=None)], created=1714064548, model='facebook/opt-125m', object='text_completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=16, prompt_tokens=6, total_tokens=22))}) +``` +Each line represents a tuple containing the example, a [PredictionResult](https://beam.apache.org/releases/pydoc/2.40.0/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult) object with the response from the model in the inference field. + +You can also choose to run with chat examples by adding the `--chat` parameter: +```sh +python -m apache_beam.examples.inference.vllm_text_completion \ + --model "facebook/opt-125m" \ + --output 'path/to/output/file.txt' \ + --chat \ + <... aditional pipeline arguments to configure runner if not running in GPU environment ...> +``` + +This will configure the pipeline to run against a sequence of previous messages instead of a single text completion prompt. +For example, it might run against: + +``` +[ + OpenAIChatMessage(role='user', content='What is an example of a type of penguin?'), + OpenAIChatMessage(role='system', content='An emperor penguin is a type of penguin.'), + OpenAIChatMessage(role='user', content='Tell me about them') +], +``` + +and produce: + +``` +An emperor penguin is an adorable creature that lives in Antarctica. +``` + --- \ No newline at end of file diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index b1fae7ddad63..67b0c8d1a424 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -15,13 +15,12 @@ # limitations under the License. # -""" A sample pipeline using the RunInference API to classify images of flowers. -This pipeline reads an already-processes representation of an image of -sunflowers and sends it to a deployed Vertex AI model endpoint, then -returns the predictions from the classifier model. The model and image -are from the Hello Image Data Vertex AI tutorial (see -https://cloud.google.com/vertex-ai/docs/tutorials/image-recognition-custom -for more information.) +""" A sample pipeline using the RunInference API to interface with an LLM using +vLLM. Takes in a set of prompts or lists of previous messages and produces +responses using a model of choice. + +Requires a GPU runtime with vllm, openai, and apache-beam installed to run +correctly. """ import argparse diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py b/sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py index e69de29bb2d1..e4e6bcea6d92 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""End-to-End test for vLLM Inference""" + +import logging +import unittest +import uuid + +import pytest + +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference import pytorch_inference_it_test +from apache_beam.testing.test_pipeline import TestPipeline + +from apache_beam.examples.inference import vllm_text_completion + + +@pytest.mark.uses_vllm +@pytest.mark.it_postcommit +@pytest.mark.timeout(1800) +class HuggingFaceInference(unittest.TestCase): + def test_vllm_text_completion(self): + test_pipeline = TestPipeline(is_integration_test=True) + # Path to text file containing some sentences + output_file_dir = 'gs://apache-beam-ml/testing/predictions' + output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) + + model_name = 'facebook/opt-125m' + + extra_opts = { + 'model': model_name, + 'output': output_file, + 'machine_type': 'n1-standard-4', + 'dataflow_service_options': ['worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'], + 'worker_harness_container_image': 'gcr.io/apache-beam-testing/beam-ml/vllm:latest' + } + vllm_text_completion.run( + test_pipeline.get_full_options_as_args(**extra_opts), + save_main_session=False) + + self.assertEqual(FileSystems().exists(output_file), True) + predictions = pytorch_inference_it_test.process_outputs( + filepath=output_file) + + self.assertEqual(len(predictions), 5, f'Expected 5 strings, received: {predictions}') + + def test_vllm_chat(self): + test_pipeline = TestPipeline(is_integration_test=True) + # Path to text file containing some sentences + output_file_dir = 'gs://apache-beam-ml/testing/predictions' + output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) + + model_name = 'facebook/opt-125m' + + extra_opts = { + 'model': model_name, + 'output': output_file, + 'chat': True + } + vllm_text_completion.run( + test_pipeline.get_full_options_as_args(**extra_opts), + save_main_session=False) + + self.assertEqual(FileSystems().exists(output_file), True) + predictions = pytorch_inference_it_test.process_outputs( + filepath=output_file) + + self.assertEqual(len(predictions), 5, f'Expected 5 strings, received: {predictions}') + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.DEBUG) + unittest.main() From e6014e04351468bcf16fb5c5d54da2af0a4f1597 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 9 Sep 2024 10:30:50 -0400 Subject: [PATCH 04/18] Test changes --- .../inference/test_resources/vllm.dockerfile | 12 +-- .../ml/inference/vllm_inference_it_test.py | 88 ------------------- .../python/test-suites/dataflow/common.gradle | 34 +++++++ .../test-suites/dataflow/py312/build.gradle | 4 + 4 files changed, 45 insertions(+), 93 deletions(-) delete mode 100644 sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py diff --git a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile index eb66ad982102..daf3ba02d44a 100644 --- a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile +++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile @@ -6,14 +6,16 @@ RUN apt update RUN apt install software-properties-common -y RUN add-apt-repository ppa:deadsnakes/ppa RUN apt update -RUN apt install python3.10 python3.10-venv python3.10-dev -y +RUN apt install python3.12 -y +RUN apt install python3.12-venv -y +RUN apt install python3.12-dev -y RUN rm /usr/bin/python3 -RUN ln -s python3.10 /usr/bin/python3 +RUN ln -s python3.12 /usr/bin/python3 RUN python3 --version RUN apt-get install -y curl -RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && pip install --upgrade pip +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 && pip install --upgrade pip -RUN pip install --no-cache-dir apache-beam[gcp]==2.55.0 openai vllm +RUN pip install --no-cache-dir apache-beam[gcp]==2.58.1 openai vllm RUN apt install libcairo2-dev pkg-config python3-dev -y RUN pip install pycairo @@ -22,7 +24,7 @@ RUN pip install pycairo RUN pip check # Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image. -COPY --from=apache/beam_python3.10_sdk:2.55.0 /opt/apache/beam /opt/apache/beam +COPY --from=apache/beam_python3.12_sdk:2.58.1 /opt/apache/beam /opt/apache/beam # Set the entrypoint to Apache Beam SDK worker launcher. ENTRYPOINT [ "/opt/apache/beam/boot" ] diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py b/sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py deleted file mode 100644 index e4e6bcea6d92..000000000000 --- a/sdks/python/apache_beam/ml/inference/vllm_inference_it_test.py +++ /dev/null @@ -1,88 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""End-to-End test for vLLM Inference""" - -import logging -import unittest -import uuid - -import pytest - -from apache_beam.io.filesystems import FileSystems -from apache_beam.ml.inference import pytorch_inference_it_test -from apache_beam.testing.test_pipeline import TestPipeline - -from apache_beam.examples.inference import vllm_text_completion - - -@pytest.mark.uses_vllm -@pytest.mark.it_postcommit -@pytest.mark.timeout(1800) -class HuggingFaceInference(unittest.TestCase): - def test_vllm_text_completion(self): - test_pipeline = TestPipeline(is_integration_test=True) - # Path to text file containing some sentences - output_file_dir = 'gs://apache-beam-ml/testing/predictions' - output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) - - model_name = 'facebook/opt-125m' - - extra_opts = { - 'model': model_name, - 'output': output_file, - 'machine_type': 'n1-standard-4', - 'dataflow_service_options': ['worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'], - 'worker_harness_container_image': 'gcr.io/apache-beam-testing/beam-ml/vllm:latest' - } - vllm_text_completion.run( - test_pipeline.get_full_options_as_args(**extra_opts), - save_main_session=False) - - self.assertEqual(FileSystems().exists(output_file), True) - predictions = pytorch_inference_it_test.process_outputs( - filepath=output_file) - - self.assertEqual(len(predictions), 5, f'Expected 5 strings, received: {predictions}') - - def test_vllm_chat(self): - test_pipeline = TestPipeline(is_integration_test=True) - # Path to text file containing some sentences - output_file_dir = 'gs://apache-beam-ml/testing/predictions' - output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) - - model_name = 'facebook/opt-125m' - - extra_opts = { - 'model': model_name, - 'output': output_file, - 'chat': True - } - vllm_text_completion.run( - test_pipeline.get_full_options_as_args(**extra_opts), - save_main_session=False) - - self.assertEqual(FileSystems().exists(output_file), True) - predictions = pytorch_inference_it_test.process_outputs( - filepath=output_file) - - self.assertEqual(len(predictions), 5, f'Expected 5 strings, received: {predictions}') - - -if __name__ == '__main__': - logging.getLogger().setLevel(logging.DEBUG) - unittest.main() diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index e5d301ecbe14..612c158fa847 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -424,6 +424,34 @@ def tensorRTTests = tasks.create("tensorRTtests") { } } +def vllmTests = tasks.create("vllmtests") { + dependsOn 'installGcpTest' + dependsOn ':sdks:python:sdist' + doLast { + def testOpts = basicPytestOpts + def argMap = [ + "runner": "DataflowRunner", + "machine_type":"n1-standard-4", + // TODO(https://github.com/apache/beam/issues/22651): Build docker image for VLLM tests during Run time. + // This would also enable to use wheel "--sdk_location" as other tasks, and eliminate distTarBall dependency + // declaration for this project. + // Right now, this is built from https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile + "sdk_container_image": "us.gcr.io/apache-beam-testing/python-postcommit-it/vllm:latest", + "sdk_location": files(configurations.distTarBall.files).singleFile, + "project": "apache-beam-testing", + "region": "us-central1", + "input": "gs://apache-beam-ml/testing/inputs/tensorrt_image_file_names.txt", + "output": "gs://apache-beam-ml/outputs/tensorrt_predictions.txt", + "disk_size_gb": 75 + ] + def cmdArgs = mapToArgString(argMap) + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver'" + } + } +} + // Vertex AI RunInference IT tests task vertexAIInferenceTest { dependsOn 'initializeForDataflowJob' @@ -521,6 +549,12 @@ project.tasks.register("inferencePostCommitIT") { ] } +project.tasks.register("inferencePostCommitITPy312") { + dependsOn = [ + 'vllmTests', + ] +} + // Create cross-language tasks for running tests against Java expansion service(s) def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' diff --git a/sdks/python/test-suites/dataflow/py312/build.gradle b/sdks/python/test-suites/dataflow/py312/build.gradle index ea2bacc02018..4dcb4838e2f4 100644 --- a/sdks/python/test-suites/dataflow/py312/build.gradle +++ b/sdks/python/test-suites/dataflow/py312/build.gradle @@ -22,3 +22,7 @@ applyPythonNature() // Required to setup a Python 3 virtualenv and task names. pythonVersion = '3.12' apply from: "../common.gradle" + +toxTask "testPy312vllm", "py312-vllm", "${posargs}" +test.dependsOn "testPy38onnx-113" +postCommitPyDep.dependsOn "testPy38onnx-113" \ No newline at end of file From bbcca465cead20bdfe4e51d98b09fb005a2268f2 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 9 Sep 2024 17:22:41 +0000 Subject: [PATCH 05/18] Dockerfile improvements --- .../ml/inference/test_resources/vllm.dockerfile | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile index daf3ba02d44a..ea2e489967bc 100644 --- a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile +++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile @@ -6,6 +6,9 @@ RUN apt update RUN apt install software-properties-common -y RUN add-apt-repository ppa:deadsnakes/ppa RUN apt update + +ARG DEBIAN_FRONTEND=noninteractive + RUN apt install python3.12 -y RUN apt install python3.12-venv -y RUN apt install python3.12-dev -y @@ -15,14 +18,12 @@ RUN python3 --version RUN apt-get install -y curl RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 && pip install --upgrade pip -RUN pip install --no-cache-dir apache-beam[gcp]==2.58.1 openai vllm +RUN pip install --no-cache-dir -vvv apache-beam[gcp]==2.58.1 +RUN pip install openai vllm RUN apt install libcairo2-dev pkg-config python3-dev -y RUN pip install pycairo -# Verify that there are no conflicting dependencies. -RUN pip check - # Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image. COPY --from=apache/beam_python3.12_sdk:2.58.1 /opt/apache/beam /opt/apache/beam From 5b234ced8bc0db3fa60af2f75abf6e007257329b Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 9 Sep 2024 14:24:51 -0400 Subject: [PATCH 06/18] Remove bad change --- sdks/python/test-suites/dataflow/py312/build.gradle | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sdks/python/test-suites/dataflow/py312/build.gradle b/sdks/python/test-suites/dataflow/py312/build.gradle index 4dcb4838e2f4..ea2bacc02018 100644 --- a/sdks/python/test-suites/dataflow/py312/build.gradle +++ b/sdks/python/test-suites/dataflow/py312/build.gradle @@ -22,7 +22,3 @@ applyPythonNature() // Required to setup a Python 3 virtualenv and task names. pythonVersion = '3.12' apply from: "../common.gradle" - -toxTask "testPy312vllm", "py312-vllm", "${posargs}" -test.dependsOn "testPy38onnx-113" -postCommitPyDep.dependsOn "testPy38onnx-113" \ No newline at end of file From 61591590528960d7ff51858c2243212af6629bd5 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 9 Sep 2024 14:29:52 -0400 Subject: [PATCH 07/18] Clean up test args --- sdks/python/test-suites/dataflow/common.gradle | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 612c158fa847..925dad58500a 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -440,15 +440,20 @@ def vllmTests = tasks.create("vllmtests") { "sdk_location": files(configurations.distTarBall.files).singleFile, "project": "apache-beam-testing", "region": "us-central1", - "input": "gs://apache-beam-ml/testing/inputs/tensorrt_image_file_names.txt", - "output": "gs://apache-beam-ml/outputs/tensorrt_predictions.txt", + "model": "facebook/opt-125m'", + "output": "gs://apache-beam-ml/outputs/vllm_predictions.txt", "disk_size_gb": 75 ] def cmdArgs = mapToArgString(argMap) + // Exec one version with and one version without the chat option exec { executable 'sh' args '-c', ". ${envdir}/bin/activate && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver'" } + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver'" + } } } From 404a0fcaefe75631af7b3b247de6e0835f6592bd Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 9 Sep 2024 14:31:51 -0400 Subject: [PATCH 08/18] clean up invocation --- .github/trigger_files/beam_PostCommit_Python.json | 2 +- build.gradle.kts | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 2934a91b84b1..30ee463ad4e9 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 1 + "modification": 2 } diff --git a/build.gradle.kts b/build.gradle.kts index e6295384b753..f775dfc319f8 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -543,6 +543,7 @@ tasks.register("python312PostCommit") { dependsOn(":sdks:python:test-suites:direct:py312:postCommitIT") dependsOn(":sdks:python:test-suites:direct:py312:hdfsIntegrationTest") dependsOn(":sdks:python:test-suites:portable:py312:postCommitPy312") + dependsOn(":sdks:python:test-suites:dataflow:py312:inferencePostCommitITPy312") } tasks.register("portablePythonPreCommit") { From b407542f04de54023ad2040a848b2f120a6dc074 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 9 Sep 2024 14:34:04 -0400 Subject: [PATCH 09/18] string fix --- sdks/python/test-suites/dataflow/common.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 925dad58500a..c7cf5323c069 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -424,7 +424,7 @@ def tensorRTTests = tasks.create("tensorRTtests") { } } -def vllmTests = tasks.create("vllmtests") { +def vllmTests = tasks.create("vllmTests") { dependsOn 'installGcpTest' dependsOn ':sdks:python:sdist' doLast { From a9c97c260358c60357bda85fb8b790856813f502 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 9 Sep 2024 17:48:43 -0400 Subject: [PATCH 10/18] string fix --- sdks/python/test-suites/dataflow/common.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index c7cf5323c069..c6e0ce367267 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -440,7 +440,7 @@ def vllmTests = tasks.create("vllmTests") { "sdk_location": files(configurations.distTarBall.files).singleFile, "project": "apache-beam-testing", "region": "us-central1", - "model": "facebook/opt-125m'", + "model": "facebook/opt-125m", "output": "gs://apache-beam-ml/outputs/vllm_predictions.txt", "disk_size_gb": 75 ] From 5bd6ea8edb9134bdab85912f30937fdf98e3c126 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 9 Sep 2024 20:39:54 -0400 Subject: [PATCH 11/18] clean up --- .../inference/vllm_text_completion.py | 27 ++++++++++++++----- .../inference/test_resources/vllm.dockerfile | 16 +++++++++++ .../ml/inference/vllm_inference.py | 9 +++---- sdks/python/setup.py | 1 + .../python/test-suites/dataflow/common.gradle | 4 +-- 5 files changed, 43 insertions(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index 67b0c8d1a424..0559f0fb2456 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -45,30 +45,43 @@ CHAT_EXAMPLES = [ [ - OpenAIChatMessage(role='user', content='What is an example of a type of penguin?'), - OpenAIChatMessage(role='system', content='An emperor penguin is a type of penguin.'), + OpenAIChatMessage( + role='user', content='What is an example of a type of penguin?'), + OpenAIChatMessage( + role='system', content='An emperor penguin is a type of penguin.'), OpenAIChatMessage(role='user', content='Tell me about them') ], [ - OpenAIChatMessage(role='user', content='What colors are in the rainbow?'), - OpenAIChatMessage(role='system', content='Red, orange, yellow, green, blue, indigo, and violet are colors in the rainbow.'), + OpenAIChatMessage( + role='user', content='What colors are in the rainbow?'), + OpenAIChatMessage( + role='system', + content= + 'Red, orange, yellow, green, blue, indigo, and violet are colors in the rainbow.' + ), OpenAIChatMessage(role='user', content='Do other colors ever appear?') ], [ - OpenAIChatMessage(role='user', content='Who is the president of the United States?') + OpenAIChatMessage( + role='user', content='Who is the president of the United States?') ], [ OpenAIChatMessage(role='user', content='What state is Fargo in?'), OpenAIChatMessage(role='system', content='Fargo is in North Dakota.'), OpenAIChatMessage(role='user', content='How many people live there?'), - OpenAIChatMessage(role='system', content='Approximately 130,000 people live in Fargo, North Dakota.'), + OpenAIChatMessage( + role='system', + content='Approximately 130,000 people live in Fargo, North Dakota.' + ), OpenAIChatMessage(role='user', content='What is Fargo known for?'), ], [ - OpenAIChatMessage(role='user', content='How many fish are in the ocean?'), + OpenAIChatMessage( + role='user', content='How many fish are in the ocean?'), ], ] + def parse_known_args(argv): """Parses args for the workflow.""" parser = argparse.ArgumentParser() diff --git a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile index ea2e489967bc..5abbffdc5a2a 100644 --- a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile +++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile @@ -1,3 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Used for any vLLM integration test FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 407839f6317a..7174dbfd063b 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -34,7 +34,7 @@ from typing import Tuple try: - import vllm # type: ignore unused-import + import vllm # pylint: disable=unused-import logging.info('vllm module successfully imported.') except ModuleNotFoundError: msg = 'vllm module was not found. This is ok as long as the specified ' \ @@ -119,7 +119,7 @@ def start_server(self, retries=3): if len(models) > 0: self._server_started = True return - except: + except: # pylint: disable=bare-except pass # Sleep while bringing up the process time.sleep(5) @@ -127,8 +127,7 @@ def start_server(self, retries=3): if retries == 0: raise Exception( 'Failed to start vLLM server, process exited with code %s. ' - 'See worker logs to determine cause.' - % self._server_process.poll()) + 'See worker logs to determine cause.' % self._server_process.poll()) else: self.start_server(retries - 1) @@ -231,7 +230,7 @@ def run_inference( Args: batch: A sequence of examples as OpenAI messages. - model: A _VLLMModelServer containing info for connecting to the spun up server. + model: A _VLLMModelServer for connecting to the spun up server. inference_args: Any additional arguments for an inference. Returns: diff --git a/sdks/python/setup.py b/sdks/python/setup.py index f50a3a07746f..6fc7239295f7 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -401,6 +401,7 @@ def get_portability_package_data(): # https://github.com/sphinx-doc/sphinx/issues/9727 'docutils==0.17.1', 'pandas<2.2.0', + 'openai' ], 'test': [ 'docstring-parser>=0.15,<1.0', diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index c6e0ce367267..2c832fc3075e 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -448,11 +448,11 @@ def vllmTests = tasks.create("vllmTests") { // Exec one version with and one version without the chat option exec { executable 'sh' - args '-c', ". ${envdir}/bin/activate && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver'" + args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver'" } exec { executable 'sh' - args '-c', ". ${envdir}/bin/activate && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver'" + args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver'" } } } From c5bf4f9f1d8d713fe8141843c25d813b2d4eb06e Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 9 Sep 2024 21:46:04 -0400 Subject: [PATCH 12/18] lint --- .../apache_beam/examples/inference/vllm_text_completion.py | 2 +- sdks/python/apache_beam/ml/inference/vllm_inference.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index 0559f0fb2456..26cf99d52c33 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -57,7 +57,7 @@ OpenAIChatMessage( role='system', content= - 'Red, orange, yellow, green, blue, indigo, and violet are colors in the rainbow.' + 'Red, orange, yellow, green, blue, indigo, and violet.' ), OpenAIChatMessage(role='user', content='Do other colors ever appear?') ], diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 7174dbfd063b..02875f5e6ffa 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -119,7 +119,7 @@ def start_server(self, retries=3): if len(models) > 0: self._server_started = True return - except: # pylint: disable=bare-except + except: # pylint: disable=bare-except pass # Sleep while bringing up the process time.sleep(5) From 000401dfa34c6e3ec7effde704200988549bcc21 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 20 Sep 2024 23:44:19 -0400 Subject: [PATCH 13/18] Get tests working with 5xx driver --- .../examples/inference/vllm_text_completion.py | 6 ++---- .../apache_beam/ml/inference/vllm_inference.py | 14 +++++++++++--- sdks/python/test-suites/dataflow/common.gradle | 4 ++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index 26cf99d52c33..367094f13539 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -56,9 +56,7 @@ role='user', content='What colors are in the rainbow?'), OpenAIChatMessage( role='system', - content= - 'Red, orange, yellow, green, blue, indigo, and violet.' - ), + content='Red, orange, yellow, green, blue, indigo, and violet.'), OpenAIChatMessage(role='user', content='Do other colors ever appear?') ], [ @@ -110,7 +108,7 @@ def parse_known_args(argv): class PostProcessor(beam.DoFn): def process(self, element: PredictionResult) -> Iterable[str]: - yield element.example + ": " + element.inference + yield element.example + ": " + str(element.inference) def run( diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 02875f5e6ffa..f581ac85b9b8 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -126,8 +126,8 @@ def start_server(self, retries=3): if retries == 0: raise Exception( - 'Failed to start vLLM server, process exited with code %s. ' - 'See worker logs to determine cause.' % self._server_process.poll()) + "Failed to start vLLM server, process exited with code %s" % + self._server_process.poll()) else: self.start_server(retries - 1) @@ -179,6 +179,7 @@ def run_inference( An Iterable of type PredictionResult. """ client = getVLLMClient(model.get_server_port()) + inference_args = inference_args or {} predictions = [] for prompt in batch: completion = client.completions.create( @@ -192,6 +193,9 @@ def share_model_across_processes(self) -> bool: def should_skip_batching(self) -> bool: # Batching does not help since vllm is already doing dynamic batching and # each request is sent one by one anyways + # TODO(https://github.com/apache/beam/issues/32528): We should add support + # for taking in batches and doing a bunch of async calls. That will end up + # being more efficient when we can do in bundle batching. return True @@ -237,9 +241,10 @@ def run_inference( An Iterable of type PredictionResult. """ client = getVLLMClient(model.get_server_port()) + inference_args = inference_args or {} predictions = [] for messages in batch: - completion = client.completions.create( + completion = client.chat.completions.create( model=self._model_name, messages=messages, **inference_args) predictions.append(completion) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] @@ -250,4 +255,7 @@ def share_model_across_processes(self) -> bool: def should_skip_batching(self) -> bool: # Batching does not help since vllm is already doing dynamic batching and # each request is sent one by one anyways + # TODO(https://github.com/apache/beam/issues/32528): We should add support + # for taking in batches and doing a bunch of async calls. That will end up + # being more efficient when we can do in bundle batching. return True diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 2c832fc3075e..c9aac452d5de 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -448,11 +448,11 @@ def vllmTests = tasks.create("vllmTests") { // Exec one version with and one version without the chat option exec { executable 'sh' - args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver'" + args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'" } exec { executable 'sh' - args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver'" + args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'" } } } From 19a018ba98a75b0615bbf2f21516e1e08ccf7c23 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Sat, 21 Sep 2024 11:22:55 -0400 Subject: [PATCH 14/18] cleanup --- .../examples/inference/vllm_text_completion.py | 4 +++- .../apache_beam/ml/inference/vllm_inference.py | 18 +++++++++++------- sdks/python/test-suites/dataflow/common.gradle | 2 +- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index 367094f13539..46a811f3f7b0 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -30,7 +30,9 @@ import apache_beam as beam from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference -from apache_beam.ml.inference.vllm_inference import OpenAIChatMessage, VLLMCompletionsModelHandler, VLLMChatModelHandler +from apache_beam.ml.inference.vllm_inference import OpenAIChatMessage +from apache_beam.ml.inference.vllm_inference import VLLMChatModelHandler +from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions from apache_beam.runners.runner import PipelineResult diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index f581ac85b9b8..a762ddab0bc8 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -17,15 +17,11 @@ # pytype: skip-file -from apache_beam.ml.inference.base import ModelHandler -from apache_beam.ml.inference.base import PredictionResult -from apache_beam.utils import subprocess_server -from dataclasses import dataclass -from openai import OpenAI import logging +import subprocess import threading import time -import subprocess +from dataclasses import dataclass from typing import Any from typing import Dict from typing import Iterable @@ -33,6 +29,11 @@ from typing import Sequence from typing import Tuple +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.utils import subprocess_server +from openai import OpenAI + try: import vllm # pylint: disable=unused-import logging.info('vllm module successfully imported.') @@ -244,8 +245,11 @@ def run_inference( inference_args = inference_args or {} predictions = [] for messages in batch: + formatted = [] + for message in messages: + formatted.append({"role": message.role, "content": message.content}) completion = client.chat.completions.create( - model=self._model_name, messages=messages, **inference_args) + model=self._model_name, messages=formatted, **inference_args) predictions.append(completion) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index c9aac452d5de..e4d6141a5573 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -452,7 +452,7 @@ def vllmTests = tasks.create("vllmTests") { } exec { executable 'sh' - args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'" + args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat true --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'" } } } From 7a173dcf94c50bbc66635dbe0db50351ebb63c6b Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 23 Sep 2024 10:18:09 -0400 Subject: [PATCH 15/18] Fixes, everything is now working --- .../apache_beam/examples/inference/README.md | 16 ++-- .../inference/vllm_text_completion.py | 21 +++-- .../ml/inference/vllm_inference.py | 85 ++++++++++++++----- .../python/test-suites/dataflow/common.gradle | 2 +- 4 files changed, 89 insertions(+), 35 deletions(-) diff --git a/sdks/python/apache_beam/examples/inference/README.md b/sdks/python/apache_beam/examples/inference/README.md index 9d6c1314efeb..f9c5af436965 100644 --- a/sdks/python/apache_beam/examples/inference/README.md +++ b/sdks/python/apache_beam/examples/inference/README.md @@ -924,18 +924,24 @@ python -m apache_beam.examples.inference.vllm_text_completion \ Make sure to enable the 5xx driver since vLLM only works with 5xx drivers, not 4xx. -This writes the output to the output file with contents like: +This writes the output to the output file location with contents like: + ``` 'Hello, my name is', PredictionResult(example={'prompt': 'Hello, my name is'}, inference=Completion(id='cmpl-5f5113a317c949309582b1966511ffc4', choices=[CompletionChoice(finish_reason='length', index=0, logprobs=None, text=' Joel, my dad is Anton Harriman and my wife is Lydia. ', stop_reason=None)], created=1714064548, model='facebook/opt-125m', object='text_completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=16, prompt_tokens=6, total_tokens=22))}) ``` Each line represents a tuple containing the example, a [PredictionResult](https://beam.apache.org/releases/pydoc/2.40.0/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult) object with the response from the model in the inference field. -You can also choose to run with chat examples by adding the `--chat` parameter: +You can also choose to run with chat examples. Doing this requires 2 steps: + +1) Upload a [chat_template](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#chat-template) to a filestore which is accessible from your job's environment (e.g. a public Google Cloud Storage bucket). You can copy [this sample template](https://storage.googleapis.com/apache-beam-ml/additional_files/sample_chat_template.jinja) to get started. You can skip this step if using a model other than `facebook/opt-125m` and you know your model provides a chat template. +2) Add the `--chat true` and `--chat_template ` parameters: + ```sh python -m apache_beam.examples.inference.vllm_text_completion \ --model "facebook/opt-125m" \ - --output 'path/to/output/file.txt' \ - --chat \ + --output 'gs://path/to/output/file.txt' \ + --chat true \ + --chat_template gs://path/to/your/file \ <... aditional pipeline arguments to configure runner if not running in GPU environment ...> ``` @@ -950,7 +956,7 @@ For example, it might run against: ], ``` -and produce: +and produce the following result in your output file location: ``` An emperor penguin is an adorable creature that lives in Antarctica. diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index 46a811f3f7b0..3cf7d04cb03e 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -50,14 +50,14 @@ OpenAIChatMessage( role='user', content='What is an example of a type of penguin?'), OpenAIChatMessage( - role='system', content='An emperor penguin is a type of penguin.'), + role='assistant', content='Emperor penguin is a type of penguin.'), OpenAIChatMessage(role='user', content='Tell me about them') ], [ OpenAIChatMessage( role='user', content='What colors are in the rainbow?'), OpenAIChatMessage( - role='system', + role='assistant', content='Red, orange, yellow, green, blue, indigo, and violet.'), OpenAIChatMessage(role='user', content='Do other colors ever appear?') ], @@ -67,10 +67,10 @@ ], [ OpenAIChatMessage(role='user', content='What state is Fargo in?'), - OpenAIChatMessage(role='system', content='Fargo is in North Dakota.'), + OpenAIChatMessage(role='assistant', content='It is in North Dakota.'), OpenAIChatMessage(role='user', content='How many people live there?'), OpenAIChatMessage( - role='system', + role='assistant', content='Approximately 130,000 people live in Fargo, North Dakota.' ), OpenAIChatMessage(role='user', content='What is Fargo known for?'), @@ -105,12 +105,19 @@ def parse_known_args(argv): required=False, default=False, help='Whether to use chat model handler and examples') + parser.add_argument( + '--chat_template', + dest='chat_template', + type=str, + required=False, + default=None, + help='Chat template to use for chat example.') return parser.parse_known_args(argv) class PostProcessor(beam.DoFn): def process(self, element: PredictionResult) -> Iterable[str]: - yield element.example + ": " + str(element.inference) + yield str(element.example) + ": " + str(element.inference) def run( @@ -129,7 +136,9 @@ def run( input_examples = COMPLETION_EXAMPLES if known_args.chat: - model_handler = VLLMChatModelHandler(model_name=known_args.model) + model_handler = VLLMChatModelHandler( + model_name=known_args.model, + chat_template_path=known_args.chat_template) input_examples = CHAT_EXAMPLES pipeline = test_pipeline diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index a762ddab0bc8..3616f12263a2 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -18,9 +18,11 @@ # pytype: skip-file import logging +import os import subprocess import threading import time +import uuid from dataclasses import dataclass from typing import Any from typing import Dict @@ -29,6 +31,7 @@ from typing import Sequence from typing import Tuple +from apache_beam.io.filesystems import FileSystems from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.utils import subprocess_server @@ -92,8 +95,9 @@ def getVLLMClient(port) -> OpenAI: class _VLLMModelServer(): - def __init__(self, model_name): + def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]): self._model_name = model_name + self._vllm_server_kwargs = vllm_server_kwargs self._server_started = False self._server_process = None self._server_port = None @@ -102,7 +106,7 @@ def __init__(self, model_name): def start_server(self, retries=3): if not self._server_started: - self._server_process, self._server_port = start_process([ + server_cmd = [ 'python', '-m', 'vllm.entrypoints.openai.api_server', @@ -110,7 +114,11 @@ def start_server(self, retries=3): self._model_name, '--port', '{{PORT}}', - ]) + ] + for k, v in self._vllm_server_kwargs.items(): + server_cmd.append(f'--{k}') + server_cmd.append(v) + self._server_process, self._server_port = start_process(server_cmd) client = getVLLMClient(self._server_port) while self._server_process.poll() is None: @@ -141,10 +149,9 @@ def get_server_port(self) -> int: class VLLMCompletionsModelHandler(ModelHandler[str, PredictionResult, _VLLMModelServer]): - def __init__( - self, - model_name: str, - ): + def __init__(self, + model_name: str, + vllm_server_kwargs: Optional[Dict[str, str]] = None): """Implementation of the ModelHandler interface for vLLM using text as input. @@ -156,16 +163,24 @@ def __init__( model_name: The vLLM model. See https://docs.vllm.ai/en/latest/models/supported_models.html for supported models. + vllm_server_kwargs: Any additional kwargs to be passed into your vllm + server when it is being created. Will be invoked using + `python -m vllm.entrypoints.openai.api_serverv + `. For example, you could pass + `{'echo': 'true'}` to prepend new messages with the previous message. + For a list of possible kwargs, see + https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-completions-api """ self._model_name = model_name + self._vllm_server_kwargs: Dict[str, str] = vllm_server_kwargs or {} self._env_vars = {} def load_model(self) -> _VLLMModelServer: - return _VLLMModelServer(self._model_name) + return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) def run_inference( self, - batch: Sequence[str], + batch: str, model: _VLLMModelServer, inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[PredictionResult]: @@ -182,10 +197,9 @@ def run_inference( client = getVLLMClient(model.get_server_port()) inference_args = inference_args or {} predictions = [] - for prompt in batch: - completion = client.completions.create( - model=self._model_name, prompt=prompt, **inference_args) - predictions.append(completion) + completion = client.completions.create( + model=self._model_name, prompt=batch, **inference_args) + predictions.append(completion) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def share_model_across_processes(self) -> bool: @@ -206,7 +220,8 @@ class VLLMChatModelHandler(ModelHandler[Sequence[OpenAIChatMessage], def __init__( self, model_name: str, - ): + chat_template_path: Optional[str] = None, + vllm_server_kwargs: Dict[str, str] = None): """ Implementation of the ModelHandler interface for vLLM using previous messages as input. @@ -218,16 +233,41 @@ def __init__( model_name: The vLLM model. See https://docs.vllm.ai/en/latest/models/supported_models.html for supported models. + chat_template_path: Path to a chat template. This file must be accessible + from your runner's execution environment, so it is recommended to use + a cloud based file storage system (e.g. Google Cloud Storage). + For info on chat templates, see: + https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#chat-template + vllm_server_kwargs: Any additional kwargs to be passed into your vllm + server when it is being created. Will be invoked using + `python -m vllm.entrypoints.openai.api_serverv + `. For example, you could pass + `{'echo': 'true'}` to prepend new messages with the previous message. + For a list of possible kwargs, see + https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api """ self._model_name = model_name + self._vllm_server_kwargs: Dict[str, str] = vllm_server_kwargs or {} self._env_vars = {} + self._chat_template_path = chat_template_path + self._chat_file = f'template-{uuid.uuid4().hex}.jinja' def load_model(self) -> _VLLMModelServer: - return _VLLMModelServer(self._model_name) + chat_template_contents = '' + if self._chat_template_path is not None: + local_chat_template_path = os.path.join(os.getcwd(), self._chat_file) + if not os.path.exists(local_chat_template_path): + with FileSystems.open(self._chat_template_path) as fin: + chat_template_contents = fin.read().decode() + with open(local_chat_template_path, 'a') as f: + f.write(chat_template_contents) + self._vllm_server_kwargs['chat_template'] = local_chat_template_path + + return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) def run_inference( self, - batch: Sequence[Sequence[OpenAIChatMessage]], + batch: Sequence[OpenAIChatMessage], model: _VLLMModelServer, inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[PredictionResult]: @@ -244,13 +284,12 @@ def run_inference( client = getVLLMClient(model.get_server_port()) inference_args = inference_args or {} predictions = [] - for messages in batch: - formatted = [] - for message in messages: - formatted.append({"role": message.role, "content": message.content}) - completion = client.chat.completions.create( - model=self._model_name, messages=formatted, **inference_args) - predictions.append(completion) + formatted = [] + for message in batch: + formatted.append({"role": message.role, "content": message.content}) + completion = client.chat.completions.create( + model=self._model_name, messages=formatted, **inference_args) + predictions.append(completion) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def share_model_across_processes(self) -> bool: diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index e4d6141a5573..6bca904c1a64 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -452,7 +452,7 @@ def vllmTests = tasks.create("vllmTests") { } exec { executable 'sh' - args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat true --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'" + args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat true --chat_template 'gs://apache-beam-ml/additional_files/sample_chat_template.jinja' --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'" } } } From 7703f536c3a83ee0daa94b74a4ba4d71680bf850 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 23 Sep 2024 11:04:39 -0400 Subject: [PATCH 16/18] Batching --- .../ml/inference/vllm_inference.py | 53 ++++++++----------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 3616f12263a2..f9c246ab47d3 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -149,9 +149,10 @@ def get_server_port(self) -> int: class VLLMCompletionsModelHandler(ModelHandler[str, PredictionResult, _VLLMModelServer]): - def __init__(self, - model_name: str, - vllm_server_kwargs: Optional[Dict[str, str]] = None): + def __init__( + self, + model_name: str, + vllm_server_kwargs: Optional[Dict[str, str]] = None): """Implementation of the ModelHandler interface for vLLM using text as input. @@ -180,7 +181,7 @@ def load_model(self) -> _VLLMModelServer: def run_inference( self, - batch: str, + batch: Sequence[str], model: _VLLMModelServer, inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[PredictionResult]: @@ -197,20 +198,16 @@ def run_inference( client = getVLLMClient(model.get_server_port()) inference_args = inference_args or {} predictions = [] - completion = client.completions.create( - model=self._model_name, prompt=batch, **inference_args) - predictions.append(completion) - return [PredictionResult(x, y) for x, y in zip(batch, predictions)] - - def share_model_across_processes(self) -> bool: - return True - - def should_skip_batching(self) -> bool: - # Batching does not help since vllm is already doing dynamic batching and - # each request is sent one by one anyways # TODO(https://github.com/apache/beam/issues/32528): We should add support # for taking in batches and doing a bunch of async calls. That will end up # being more efficient when we can do in bundle batching. + for prompt in batch: + completion = client.completions.create( + model=self._model_name, prompt=prompt, **inference_args) + predictions.append(completion) + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def share_model_across_processes(self) -> bool: return True @@ -267,7 +264,7 @@ def load_model(self) -> _VLLMModelServer: def run_inference( self, - batch: Sequence[OpenAIChatMessage], + batch: Sequence[Sequence[OpenAIChatMessage]], model: _VLLMModelServer, inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[PredictionResult]: @@ -284,21 +281,17 @@ def run_inference( client = getVLLMClient(model.get_server_port()) inference_args = inference_args or {} predictions = [] - formatted = [] - for message in batch: - formatted.append({"role": message.role, "content": message.content}) - completion = client.chat.completions.create( - model=self._model_name, messages=formatted, **inference_args) - predictions.append(completion) - return [PredictionResult(x, y) for x, y in zip(batch, predictions)] - - def share_model_across_processes(self) -> bool: - return True - - def should_skip_batching(self) -> bool: - # Batching does not help since vllm is already doing dynamic batching and - # each request is sent one by one anyways # TODO(https://github.com/apache/beam/issues/32528): We should add support # for taking in batches and doing a bunch of async calls. That will end up # being more efficient when we can do in bundle batching. + for messages in batch: + formatted = [] + for message in messages: + formatted.append({"role": message.role, "content": message.content}) + completion = client.chat.completions.create( + model=self._model_name, messages=formatted, **inference_args) + predictions.append(completion) + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def share_model_across_processes(self) -> bool: return True From 93bbd49b810045e07991510af952ede0204840b6 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 23 Sep 2024 11:55:47 -0400 Subject: [PATCH 17/18] lint --- sdks/python/apache_beam/ml/inference/vllm_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index f9c246ab47d3..929b6c945b4d 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -100,7 +100,7 @@ def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]): self._vllm_server_kwargs = vllm_server_kwargs self._server_started = False self._server_process = None - self._server_port = None + self._server_port: int = -1 self.start_server() @@ -218,7 +218,7 @@ def __init__( self, model_name: str, chat_template_path: Optional[str] = None, - vllm_server_kwargs: Dict[str, str] = None): + vllm_server_kwargs: Optional[Dict[str, str]] = None): """ Implementation of the ModelHandler interface for vLLM using previous messages as input. From 98791c73ad4c1d98453e30c0f36060a91ab6545f Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Tue, 24 Sep 2024 10:48:45 -0400 Subject: [PATCH 18/18] Feedback + CHANGES.md --- CHANGES.md | 15 +------ .../ml/inference/vllm_inference.py | 39 +++++++++++++------ 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index d58ceffeb411..c123a8e1a4dc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -57,18 +57,13 @@ ## Highlights -* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). -* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). - -## I/Os - -* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) ## New Features / Improvements * Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). * Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349)) -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) ## Breaking Changes @@ -77,11 +72,9 @@ as strings rather than silently coerced (and possibly truncated) to numeric values. To retain the old behavior, pass `dtype=True` (or any other value accepted by `pandas.read_json`). -* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). ## Deprecations -* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). * Python 3.8 is reaching EOL and support is being removed in Beam 2.61.0. The 2.60.0 release will warn users when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) @@ -92,10 +85,6 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). -## Known Issues - -* ([#X](https://github.com/apache/beam/issues/X)). - # [2.59.0] - 2024-09-11 ## Highlights diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 929b6c945b4d..28890083d93e 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -120,6 +120,14 @@ def start_server(self, retries=3): server_cmd.append(v) self._server_process, self._server_port = start_process(server_cmd) + self.check_connectivity() + + def get_server_port(self) -> int: + if not self._server_started: + self.start_server() + return self._server_port + + def check_connectivity(self, retries=3): client = getVLLMClient(self._server_port) while self._server_process.poll() is None: try: @@ -134,17 +142,14 @@ def start_server(self, retries=3): time.sleep(5) if retries == 0: + self._server_started = False raise Exception( - "Failed to start vLLM server, process exited with code %s" % + "Failed to start vLLM server, polling process exited with code " + + "%s. Next time a request is tried, the server will be restarted" % self._server_process.poll()) else: self.start_server(retries - 1) - def get_server_port(self) -> int: - if not self._server_started: - self.start_server() - return self._server_port - class VLLMCompletionsModelHandler(ModelHandler[str, PredictionResult, @@ -202,9 +207,14 @@ def run_inference( # for taking in batches and doing a bunch of async calls. That will end up # being more efficient when we can do in bundle batching. for prompt in batch: - completion = client.completions.create( - model=self._model_name, prompt=prompt, **inference_args) - predictions.append(completion) + try: + completion = client.completions.create( + model=self._model_name, prompt=prompt, **inference_args) + predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def share_model_across_processes(self) -> bool: @@ -288,9 +298,14 @@ def run_inference( formatted = [] for message in messages: formatted.append({"role": message.role, "content": message.content}) - completion = client.chat.completions.create( - model=self._model_name, messages=formatted, **inference_args) - predictions.append(completion) + try: + completion = client.chat.completions.create( + model=self._model_name, messages=formatted, **inference_args) + predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def share_model_across_processes(self) -> bool: