Skip to content

Commit

Permalink
Add throttling metrics and retries to vertex embeddings (apache#33311)
Browse files Browse the repository at this point in the history
* Add throttling metrics and retries to vertex embeddings

* Format + run postcommits

* fix + lint
  • Loading branch information
damccorm authored Dec 6, 2024
1 parent 0bf6f69 commit d138b75
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run.",
"modification": 6
"modification": 7
}

103 changes: 100 additions & 3 deletions sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,27 @@
# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long
# to install Vertex AI Python SDK.

import logging
import time
from collections.abc import Iterable
from collections.abc import Sequence
from typing import Any
from typing import Optional

from google.api_core.exceptions import ServerError
from google.api_core.exceptions import TooManyRequests
from google.auth.credentials import Credentials

import apache_beam as beam
import vertexai
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
from apache_beam.metrics.metric import Metrics
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.utils import retry
from vertexai.language_models import TextEmbeddingInput
from vertexai.language_models import TextEmbeddingModel
from vertexai.vision_models import Image
Expand All @@ -51,6 +58,26 @@
"CLUSTERING"
]
_BATCH_SIZE = 5 # Vertex AI limits requests to 5 at a time.
_MSEC_TO_SEC = 1000

LOGGER = logging.getLogger("VertexAIEmbeddings")


def _retry_on_appropriate_gcp_error(exception):
"""
Retry filter that returns True if a returned HTTP error code is 5xx or 429.
This is used to retry remote requests that fail, most notably 429
(TooManyRequests.)
Args:
exception: the returned exception encountered during the request/response
loop.
Returns:
boolean indication whether or not the exception is a Server Error (5xx) or
a TooManyRequests (429) error.
"""
return isinstance(exception, (TooManyRequests, ServerError))


class _VertexAITextEmbeddingHandler(ModelHandler):
Expand All @@ -74,6 +101,41 @@ def __init__(
self.task_type = task_type
self.title = title

# Configure AdaptiveThrottler and throttling metrics for client-side
# throttling behavior.
# See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
# for more details.
self.throttled_secs = Metrics.counter(
VertexAIImageEmbeddings, "cumulativeThrottlingSeconds")
self.throttler = AdaptiveThrottler(
window_ms=1, bucket_ms=1, overload_ratio=2)

@retry.with_exponential_backoff(
num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
def get_request(
self,
text_batch: Sequence[TextEmbeddingInput],
model: MultiModalEmbeddingModel,
throttle_delay_secs: int):
while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC):
LOGGER.info(
"Delaying request for %d seconds due to previous failures",
throttle_delay_secs)
time.sleep(throttle_delay_secs)
self.throttled_secs.inc(throttle_delay_secs)

try:
req_time = time.time()
prediction = model.get_embeddings(text_batch)
self.throttler.successful_request(req_time * _MSEC_TO_SEC)
return prediction
except TooManyRequests as e:
LOGGER.warning("request was limited by the service with code %i", e.code)
raise
except Exception as e:
LOGGER.error("unexpected exception raised as part of request, got %s", e)
raise

def run_inference(
self,
batch: Sequence[str],
Expand All @@ -89,7 +151,8 @@ def run_inference(
text=text, title=self.title, task_type=self.task_type)
for text in text_batch
]
embeddings_batch = model.get_embeddings(text_batch)
embeddings_batch = self.get_request(
text_batch=text_batch, model=model, throttle_delay_secs=5)
embeddings.extend([el.values for el in embeddings_batch])
return embeddings

Expand Down Expand Up @@ -173,6 +236,41 @@ def __init__(
self.model_name = model_name
self.dimension = dimension

# Configure AdaptiveThrottler and throttling metrics for client-side
# throttling behavior.
# See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
# for more details.
self.throttled_secs = Metrics.counter(
VertexAIImageEmbeddings, "cumulativeThrottlingSeconds")
self.throttler = AdaptiveThrottler(
window_ms=1, bucket_ms=1, overload_ratio=2)

@retry.with_exponential_backoff(
num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
def get_request(
self,
img: Image,
model: MultiModalEmbeddingModel,
throttle_delay_secs: int):
while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC):
LOGGER.info(
"Delaying request for %d seconds due to previous failures",
throttle_delay_secs)
time.sleep(throttle_delay_secs)
self.throttled_secs.inc(throttle_delay_secs)

try:
req_time = time.time()
prediction = model.get_embeddings(image=img, dimension=self.dimension)
self.throttler.successful_request(req_time * _MSEC_TO_SEC)
return prediction
except TooManyRequests as e:
LOGGER.warning("request was limited by the service with code %i", e.code)
raise
except Exception as e:
LOGGER.error("unexpected exception raised as part of request, got %s", e)
raise

def run_inference(
self,
batch: Sequence[Image],
Expand All @@ -182,8 +280,7 @@ def run_inference(
embeddings = []
# Maximum request size for muli-model embedding models is 1.
for img in batch:
embedding_response = model.get_embeddings(
image=img, dimension=self.dimension)
embedding_response = self.get_request(img, model, throttle_delay_secs=5)
embeddings.append(embedding_response.image_embedding)
return embeddings

Expand Down

0 comments on commit d138b75

Please sign in to comment.