Skip to content

Commit

Permalink
Add OpenAI format response to r2.0.0rc1 (#9796)
Browse files Browse the repository at this point in the history
* Add OpenAI format response to r2.0.0rc1

Signed-off-by: Abhishree <[email protected]>

* Apply isort and black reformatting

Signed-off-by: athitten <[email protected]>

---------

Signed-off-by: Abhishree <[email protected]>
Signed-off-by: athitten <[email protected]>
Co-authored-by: athitten <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
3 people authored Jul 22, 2024
1 parent e6941de commit a52ee73
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 10 deletions.
14 changes: 13 additions & 1 deletion nemo/deploy/nlp/query_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from abc import ABC, abstractmethod

import numpy as np
Expand Down Expand Up @@ -172,6 +173,7 @@ def query_llm(
compute_logprob: bool = None,
end_strings=None,
init_timeout=60.0,
openai_format_response: bool = False,
):
"""
Query the Triton server synchronously and return a list of responses.
Expand Down Expand Up @@ -259,7 +261,17 @@ def query_llm(
return "Unknown output keyword."

sentences = np.char.decode(output.astype("bytes"), "utf-8")
return sentences
if openai_format_response:
openai_response = {
"id": f"cmpl-{int(time.time())}",
"object": "text_completion",
"created": int(time.time()),
"model": self.model_name,
"choices": [{"text": str(sentences)}],
}
return openai_response
else:
return sentences
else:
return result_dict["outputs"]

Expand Down
5 changes: 0 additions & 5 deletions nemo/deploy/service/config.json

This file was deleted.

39 changes: 35 additions & 4 deletions nemo/deploy/service/rest_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import json
import os
from pathlib import Path
import requests

from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from pydantic_settings import BaseSettings

Expand All @@ -33,6 +34,7 @@ def __init__(self):
self._triton_service_port = config_json["triton_service_port"]
self._triton_service_ip = config_json["triton_service_ip"]
self._triton_request_timeout = config_json["triton_request_timeout"]
self._openai_format_response = config_json["openai_format_response"]
except Exception as error:
print("An exception occurred:", error)
return
Expand All @@ -49,6 +51,14 @@ def triton_service_ip(self):
def triton_request_timeout(self):
return self._triton_request_timeout

@property
def openai_format_response(self):
"""
Retuns the response from Triton server in OpenAI compatible formar if set to True,
default set in config.json is false.
"""
return self._openai_format_response


app = FastAPI()
triton_settings = TritonSettings()
Expand All @@ -66,6 +76,23 @@ class CompletionRequest(BaseModel):
frequency_penalty: float = 1.0


@app.get("/triton_health")
async def check_triton_health():
"""
This method exposes endpoint "/triton_health" which can be used to verify if Triton server is accessible while running the REST or FastAPI application.
Verify by running: curl http://service_http_address:service_port/triton_health and the returned status should inform if the server is accessible.
"""
triton_url = f"triton_settings.triton_service_ip:str(triton_settings.triton_service_port)/v2/health/ready"
try:
response = requests.get(triton_url, timeout=5)
if response.status_code == 200:
return {"status": "Triton server is reachable and ready"}
else:
raise HTTPException(status_code=503, detail="Triton server is not ready")
except requests.RequestException as e:
raise HTTPException(status_code=503, detail=f"Cannot reach Triton server: {str(e)}")


@app.post("/v1/completions/")
def completions_v1(request: CompletionRequest):
try:
Expand All @@ -78,10 +105,14 @@ def completions_v1(request: CompletionRequest):
top_p=request.top_p,
temperature=request.temperature,
init_timeout=triton_settings.triton_request_timeout,
openai_format_response=triton_settings.openai_format_response,
)
return {
"output": output[0][0],
}
if triton_settings.openai_format_response:
return output
else:
return {
"output": output[0][0],
}
except Exception as error:
print("An exception occurred:", error)
return {"error": "An exception occurred"}
33 changes: 33 additions & 0 deletions scripts/deploy/nlp/deploy_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import argparse
import json
import logging
import os
import sys
Expand Down Expand Up @@ -73,6 +74,9 @@ def get_args(argv):
parser.add_argument(
"-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server"
)
parser.add_argument(
"-trt", "--triton_request_timeout", default=60, type=int, help="Timeout in seconds for Triton server"
)
parser.add_argument(
"-tmr", "--triton_model_repository", default=None, type=str, help="Folder for the trt-llm conversion"
)
Expand Down Expand Up @@ -183,11 +187,33 @@ def get_args(argv):
"-sha", "--service_http_address", default="0.0.0.0", type=str, help="HTTP address for the REST Service"
)
parser.add_argument("-sp", "--service_port", default=8080, type=int, help="Port for the REST Service")
parser.add_argument(
"-ofr",
"--openai_format_response",
default=False,
type=bool,
help="Return the response from PyTriton server in OpenAI compatible format",
)
parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode")
args = parser.parse_args(argv)
return args


def store_args_to_json(args):
"""
Stores user defined arg values relevant for REST API in config.json
Gets called only when args.start_rest_service is True.
"""
args_dict = {
"triton_service_ip": args.triton_http_address,
"triton_service_port": args.triton_port,
"triton_request_timeout": args.triton_request_timeout,
"openai_format_response": args.openai_format_response,
}
with open("nemo/deploy/service/config.json", "w") as f:
json.dump(args_dict, f)


def get_trtllm_deployable(args):
if args.triton_model_repository is None:
trt_llm_path = "/tmp/trt_llm_model_dir/"
Expand Down Expand Up @@ -318,6 +344,13 @@ def nemo_deploy(argv):
LOGGER.info("Logging level set to {}".format(loglevel))
LOGGER.info(args)

if args.start_rest_service:
if args.service_port == args.triton_port:
logging.error("REST service port and Triton server port cannot use the same port.")
return
# Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py
store_args_to_json(args)

backend = args.backend.lower()
if backend == 'tensorrt-llm':
if not trt_llm_supported:
Expand Down

0 comments on commit a52ee73

Please sign in to comment.