From a52ee73b27f3f1a82cadee11ad0ea20c313159bb Mon Sep 17 00:00:00 2001 From: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:03:22 -0700 Subject: [PATCH] Add OpenAI format response to r2.0.0rc1 (#9796) * Add OpenAI format response to r2.0.0rc1 Signed-off-by: Abhishree * Apply isort and black reformatting Signed-off-by: athitten --------- Signed-off-by: Abhishree Signed-off-by: athitten Co-authored-by: athitten Co-authored-by: Eric Harper --- nemo/deploy/nlp/query_llm.py | 14 +++++++++- nemo/deploy/service/config.json | 5 ---- nemo/deploy/service/rest_model_api.py | 39 ++++++++++++++++++++++++--- scripts/deploy/nlp/deploy_triton.py | 33 +++++++++++++++++++++++ 4 files changed, 81 insertions(+), 10 deletions(-) delete mode 100644 nemo/deploy/service/config.json diff --git a/nemo/deploy/nlp/query_llm.py b/nemo/deploy/nlp/query_llm.py index 71492520bf0a..7e873db6b5b1 100644 --- a/nemo/deploy/nlp/query_llm.py +++ b/nemo/deploy/nlp/query_llm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time from abc import ABC, abstractmethod import numpy as np @@ -172,6 +173,7 @@ def query_llm( compute_logprob: bool = None, end_strings=None, init_timeout=60.0, + openai_format_response: bool = False, ): """ Query the Triton server synchronously and return a list of responses. @@ -259,7 +261,17 @@ def query_llm( return "Unknown output keyword." sentences = np.char.decode(output.astype("bytes"), "utf-8") - return sentences + if openai_format_response: + openai_response = { + "id": f"cmpl-{int(time.time())}", + "object": "text_completion", + "created": int(time.time()), + "model": self.model_name, + "choices": [{"text": str(sentences)}], + } + return openai_response + else: + return sentences else: return result_dict["outputs"] diff --git a/nemo/deploy/service/config.json b/nemo/deploy/service/config.json deleted file mode 100644 index d3b3440dd97b..000000000000 --- a/nemo/deploy/service/config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "triton_service_port": 8000, - "triton_service_ip": "0.0.0.0", - "triton_request_timeout": 60 - } \ No newline at end of file diff --git a/nemo/deploy/service/rest_model_api.py b/nemo/deploy/service/rest_model_api.py index 5c49370fd45f..fbc774883faa 100644 --- a/nemo/deploy/service/rest_model_api.py +++ b/nemo/deploy/service/rest_model_api.py @@ -12,8 +12,9 @@ import json import os from pathlib import Path +import requests -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from pydantic import BaseModel from pydantic_settings import BaseSettings @@ -33,6 +34,7 @@ def __init__(self): self._triton_service_port = config_json["triton_service_port"] self._triton_service_ip = config_json["triton_service_ip"] self._triton_request_timeout = config_json["triton_request_timeout"] + self._openai_format_response = config_json["openai_format_response"] except Exception as error: print("An exception occurred:", error) return @@ -49,6 +51,14 @@ def triton_service_ip(self): def triton_request_timeout(self): return self._triton_request_timeout + @property + def openai_format_response(self): + """ + Retuns the response from Triton server in OpenAI compatible formar if set to True, + default set in config.json is false. + """ + return self._openai_format_response + app = FastAPI() triton_settings = TritonSettings() @@ -66,6 +76,23 @@ class CompletionRequest(BaseModel): frequency_penalty: float = 1.0 +@app.get("/triton_health") +async def check_triton_health(): + """ + This method exposes endpoint "/triton_health" which can be used to verify if Triton server is accessible while running the REST or FastAPI application. + Verify by running: curl http://service_http_address:service_port/triton_health and the returned status should inform if the server is accessible. + """ + triton_url = f"triton_settings.triton_service_ip:str(triton_settings.triton_service_port)/v2/health/ready" + try: + response = requests.get(triton_url, timeout=5) + if response.status_code == 200: + return {"status": "Triton server is reachable and ready"} + else: + raise HTTPException(status_code=503, detail="Triton server is not ready") + except requests.RequestException as e: + raise HTTPException(status_code=503, detail=f"Cannot reach Triton server: {str(e)}") + + @app.post("/v1/completions/") def completions_v1(request: CompletionRequest): try: @@ -78,10 +105,14 @@ def completions_v1(request: CompletionRequest): top_p=request.top_p, temperature=request.temperature, init_timeout=triton_settings.triton_request_timeout, + openai_format_response=triton_settings.openai_format_response, ) - return { - "output": output[0][0], - } + if triton_settings.openai_format_response: + return output + else: + return { + "output": output[0][0], + } except Exception as error: print("An exception occurred:", error) return {"error": "An exception occurred"} diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index a306231bcd61..59bfd02fcc33 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse +import json import logging import os import sys @@ -73,6 +74,9 @@ def get_args(argv): parser.add_argument( "-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server" ) + parser.add_argument( + "-trt", "--triton_request_timeout", default=60, type=int, help="Timeout in seconds for Triton server" + ) parser.add_argument( "-tmr", "--triton_model_repository", default=None, type=str, help="Folder for the trt-llm conversion" ) @@ -183,11 +187,33 @@ def get_args(argv): "-sha", "--service_http_address", default="0.0.0.0", type=str, help="HTTP address for the REST Service" ) parser.add_argument("-sp", "--service_port", default=8080, type=int, help="Port for the REST Service") + parser.add_argument( + "-ofr", + "--openai_format_response", + default=False, + type=bool, + help="Return the response from PyTriton server in OpenAI compatible format", + ) parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") args = parser.parse_args(argv) return args +def store_args_to_json(args): + """ + Stores user defined arg values relevant for REST API in config.json + Gets called only when args.start_rest_service is True. + """ + args_dict = { + "triton_service_ip": args.triton_http_address, + "triton_service_port": args.triton_port, + "triton_request_timeout": args.triton_request_timeout, + "openai_format_response": args.openai_format_response, + } + with open("nemo/deploy/service/config.json", "w") as f: + json.dump(args_dict, f) + + def get_trtllm_deployable(args): if args.triton_model_repository is None: trt_llm_path = "/tmp/trt_llm_model_dir/" @@ -318,6 +344,13 @@ def nemo_deploy(argv): LOGGER.info("Logging level set to {}".format(loglevel)) LOGGER.info(args) + if args.start_rest_service: + if args.service_port == args.triton_port: + logging.error("REST service port and Triton server port cannot use the same port.") + return + # Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py + store_args_to_json(args) + backend = args.backend.lower() if backend == 'tensorrt-llm': if not trt_llm_supported: