Skip to content

Commit

Permalink
Vllm model handler (apache#32410)
Browse files Browse the repository at this point in the history
* Vllm first pass [wip]

* Example for integration tests wip

* Still wip

* Test changes

* Dockerfile improvements

* Remove bad change

* Clean up test args

* clean up invocation

* string fix

* string fix

* clean up

* lint

* Get tests working with 5xx driver

* cleanup

* Fixes, everything is now working

* Batching

* lint

* Feedback + CHANGES.md
  • Loading branch information
damccorm authored and reeba212 committed Dec 4, 2024
1 parent ad5460a commit fb784ac
Show file tree
Hide file tree
Showing 9 changed files with 646 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"comment": "modify this file in a trivial way to cause this test suite to run.",
"modification": 1
"comment": "Modify this file in a trivial way to cause this test suite to run.",
"modification": 2
}

15 changes: 2 additions & 13 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,13 @@

## Highlights

* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)).
* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)).

## I/Os

* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528))

## New Features / Improvements

* Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)).
* Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349))
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528))

## Breaking Changes

Expand All @@ -77,11 +72,9 @@
as strings rather than silently coerced (and possibly truncated) to numeric
values. To retain the old behavior, pass `dtype=True` (or any other value
accepted by `pandas.read_json`).
* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)).

## Deprecations

* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)).
* Python 3.8 is reaching EOL and support is being removed in Beam 2.61.0. The 2.60.0 release will warn users
when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192))

Expand All @@ -92,10 +85,6 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192))
## Security Fixes
* Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)).

## Known Issues

