From 98791c73ad4c1d98453e30c0f36060a91ab6545f Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Tue, 24 Sep 2024 10:48:45 -0400 Subject: [PATCH] Feedback + CHANGES.md --- CHANGES.md | 15 +------ .../ml/inference/vllm_inference.py | 39 +++++++++++++------ 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index d58ceffeb411..c123a8e1a4dc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -57,18 +57,13 @@ ## Highlights -* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). -* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). - -## I/Os - -* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) ## New Features / Improvements * Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). * Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349)) -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) ## Breaking Changes @@ -77,11 +72,9 @@ as strings rather than silently coerced (and possibly truncated) to numeric values. To retain the old behavior, pass `dtype=True` (or any other value accepted by `pandas.read_json`). -* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). ## Deprecations -* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). * Python 3.8 is reaching EOL and support is being removed in Beam 2.61.0. The 2.60.0 release will warn users when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) @@ -92,10 +85,6 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). -## Known Issues - -* ([#X](https://github.com/apache/beam/issues/X)). - # [2.59.0] - 2024-09-11 ## Highlights diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 929b6c945b4d..28890083d93e 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -120,6 +120,14 @@ def start_server(self, retries=3): server_cmd.append(v) self._server_process, self._server_port = start_process(server_cmd) + self.check_connectivity() + + def get_server_port(self) -> int: + if not self._server_started: + self.start_server() + return self._server_port + + def check_connectivity(self, retries=3): client = getVLLMClient(self._server_port) while self._server_process.poll() is None: try: @@ -134,17 +142,14 @@ def start_server(self, retries=3): time.sleep(5) if retries == 0: + self._server_started = False raise Exception( - "Failed to start vLLM server, process exited with code %s" % + "Failed to start vLLM server, polling process exited with code " + + "%s. Next time a request is tried, the server will be restarted" % self._server_process.poll()) else: self.start_server(retries - 1) - def get_server_port(self) -> int: - if not self._server_started: - self.start_server() - return self._server_port - class VLLMCompletionsModelHandler(ModelHandler[str, PredictionResult, @@ -202,9 +207,14 @@ def run_inference( # for taking in batches and doing a bunch of async calls. That will end up # being more efficient when we can do in bundle batching. for prompt in batch: - completion = client.completions.create( - model=self._model_name, prompt=prompt, **inference_args) - predictions.append(completion) + try: + completion = client.completions.create( + model=self._model_name, prompt=prompt, **inference_args) + predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def share_model_across_processes(self) -> bool: @@ -288,9 +298,14 @@ def run_inference( formatted = [] for message in messages: formatted.append({"role": message.role, "content": message.content}) - completion = client.chat.completions.create( - model=self._model_name, messages=formatted, **inference_args) - predictions.append(completion) + try: + completion = client.chat.completions.create( + model=self._model_name, messages=formatted, **inference_args) + predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def share_model_across_processes(self) -> bool: