Skip to content

Commit

Permalink
Support token usage in streaming (#30)
Browse files Browse the repository at this point in the history
* Support token usage in streaming

Signed-off-by: B-Step62 <[email protected]>

* docstring

Signed-off-by: B-Step62 <[email protected]>

* comment

Signed-off-by: B-Step62 <[email protected]>

---------

Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 authored Oct 21, 2024
1 parent 91e9810 commit 43404a1
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 2 deletions.
52 changes: 50 additions & 2 deletions libs/databricks/langchain_databricks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
Expand Down Expand Up @@ -157,6 +158,30 @@ class ChatDatabricks(BaseChatModel):
id='run-4cef851f-6223-424f-ad26-4a54e5852aa5'
)
To get token usage returned when streaming, pass the ``stream_usage`` kwarg:
.. code-block:: python
stream = llm.stream(messages, stream_usage=True)
next(stream).usage_metadata
.. code-block:: python
{"input_tokens": 28, "output_tokens": 5, "total_tokens": 33}
Alternatively, setting ``stream_usage`` when instantiating the model can be
useful when incorporating ``ChatDatabricks`` into LCEL chains-- or when using
methods like ``.with_structured_output``, which generate chains under the
hood.
.. code-block:: python
llm = ChatDatabricks(
endpoint="databricks-meta-llama-3-1-405b-instruct",
stream_usage=True
)
structured_llm = llm.with_structured_output(...)
Async:
.. code-block:: python
Expand Down Expand Up @@ -229,6 +254,10 @@ class GetPopulation(BaseModel):
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
extra_params: Optional[Dict[str, Any]] = None
"""Whether to include usage metadata in streaming output. If True, additional
message chunks will be generated during the stream including usage metadata.
"""
stream_usage: bool = False
"""Any extra parameters to pass to the endpoint."""
client: Optional[BaseDeploymentClient] = Field(
default=None, exclude=True
Expand Down Expand Up @@ -301,8 +330,12 @@ def _stream(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if stream_usage is None:
stream_usage = self.stream_usage
data = self._prepare_inputs(messages, stop, **kwargs)
first_chunk_role = None
for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data): # type: ignore
Expand All @@ -313,8 +346,19 @@ def _stream(
if first_chunk_role is None:
first_chunk_role = chunk_delta.get("role")

if stream_usage and (usage := chunk.get("usage")):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
else:
usage = None

chunk_message = _convert_dict_to_message_chunk(
chunk_delta, first_chunk_role
chunk_delta, first_chunk_role, usage=usage
)

generation_info = {}
Expand Down Expand Up @@ -759,7 +803,9 @@ def _convert_dict_to_message(_dict: Dict) -> BaseMessage:


def _convert_dict_to_message_chunk(
_dict: Mapping[str, Any], default_role: str
_dict: Mapping[str, Any],
default_role: str,
usage: Optional[Dict[str, Any]] = None,
) -> BaseMessageChunk:
role = _dict.get("role", default_role)
content = _dict.get("content")
Expand Down Expand Up @@ -790,11 +836,13 @@ def _convert_dict_to_message_chunk(
]
except KeyError:
pass
usage_metadata = UsageMetadata(**usage) if usage else None # type: ignore
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=_dict.get("id"),
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata,
)
else:
return ChatMessageChunk(content=content, role=role)
32 changes: 32 additions & 0 deletions libs/databricks/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,38 @@ def on_llm_new_token(self, *args, **kwargs):
assert last_chunk.response_metadata["finish_reason"] == "stop"


def test_chat_databricks_stream_with_usage():
class FakeCallbackHandler(BaseCallbackHandler):
def __init__(self):
self.chunk_counts = 0

def on_llm_new_token(self, *args, **kwargs):
self.chunk_counts += 1

callback = FakeCallbackHandler()

chat = ChatDatabricks(
endpoint=_TEST_ENDPOINT,
temperature=0,
stop=["Python"],
max_tokens=100,
stream_usage=True,
)

chunks = list(chat.stream("How to learn Python?", config={"callbacks": [callback]}))
assert len(chunks) > 0
assert all(isinstance(chunk, AIMessageChunk) for chunk in chunks)
assert all("Python" not in chunk.content for chunk in chunks)
assert callback.chunk_counts == len(chunks)

last_chunk = chunks[-1]
assert last_chunk.response_metadata["finish_reason"] == "stop"
assert last_chunk.usage_metadata is not None
assert last_chunk.usage_metadata["input_tokens"] > 0
assert last_chunk.usage_metadata["output_tokens"] > 0
assert last_chunk.usage_metadata["total_tokens"] > 0


@pytest.mark.asyncio
async def test_chat_databricks_ainvoke():
chat = ChatDatabricks(
Expand Down
37 changes: 37 additions & 0 deletions libs/databricks/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,43 @@ def test_chat_model_stream(llm: ChatDatabricks) -> None:
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]


def test_chat_model_stream_with_usage(llm: ChatDatabricks) -> None:
def _assert_usage(chunk, expected):
usage = chunk.usage_metadata
assert usage is not None
assert usage["input_tokens"] == expected["usage"]["prompt_tokens"]
assert usage["output_tokens"] == expected["usage"]["completion_tokens"]
assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"]

# Method 1: Pass stream_usage=True to the constructor
res = llm.stream(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "36939 * 8922.4"},
],
stream_usage=True,
)
for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE):
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]
_assert_usage(chunk, expected)

# Method 2: Pass stream_usage=True to the constructor
llm_with_usage = ChatDatabricks(
endpoint="databricks-meta-llama-3-70b-instruct",
target_uri="databricks",
stream_usage=True,
)
res = llm_with_usage.stream(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "36939 * 8922.4"},
],
)
for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE):
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]
_assert_usage(chunk, expected)


class GetWeather(BaseModel):
"""Get the current weather in a given location"""

Expand Down

0 comments on commit 43404a1

Please sign in to comment.