From b2039bedf3fa6ba8a27109769934ce6142b0f9c0 Mon Sep 17 00:00:00 2001 From: Yingbei Tong Date: Mon, 15 Jul 2024 21:12:20 +0000 Subject: [PATCH] Streaming (#3) * inprogress streaming but mostly working * streaming works. Need to polish and also handle a special case --------- Co-authored-by: Sanjay Nadhavajhala --- vllm/entrypoints/openai/serving_chat.py | 173 +++++++++++++----------- 1 file changed, 97 insertions(+), 76 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9d07495dc7d54..1f0a0879b6d19 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -34,7 +34,7 @@ from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.utils import random_uuid -from llama_tools import preprocess_input, postprocess_output +from rubra_tools import preprocess_input, postprocess_output logger = init_logger(__name__) @@ -211,21 +211,18 @@ async def create_chat_completion( try: conversation: List[ConversationMessage] = [] image_futures: List[Awaitable[ImagePixelData]] = [] - print("==================create chat completion====================") for msg in request.messages: chat_parsed_result = self._parse_chat_message_content(msg) conversation.extend(chat_parsed_result.messages) image_futures.extend(chat_parsed_result.image_futures) - - + if request.tools: raw_msgs = request.messages tools = [t.model_dump() for t in request.tools] raw_msgs = preprocess_input(msgs=raw_msgs, tools=tools) conversation = raw_msgs - prompt = self.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, @@ -385,82 +382,106 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" first_iteration = False + is_function_call = False + checked_function_call = False for output in res.outputs: i = output.index if finish_reason_sent[i]: continue - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - - if request.logprobs and request.top_logprobs is not None: - assert out_logprobs is not None, ( - "Did not output logprobs") - logprobs = self._create_chat_logprobs( - token_ids=delta_token_ids, - top_logprobs=out_logprobs, - num_output_top_logprobs=request.top_logprobs, - ) - else: - logprobs = None - - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - - if request.tool_choice and type( - request.tool_choice - ) is ChatCompletionNamedToolChoiceParam: - delta_message = DeltaMessage(tool_calls=[ - ToolCall(function=FunctionCall( - name=request.tool_choice.function.name, - arguments=delta_text)) - ]) - else: - delta_message = DeltaMessage(content=delta_text) - - if output.finish_reason is None: - # Send token-by-token response for each request.n - - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=delta_message, - logprobs=logprobs, - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if (request.stream_options - and request.stream_options.include_usage): - chunk.usage = None - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - else: - # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=delta_message, - logprobs=logprobs, - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if (request.stream_options - and request.stream_options.include_usage): - chunk.usage = None - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - finish_reason_sent[i] = True + if (not checked_function_call and len(output.text)>= 15): + if "starttoolcall" in output.text: + is_function_call = True + checked_function_call = True + + if (checked_function_call and not is_function_call) or output.finish_reason is not None: + + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + out_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_chat_logprobs( + token_ids=delta_token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + ) + else: + logprobs = None + + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + + if request.tool_choice and type( + request.tool_choice + ) is ChatCompletionNamedToolChoiceParam: + delta_message = DeltaMessage(tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=delta_text)) + ]) + else: + content = delta_text + function_output = postprocess_output(output_str=content) + tool_calls = [] + if function_output: + try: + for fc in function_output: + function = FunctionCall(name=fc["function"]["name"], arguments=fc["function"]["arguments"]) + call = ToolCall(function=function) + tool_calls.append(call) + content = "" + except Exception as e: + content = str(function_output) + print(f"Error extract functions from output: {e}") + delta_message = DeltaMessage(content=content, tool_calls=tool_calls) + + if output.finish_reason is None: + # Send token-by-token response for each request.n + + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None) + + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + if (request.stream_options + and request.stream_options.include_usage): + chunk.usage = None + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + else: + # Send the finish response for each request.n only once + prompt_tokens = len(res.prompt_token_ids) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if (request.stream_options + and request.stream_options.include_usage): + chunk.usage = None + data = chunk.model_dump_json(exclude_unset=False) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True if (request.stream_options and request.stream_options.include_usage): @@ -542,7 +563,7 @@ async def chat_completion_full_generator( function_output = postprocess_output(output_str=content) tool_calls = [] if function_output: - print(f"Parsed function output: {function_output}\n\n") + try: for fc in function_output: function = FunctionCall(name=fc["function"]["name"], arguments=fc["function"]["arguments"])