Skip to content

Commit

Permalink
Streaming (#3)
Browse files Browse the repository at this point in the history
* inprogress streaming but mostly working

* streaming works. Need to polish and also handle a special case

---------

Co-authored-by: Sanjay Nadhavajhala <[email protected]>
  • Loading branch information
tybalex and sanjay920 authored Jul 15, 2024
1 parent e80374c commit b2039be
Showing 1 changed file with 97 additions and 76 deletions.
173 changes: 97 additions & 76 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.utils import random_uuid
from llama_tools import preprocess_input, postprocess_output
from rubra_tools import preprocess_input, postprocess_output

logger = init_logger(__name__)

Expand Down Expand Up @@ -211,21 +211,18 @@ async def create_chat_completion(
try:
conversation: List[ConversationMessage] = []
image_futures: List[Awaitable[ImagePixelData]] = []
print("==================create chat completion====================")

for msg in request.messages:
chat_parsed_result = self._parse_chat_message_content(msg)

conversation.extend(chat_parsed_result.messages)
image_futures.extend(chat_parsed_result.image_futures)



if request.tools:
raw_msgs = request.messages
tools = [t.model_dump() for t in request.tools]
raw_msgs = preprocess_input(msgs=raw_msgs, tools=tools)
conversation = raw_msgs

prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
Expand Down Expand Up @@ -385,82 +382,106 @@ async def chat_completion_stream_generator(
yield f"data: {data}\n\n"
first_iteration = False

is_function_call = False
checked_function_call = False
for output in res.outputs:
i = output.index

if finish_reason_sent[i]:
continue

delta_token_ids = output.token_ids[previous_num_tokens[i]:]
out_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None

if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None

delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)

if request.tool_choice and type(
request.tool_choice
) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text))
])
else:
delta_message = DeltaMessage(content=delta_text)

if output.finish_reason is None:
# Send token-by-token response for each request.n

choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
if (not checked_function_call and len(output.text)>= 15):
if "starttoolcall" in output.text:
is_function_call = True
checked_function_call = True

if (checked_function_call and not is_function_call) or output.finish_reason is not None:

delta_token_ids = output.token_ids[previous_num_tokens[i]:]
out_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None

if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None

delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)

if request.tool_choice and type(
request.tool_choice
) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text))
])
else:
content = delta_text
function_output = postprocess_output(output_str=content)
tool_calls = []
if function_output:
try:
for fc in function_output:
function = FunctionCall(name=fc["function"]["name"], arguments=fc["function"]["arguments"])
call = ToolCall(function=function)
tool_calls.append(call)
content = ""
except Exception as e:
content = str(function_output)
print(f"Error extract functions from output: {e}")
delta_message = DeltaMessage(content=content, tool_calls=tool_calls)

if output.finish_reason is None:
# Send token-by-token response for each request.n

choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=None)

chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)

if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=False)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True

if (request.stream_options
and request.stream_options.include_usage):
Expand Down Expand Up @@ -542,7 +563,7 @@ async def chat_completion_full_generator(
function_output = postprocess_output(output_str=content)
tool_calls = []
if function_output:
print(f"Parsed function output: {function_output}\n\n")

try:
for fc in function_output:
function = FunctionCall(name=fc["function"]["name"], arguments=fc["function"]["arguments"])
Expand Down

0 comments on commit b2039be

Please sign in to comment.