From 8c548e4c56ec4a14dc7b16571bec89ac8ef71f35 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 18 Apr 2024 15:02:28 -0600 Subject: [PATCH] TGIS metrics (#18) This PR implements a subset of the metrics from the TGIS image. I tried to make sure that everything from our current ops dashboard is supported. These are: - tgi_tokenize_request_tokens - tgi_tokenize_request_input_count - tgi_request_input_count - tgi_request_failure - tgi_request_queue_duration - tgi_queue_size - tgi_batch_current_size - tgi_batch_inference_duration - tgi_request_input_length - tgi_request_generated_tokens --------- Signed-off-by: Joe Runde --- vllm/entrypoints/grpc/grpc_server.py | 114 +++++++++++----------- vllm/tgis_utils/logs.py | 19 ++-- vllm/tgis_utils/metrics.py | 136 +++++++++++++++++++++++++++ 3 files changed, 209 insertions(+), 60 deletions(-) create mode 100644 vllm/tgis_utils/metrics.py diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index 3d678de982cbf..15885fca466e5 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -1,5 +1,4 @@ import argparse -import dataclasses import inspect import time import uuid @@ -37,19 +36,12 @@ from vllm.tgis_utils import logs from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper, TypicalLogitsWarperWrapper) +from vllm.tgis_utils.metrics import (FailureReasonLabel, ServiceMetrics, + TGISStatLogger) from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup logger = init_logger(__name__) - -@dataclasses.dataclass -class Times: - """Container tracking times (in seconds) when requests start and finish """ - # When control enters Generate or GenerateStream - request_start: float - # When the request is sent to the vLLM engine - engine_start: float = 0 - # When the stream from the vLLM engine closes - end: float = 0 +service_metrics = ServiceMetrics() def with_default(value: Any, default: Any) -> Any: @@ -63,7 +55,13 @@ async def _handle_exception(e: Exception, func, *args, **kwargs): if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check context = kwargs.get("context", None) or args[-1] logger.exception(f"{func.__name__} caused GPU OOM error") + service_metrics.count_request_failure(FailureReasonLabel.OOM) await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e)) + else: + if "generate" in func.__name__.lower(): + service_metrics.count_request_failure(FailureReasonLabel.GENERATE) + else: + service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN) logger.exception(f"{func.__name__} failed") raise e @@ -108,10 +106,20 @@ async def _post_init(self): self.tokenizer_group = await self.engine.get_tokenizer_group() self.tokenizer = await self.engine.get_tokenizer() + # Swap in the special TGIS stats logger + vllm_stat_logger = self.engine.engine.stat_logger + tgis_stats_logger = TGISStatLogger( + vllm_stat_logger=vllm_stat_logger, + max_sequence_len=self.config.max_model_len) + # 🌶️🌶️🌶️ sneaky sneak + self.engine.engine.stat_logger = tgis_stats_logger + + @log_rpc_handler_errors async def Generate(self, request: BatchedGenerationRequest, context: ServicerContext) -> BatchedGenerationResponse: start_time = time.time() + service_metrics.count_generate_request(len(request.requests)) request_id = self.request_id(context) sampling_params, deadline = await self._validate_and_convert_params( request.params, context) @@ -120,23 +128,19 @@ async def Generate(self, request: BatchedGenerationRequest, request_count = len(request.requests) generators = [] - timing_infos = [] max_is_token_limit = [False] * request_count for i, req in enumerate(request.requests): input_ids, max_is_token_limit[i]\ = await self._validate_prompt_and_tokenize( sampling_params, truncate_input_tokens, req.text, context) - timing_info = Times(request_start=start_time) - timing_infos.append(timing_info) generators.append( - self.timed_generator( - # prompt is supplied for observability, the text is not - # re-tokenized when `prompt_token_ids` is supplied - self.engine.generate(prompt=req.text, - sampling_params=sampling_params, - request_id=f"{request_id}-{i}", - prompt_token_ids=input_ids), - timing_info)) + # prompt is supplied for observability, the text is not + # re-tokenized when `prompt_token_ids` is supplied + self.engine.generate(prompt=req.text, + sampling_params=sampling_params, + request_id=f"{request_id}-{i}", + prompt_token_ids=input_ids), + ) # TODO handle cancellation result_generator: AsyncIterator[Tuple[ @@ -151,6 +155,7 @@ async def Generate(self, request: BatchedGenerationRequest, # await self.engine.abort(f"{request_id}-{i}") # return self.create_error_response("Client disconnected") responses[i] = res + service_metrics.observe_queue_time(res) if deadline is not None and time.time( ) >= deadline and None not in responses: @@ -173,7 +178,8 @@ async def Generate(self, request: BatchedGenerationRequest, kind_log = f"Sub-request {i} from batch of {request_count}" self._log_unary_response(request=request, response=response, - times=timing_infos[i], kind_log=kind_log) + start_time=start_time, engine_response=res, + kind_log=kind_log) responses[i] = response return BatchedGenerationResponse(responses=responses) @@ -182,7 +188,8 @@ async def Generate(self, request: BatchedGenerationRequest, async def GenerateStream( self, request: SingleGenerationRequest, context: ServicerContext) -> AsyncIterator[GenerationResponse]: - timing_info = Times(request_start=time.time()) + start_time = time.time() + service_metrics.count_generate_request() request_id = self.request_id(context) sampling_params, deadline = await self._validate_and_convert_params( request.params, context) @@ -193,16 +200,13 @@ async def GenerateStream( sampling_params, truncate_input_tokens, request.request.text, context) - result_generator = self.timed_generator( - self.engine.generate( - # prompt is supplied for observability, the text is not - # re-tokenized when `prompt_token_ids` is supplied - prompt=request.request.text, - sampling_params=sampling_params, - request_id=request_id, - prompt_token_ids=input_ids, - ), - timing_info + result_generator = self.engine.generate( + # prompt is supplied for observability, the text is not + # re-tokenized when `prompt_token_ids` is supplied + prompt=request.request.text, + sampling_params=sampling_params, + request_id=request_id, + prompt_token_ids=input_ids, ) resp_options = request.params.response @@ -213,9 +217,12 @@ async def GenerateStream( last_token_count = 0 time_limit_reached = False full_output = "" + last_engine_response = None #TODO handle cancellation async for result in result_generator: + last_engine_response = result if first: + service_metrics.observe_queue_time(result) first_response = self._convert_input_details( result, resp_options, sampling_params, GenerationResponse()) @@ -247,7 +254,8 @@ async def GenerateStream( first_response.text = full_output first_response.generated_token_count = last_token_count self._log_streaming_response(request=request, response=first_response, - times=timing_info) + start_time=start_time, + engine_response=last_engine_response) def _convert_input_details( self, result: RequestOutput, resp_options: ResponseOptions, @@ -314,6 +322,7 @@ async def _validate_and_convert_params( try: validate_params(params, self.max_max_new_tokens) except ValueError as tgis_validation_error: + service_metrics.count_request_failure(FailureReasonLabel.VALIDATION) await context.abort(StatusCode.INVALID_ARGUMENT, str(tgis_validation_error)) @@ -396,6 +405,7 @@ async def _validate_and_convert_params( except ValueError as vllm_validation_error: # There may be validation cases caught by vLLM that are not covered # by the TGIS api validation + service_metrics.count_request_failure(FailureReasonLabel.VALIDATION) await context.abort(StatusCode.INVALID_ARGUMENT, str(vllm_validation_error)) @@ -528,36 +538,32 @@ async def _validate_prompt_and_tokenize( @staticmethod def _log_unary_response(request: BatchedGenerationRequest, - response: GenerationResponse, times: Times, - kind_log: str): + response: GenerationResponse, + engine_response: RequestOutput, + start_time: float, kind_log: str): logs.log_response(inputs=[r.text for r in request.requests], response=response, params=request.params, - prefix_id=request.prefix_id, times=times, - kind_log=kind_log, method_str="generate", - logger=logger) + prefix_id=request.prefix_id, + engine_response=engine_response, + start_time=start_time, kind_log=kind_log, + method_str="generate", logger=logger) @staticmethod def _log_streaming_response(request: SingleGenerationRequest, - response: GenerationResponse, times: Times): + response: GenerationResponse, + engine_response: RequestOutput, + start_time: float): logs.log_response(inputs=[request.request.text], response=response, params=request.params, prefix_id=request.prefix_id, - times=times, kind_log="Streaming response", + engine_response=engine_response, + start_time=start_time, kind_log="Streaming response", method_str="generate_stream", logger=logger) - @staticmethod - async def timed_generator(generator: AsyncIterator[RequestOutput], - times: Times) -> AsyncIterator[RequestOutput]: - """Injects some timing data around each result generator from the - LLMEngine""" - times.engine_start = time.time() - async for val in generator: - yield val - times.end = time.time() - @log_rpc_handler_errors async def Tokenize(self, request: BatchedTokenizeRequest, context: ServicerContext) -> BatchedTokenizeResponse: + service_metrics.observe_tokenization_request(request) #TODO implement these if request.return_offsets: await context.abort(StatusCode.INVALID_ARGUMENT, @@ -578,7 +584,9 @@ async def Tokenize(self, request: BatchedTokenizeRequest, tokens=None if not request.return_tokens else self.tokenizer.convert_ids_to_tokens(token_ids))) - return BatchedTokenizeResponse(responses=responses) + response = BatchedTokenizeResponse(responses=responses) + service_metrics.observe_tokenization_response(response) + return response @log_rpc_handler_errors async def ModelInfo(self, request: ModelInfoRequest, diff --git a/vllm/tgis_utils/logs.py b/vllm/tgis_utils/logs.py index 89df2b10eac17..9b3f41bef77aa 100644 --- a/vllm/tgis_utils/logs.py +++ b/vllm/tgis_utils/logs.py @@ -4,19 +4,23 @@ from google.protobuf import text_format +from vllm import RequestOutput from vllm.entrypoints.grpc.pb.generation_pb2 import (GenerationResponse, Parameters, StopReason) def log_response(inputs: List[str], params: Parameters, prefix_id: str, - response: GenerationResponse, times, kind_log: str, - method_str: str, logger: logging.Logger): + response: GenerationResponse, engine_response: RequestOutput, + start_time: float, kind_log: str, method_str: str, + logger: logging.Logger): """Logs responses similar to how the TGIS server does""" # This time contains both request validation and tokenization - tokenization_time = times.engine_start - times.request_start - llm_engine_time = times.end - times.engine_start - time_per_token = _safe_div(llm_engine_time, response.generated_token_count) - total_time = times.end - times.request_start + tokenization_time = engine_response.metrics.arrival_time - start_time + inference_time = (engine_response.metrics.last_token_time - + engine_response.metrics.first_scheduled_time) + queue_time = engine_response.metrics.time_in_queue + time_per_token = _safe_div(inference_time, response.generated_token_count) + total_time = engine_response.metrics.last_token_time - start_time output_len = len(response.text) short_output = _truncate(response.text, 32) short_input = [_truncate(input_, 32) for input_ in inputs] @@ -26,7 +30,8 @@ def log_response(inputs: List[str], params: Parameters, prefix_id: str, span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} " f"input_chars=[{input_chars}] params={paramstr} " f"tokenization_time={tokenization_time * 1e3:.2f}ms " - f"queue_and_inference_time={llm_engine_time * 1e3:.2f}ms " + f"queue_time={queue_time * 1e3:.2f}ms " + f"inference_time={inference_time * 1e3:.2f}ms " f"time_per_token={time_per_token * 1e3:.2f}ms " f"total_time={total_time * 1e3:.2f}ms " f"input_toks={response.input_token_count}}}") diff --git a/vllm/tgis_utils/metrics.py b/vllm/tgis_utils/metrics.py new file mode 100644 index 0000000000000..c03c8948fa953 --- /dev/null +++ b/vllm/tgis_utils/metrics.py @@ -0,0 +1,136 @@ +"""Implements the logging for all tgi_* metrics for compatibility + with TGIS opsviz""" +from enum import StrEnum, auto + +from prometheus_client import Counter, Gauge, Histogram + +from vllm import RequestOutput +from vllm.engine.metrics import StatLogger, Stats +from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedTokenizeRequest, + BatchedTokenizeResponse) + +_duration_buckets = [ + 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50 +] + + +class FailureReasonLabel(StrEnum): + VALIDATION = auto() # request validation failed + CANCELLED = auto() # TODO: cancellation handling not implemented + CONC_LIMIT = auto() # TODO: is this applicable for vLLM? + OOM = auto() # gpu OOM error + GENERATE = auto( + ) # some error happened while running a text generation request + TIMEOUT = auto() # grpc deadline exceeded + UNKNOWN = auto() + + +class ServiceMetrics: + + def __init__(self): + # Tokenization API metrics + self.tgi_tokenize_request_tokens = Histogram( + "tgi_tokenize_request_tokens", + documentation="Histogram of tokenized tokens per tokenize request", + buckets=[1 << x for x in range(6, 20)]) + self.tgi_tokenize_request_input_count = Counter( + "tgi_tokenize_request_input_count", + documentation= + "Count of tokenize request inputs (batch of n counts as n)") + + # Generate API metrics + self.tgi_request_input_count = Counter( + "tgi_request_input_count", + documentation= + "Count of generate request inputs (batch of n counts as n)") + # err = validation|cancelled|conc_limit + self.tgi_request_failure = Counter( + "tgi_request_failure", + labelnames=["err"], + documentation="Count of failed requests, segmented by error type") + # The queue duration info from the vllm engine is only known at + # response time + self.tgi_request_queue_duration = Histogram( + "tgi_request_queue_duration", + documentation="Request time spent in queue (in seconds)", + buckets=_duration_buckets) + + def observe_tokenization_request(self, request: BatchedTokenizeRequest): + self.tgi_tokenize_request_input_count.inc(len(request.requests)) + + def observe_tokenization_response(self, response: BatchedTokenizeResponse): + for tokenize_response in response.responses: + self.tgi_tokenize_request_tokens.observe( + tokenize_response.token_count) + + def count_generate_request(self, num_requests: int = 1): + self.tgi_request_input_count.inc(num_requests) + + def observe_queue_time(self, engine_output: RequestOutput): + self.tgi_request_queue_duration.observe( + engine_output.metrics.time_in_queue) + + def count_request_failure(self, reason: FailureReasonLabel): + self.tgi_request_failure.labels({"err": reason}).inc(1) + + +class TGISStatLogger(StatLogger): + """Instance wraps the vLLM StatLogger to report TGIS metric names + for compatibility""" + + def __init__(self, vllm_stat_logger: StatLogger, max_sequence_len: int): + # Not calling super-init because we're wrapping and delegating to + # vllm_stat_logger + self._vllm_stat_logger = vllm_stat_logger + + self.tgi_queue_size = Gauge( + "tgi_queue_size", + documentation="Current number of queued requests") + self.tgi_batch_current_size = Gauge("tgi_batch_current_size", + documentation="Current batch size") + # method = prefill|next_token + self.tgi_batch_inference_duration = Histogram( + "tgi_batch_inference_duration", + labelnames=["method"], + documentation= + "Time taken for each forward-pass iteration (in seconds)", + buckets=_duration_buckets) + + sequence_len_buckets = [ + max_sequence_len / 64.0 * (x + 1) for x in range(64) + ] + self.tgi_request_input_length = Histogram( + "tgi_request_input_length", + documentation="Request input length in tokens", + buckets=sequence_len_buckets) + self.tgi_request_generated_tokens = Histogram( + "tgi_request_generated_tokens", + documentation="Number of tokens generated for request", + buckets=sequence_len_buckets) + + def info(self, type: str, obj: object) -> None: + self._vllm_stat_logger.info(type, object) + + def log(self, stats: Stats) -> None: + # First, log the vLLM stats + self._vllm_stat_logger.log(stats) + + # Then log TGIS specific ones + self.tgi_queue_size.set(stats.num_waiting + stats.num_swapped) + self.tgi_batch_current_size.set(stats.num_running) + + for ttft in stats.time_to_first_tokens: + self.tgi_batch_inference_duration.labels({ + "method": "prefill" + }).observe(ttft) + for tpot in stats.time_per_output_tokens: + self.tgi_batch_inference_duration.labels({ + "method": "next_token" + }).observe(tpot) + + # These metrics depend on open PR: https://github.com/vllm-project/vllm/pull/2764 + if hasattr(stats, "num_prompt_tokens_lst"): + for input_len in stats.num_prompt_tokens_lst: + self.tgi_request_input_length.observe(input_len) + for output_len in stats.num_generation_tokens_lst: + self.tgi_request_generated_tokens.observe(output_len)