* ([#X](https://github.com/apache/beam/issues/X)).

# [2.59.0] - 2024-09-11

## Highlights
Expand Down
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ tasks.register("python312PostCommit") {
dependsOn(":sdks:python:test-suites:direct:py312:postCommitIT")
dependsOn(":sdks:python:test-suites:direct:py312:hdfsIntegrationTest")
dependsOn(":sdks:python:test-suites:portable:py312:postCommitPy312")
dependsOn(":sdks:python:test-suites:dataflow:py312:inferencePostCommitITPy312")
}

tasks.register("portablePythonPreCommit") {
Expand Down
80 changes: 80 additions & 0 deletions sdks/python/apache_beam/examples/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,7 @@ path/to/my/image2: dandelions (78)
Each line represents a prediction of the flower type along with the confidence in that prediction.

---

## Text classifcation with a Vertex AI LLM

[`vertex_ai_llm_text_classification.py`](./vertex_ai_llm_text_classification.py) contains an implementation for a RunInference pipeline that performs image classification using a model hosted on Vertex AI (based on https://cloud.google.com/vertex-ai/docs/tutorials/image-recognition-custom).
Expand Down Expand Up @@ -882,4 +883,83 @@ This writes the output to the output file with contents like:
```
Each line represents a tuple containing the example, a [PredictionResult](https://beam.apache.org/releases/pydoc/2.40.0/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult)
object with the response from the model in the inference field, and the endpoint id representing the model id.
---

## Text completion with vLLM

[`vllm_text_completion.py`](./vllm_text_completion.py) contains an implementation for a RunInference pipeline that performs text completion using a local [vLLM](https://docs.vllm.ai/en/latest/) server.

The pipeline reads in a set of text prompts or past messages, uses RunInference to spin up a local inference server and perform inference, and then writes the predictions to a text file.

### Model for text completion

To use this transform, you can use any [LLM supported by vLLM](https://docs.vllm.ai/en/latest/models/supported_models.html).

### Running `vllm_text_completion.py`

To run the text completion pipeline locally using the Facebook opt 125M model, use the following command.
```sh
python -m apache_beam.examples.inference.vllm_text_completion \
--model "facebook/opt-125m" \
--output 'path/to/output/file.txt' \
<... aditional pipeline arguments to configure runner if not running in GPU environment ...>
```
You will either need to run this locally with a GPU accelerator or remotely on a runner that supports acceleration.
For example, you could run this on Dataflow with a GPU with the following command:
```sh
python -m apache_beam.examples.inference.vllm_text_completion \
--model "facebook/opt-125m" \
--output 'gs://path/to/output/file.txt' \
--runner dataflow \
--project <gcp project> \
--region us-central1 \
--temp_location <temp gcs location> \
--worker_harness_container_image "gcr.io/apache-beam-testing/beam-ml/vllm:latest" \
--machine_type "n1-standard-4" \
--dataflow_service_options "worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx" \
--staging_location <temp gcs location>
```
Make sure to enable the 5xx driver since vLLM only works with 5xx drivers, not 4xx.
This writes the output to the output file location with contents like:
```
'Hello, my name is', PredictionResult(example={'prompt': 'Hello, my name is'}, inference=Completion(id='cmpl-5f5113a317c949309582b1966511ffc4', choices=[CompletionChoice(finish_reason='length', index=0, logprobs=None, text=' Joel, my dad is Anton Harriman and my wife is Lydia. ', stop_reason=None)], created=1714064548, model='facebook/opt-125m', object='text_completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=16, prompt_tokens=6, total_tokens=22))})
```
Each line represents a tuple containing the example, a [PredictionResult](https://beam.apache.org/releases/pydoc/2.40.0/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult) object with the response from the model in the inference field.
You can also choose to run with chat examples. Doing this requires 2 steps:
1) Upload a [chat_template](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#chat-template) to a filestore which is accessible from your job's environment (e.g. a public Google Cloud Storage bucket). You can copy [this sample template](https://storage.googleapis.com/apache-beam-ml/additional_files/sample_chat_template.jinja) to get started. You can skip this step if using a model other than `facebook/opt-125m` and you know your model provides a chat template.
2) Add the `--chat true` and `--chat_template <gs://path/to/your/file>` parameters:
```sh
python -m apache_beam.examples.inference.vllm_text_completion \
--model "facebook/opt-125m" \
--output 'gs://path/to/output/file.txt' \
--chat true \
--chat_template gs://path/to/your/file \
<... aditional pipeline arguments to configure runner if not running in GPU environment ...>
```
This will configure the pipeline to run against a sequence of previous messages instead of a single text completion prompt.
For example, it might run against:
```
[
OpenAIChatMessage(role='user', content='What is an example of a type of penguin?'),
OpenAIChatMessage(role='system', content='An emperor penguin is a type of penguin.'),
OpenAIChatMessage(role='user', content='Tell me about them')
],
```
and produce the following result in your output file location:
```
An emperor penguin is an adorable creature that lives in Antarctica.
```
---
162 changes: 162 additions & 0 deletions sdks/python/apache_beam/examples/inference/vllm_text_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

""" A sample pipeline using the RunInference API to interface with an LLM using
vLLM. Takes in a set of prompts or lists of previous messages and produces
responses using a model of choice.
Requires a GPU runtime with vllm, openai, and apache-beam installed to run
correctly.
"""

import argparse
import logging
from typing import Iterable

import apache_beam as beam
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.vllm_inference import OpenAIChatMessage
from apache_beam.ml.inference.vllm_inference import VLLMChatModelHandler
from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.runners.runner import PipelineResult

COMPLETION_EXAMPLES = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"John cena is",
]

CHAT_EXAMPLES = [
[
OpenAIChatMessage(
role='user', content='What is an example of a type of penguin?'),
OpenAIChatMessage(
role='assistant', content='Emperor penguin is a type of penguin.'),
OpenAIChatMessage(role='user', content='Tell me about them')
],
[
OpenAIChatMessage(
role='user', content='What colors are in the rainbow?'),
OpenAIChatMessage(
role='assistant',
content='Red, orange, yellow, green, blue, indigo, and violet.'),
OpenAIChatMessage(role='user', content='Do other colors ever appear?')
],
[
OpenAIChatMessage(
role='user', content='Who is the president of the United States?')
],
[
OpenAIChatMessage(role='user', content='What state is Fargo in?'),
OpenAIChatMessage(role='assistant', content='It is in North Dakota.'),
OpenAIChatMessage(role='user', content='How many people live there?'),
OpenAIChatMessage(
role='assistant',
content='Approximately 130,000 people live in Fargo, North Dakota.'
),
OpenAIChatMessage(role='user', content='What is Fargo known for?'),
],
[
OpenAIChatMessage(
role='user', content='How many fish are in the ocean?'),
],
]


def parse_known_args(argv):
"""Parses args for the workflow."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--model',
dest='model',
type=str,
required=False,
default='facebook/opt-125m',
help='LLM to use for task')
parser.add_argument(
'--output',
dest='output',
type=str,
required=True,
help='Path to save output predictions.')
parser.add_argument(
'--chat',
dest='chat',
type=bool,
required=False,
default=False,
help='Whether to use chat model handler and examples')
parser.add_argument(
'--chat_template',
dest='chat_template',
type=str,
required=False,
default=None,
help='Chat template to use for chat example.')
return parser.parse_known_args(argv)


class PostProcessor(beam.DoFn):
def process(self, element: PredictionResult) -> Iterable[str]:
yield str(element.example) + ": " + str(element.inference)


def run(
argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
"""
Args:
argv: Command line arguments defined for this example.
save_main_session: Used for internal testing.
test_pipeline: Used for internal testing.
"""
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

model_handler = VLLMCompletionsModelHandler(model_name=known_args.model)
input_examples = COMPLETION_EXAMPLES

if known_args.chat:
model_handler = VLLMChatModelHandler(
model_name=known_args.model,
chat_template_path=known_args.chat_template)
input_examples = CHAT_EXAMPLES

pipeline = test_pipeline
if not test_pipeline:
pipeline = beam.Pipeline(options=pipeline_options)

examples = pipeline | "Create examples" >> beam.Create(input_examples)
predictions = examples | "RunInference" >> RunInference(model_handler)
process_output = predictions | "Process Predictions" >> beam.ParDo(
PostProcessor())
_ = process_output | "WriteOutput" >> beam.io.WriteToText(
known_args.output, shard_name_template='', append_trailing_newlines=True)

result = pipeline.run()
result.wait_until_finish()
return result


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Used for any vLLM integration test

FROM nvidia/cuda:12.4.1-devel-ubuntu22.04

RUN apt update
RUN apt install software-properties-common -y
RUN add-apt-repository ppa:deadsnakes/ppa
RUN apt update

ARG DEBIAN_FRONTEND=noninteractive

RUN apt install python3.12 -y
RUN apt install python3.12-venv -y
RUN apt install python3.12-dev -y
RUN rm /usr/bin/python3
RUN ln -s python3.12 /usr/bin/python3
RUN python3 --version
RUN apt-get install -y curl
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 && pip install --upgrade pip

RUN pip install --no-cache-dir -vvv apache-beam[gcp]==2.58.1
RUN pip install openai vllm

RUN apt install libcairo2-dev pkg-config python3-dev -y
RUN pip install pycairo

# Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image.
COPY --from=apache/beam_python3.12_sdk:2.58.1 /opt/apache/beam /opt/apache/beam

# Set the entrypoint to Apache Beam SDK worker launcher.
ENTRYPOINT [ "/opt/apache/beam/boot" ]
Loading

0 comments on commit fb784ac

Please sign in to comment.