Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added cost callback #18

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
from vertexai.generative_models import ( # type: ignore[import-untyped]
Candidate,
Image,
)
from vertexai.language_models import ( # type: ignore[import-untyped]
TextGenerationResponse,
)
from vertexai.preview.generative_models import Image # type: ignore[import-untyped]


def create_retry_decorator(
Expand Down Expand Up @@ -102,6 +102,7 @@ def get_generation_info(
is_gemini: bool,
*,
stream: bool = False,
usage_metadata: Optional[Dict] = None,
) -> Dict[str, Any]:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
Expand All @@ -121,11 +122,33 @@ def get_generation_info(
else None
),
}
if usage_metadata:
info["usage_metadata"] = usage_metadata
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
else:
info = dataclasses.asdict(candidate)
info.pop("text")
info = {k: v for k, v in info.items() if not k.startswith("_")}
if usage_metadata:
info_usage_metadata = {}
output_usage = usage_metadata.get("tokenMetadata", {}).get(
"outputTokenCount", {}
)
info_usage_metadata["candidates_billable_characters"] = output_usage.get(
"totalBillableCharacters"
)
info_usage_metadata["candidates_token_count"] = output_usage.get(
"totalTokens"
)
input_usage = usage_metadata.get("tokenMetadata", {}).get(
"inputTokenCount", {}
)
info_usage_metadata["prompt_billable_characters"] = input_usage.get(
"totalBillableCharacters"
)
info_usage_metadata["prompt_token_count"] = input_usage.get("totalTokens")
info["usage_metadata"] = {k: v for k, v in info_usage_metadata.items() if v}

if stream:
# Remove non-streamable types, like bools.
info.pop("is_blocked")
Expand Down
66 changes: 66 additions & 0 deletions libs/vertexai/langchain_google_vertexai/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import threading
from typing import Any, Dict, List

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult


class VertexAICallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks VertexAI info."""

prompt_tokens: int = 0
prompt_characters: int = 0
completion_tokens: int = 0
completion_characters: int = 0
successful_requests: int = 0

def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()

def __repr__(self) -> str:
return (
f"\tPrompt tokens: {self.prompt_tokens}\n"
f"\tPrompt characters: {self.prompt_characters}\n"
f"\tCompletion tokens: {self.completion_tokens}\n"
f"\tCompletion characters: {self.completion_characters}\n"
f"Successful requests: {self.successful_requests}\n"
)

@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Runs when LLM starts running."""
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Runs on new LLM token. Only available when streaming is enabled."""
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collects token usage."""
completion_tokens, prompt_tokens = 0, 0
completion_characters, prompt_characters = 0, 0
for generations in response.generations:
if len(generations) > 0 and generations[0].generation_info:
usage_metadata = generations[0].generation_info.get(
"usage_metadata", {}
)
completion_tokens += usage_metadata.get("candidates_token_count", 0)
prompt_tokens += usage_metadata.get("prompt_token_count", 0)
completion_characters += usage_metadata.get(
"candidates_billable_characters", 0
)
prompt_characters += usage_metadata.get("prompt_billable_characters", 0)

with self._lock:
self.prompt_characters += prompt_characters
self.prompt_tokens += prompt_tokens
self.completion_characters += completion_characters
self.completion_tokens += completion_tokens
self.successful_requests += 1
57 changes: 41 additions & 16 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,14 @@ def _generate(
)
generations = [
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
message=_parse_response_candidate(candidate),
generation_info=get_generation_info(
candidate,
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
for c in response.candidates
for candidate in response.candidates
]
Copy link
Contributor

@jzaldi jzaldi Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ultra nit: c, r -> candidate to make it clearer? (In general)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
question = _get_question(messages)
Expand All @@ -412,10 +416,14 @@ def _generate(
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
message=AIMessage(content=candidate.text),
generation_info=get_generation_info(
candidate,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)
for r in response.candidates
for candidate in response.candidates
]
return ChatResult(generations=generations)

Expand Down Expand Up @@ -470,7 +478,11 @@ async def _agenerate(
generations = [
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
generation_info=get_generation_info(
c,
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
for c in response.candidates
]
Expand All @@ -485,7 +497,11 @@ async def _agenerate(
generations = [
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
generation_info=get_generation_info(
r,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)
for r in response.candidates
]
Expand Down Expand Up @@ -526,7 +542,12 @@ def _stream(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
)
),
generation_info=get_generation_info(
response.candidates[0],
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
else:
question = _get_question(messages)
Expand All @@ -536,13 +557,17 @@ def _stream(
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(response, self._is_gemini_model),
)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(
response,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)

def _start_chat(
self, history: _ChatHistory, **kwargs: Any
Expand Down
45 changes: 33 additions & 12 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from vertexai.generative_models import ( # type: ignore[import-untyped]
Candidate,
GenerativeModel,
Image,
)
Expand Down Expand Up @@ -327,12 +328,19 @@ def validate_environment(cls, values: Dict) -> Dict:
raise ValueError("Only one candidate can be generated with streaming!")
return values

def _response_to_generation(
self, response: TextGenerationResponse, *, stream: bool = False
def _candidate_to_generation(
self,
response: Union[Candidate, TextGenerationResponse],
*,
stream: bool = False,
usage_metadata: Optional[Dict] = None,
) -> GenerationChunk:
"""Converts a stream response to a generation chunk."""
generation_info = get_generation_info(
response, self._is_gemini_model, stream=stream
response,
self._is_gemini_model,
stream=stream,
usage_metadata=usage_metadata,
)
try:
text = response.text
Expand Down Expand Up @@ -373,8 +381,15 @@ def _generate(
run_manager=run_manager,
**params,
)
if self._is_gemini_model:
usage_metadata = res.to_dict().get("usage_metadata")
else:
usage_metadata = res.raw_prediction_response.metadata
generations.append(
[self._response_to_generation(r) for r in res.candidates]
[
self._candidate_to_generation(r, usage_metadata=usage_metadata)
for r in res.candidates
]
)
return LLMResult(generations=generations)

Expand All @@ -395,8 +410,15 @@ async def _agenerate(
run_manager=run_manager,
**params,
)
if self._is_gemini_model:
usage_metadata = res.to_dict().get("usage_metadata")
else:
usage_metadata = res.raw_prediction_response.metadata
generations.append(
[self._response_to_generation(r) for r in res.candidates]
[
self._candidate_to_generation(r, usage_metadata=usage_metadata)
for r in res.candidates
]
)
return LLMResult(generations=generations)

Expand All @@ -416,14 +438,13 @@ def _stream(
run_manager=run_manager,
**params,
):
# Gemini models return GenerationResponse even when streaming, which has a
# candidates field.
stream_resp = (
stream_resp
if isinstance(stream_resp, TextGenerationResponse)
else stream_resp.candidates[0]
usage_metadata = None
if self._is_gemini_model:
usage_metadata = stream_resp.to_dict().get("usage_metadata")
stream_resp = stream_resp.candidates[0]
chunk = self._candidate_to_generation(
stream_resp, stream=True, usage_metadata=usage_metadata
)
chunk = self._response_to_generation(stream_resp, stream=True)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
Expand Down
84 changes: 84 additions & 0 deletions libs/vertexai/tests/integration_tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest
from langchain_core.messages import HumanMessage

from langchain_google_vertexai.callbacks import VertexAICallbackHandler
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.llms import VertexAI


@pytest.mark.parametrize(
"model_name",
["gemini-pro", "text-bison@001", "code-bison@001"],
)
def test_llm_invoke(model_name: str) -> None:
vb = VertexAICallbackHandler()
llm = VertexAI(model_name=model_name, temperature=0.0, callbacks=[vb])
_ = llm.invoke("2+2")
assert vb.successful_requests == 1
assert vb.prompt_tokens > 0
assert vb.completion_tokens > 0
prompt_tokens = vb.prompt_tokens
completion_tokens = vb.completion_tokens
_ = llm.invoke("2+2")
assert vb.successful_requests == 2
assert vb.prompt_tokens > prompt_tokens
assert vb.completion_tokens > completion_tokens


@pytest.mark.parametrize(
"model_name",
["gemini-pro", "chat-bison@001", "codechat-bison@001"],
)
def test_chat_call(model_name: str) -> None:
vb = VertexAICallbackHandler()
llm = ChatVertexAI(model_name=model_name, temperature=0.0, callbacks=[vb])
message = HumanMessage(content="Hello")
_ = llm([message])
assert vb.successful_requests == 1
assert vb.prompt_tokens > 0
assert vb.completion_tokens > 0
prompt_tokens = vb.prompt_tokens
completion_tokens = vb.completion_tokens
_ = llm([message])
assert vb.successful_requests == 2
assert vb.prompt_tokens > prompt_tokens
assert vb.completion_tokens > completion_tokens


@pytest.mark.parametrize(
"model_name",
["gemini-pro", "text-bison@001", "code-bison@001"],
)
def test_invoke_config(model_name: str) -> None:
vb = VertexAICallbackHandler()
llm = VertexAI(model_name=model_name, temperature=0.0)
llm.invoke("2+2", config={"callbacks": [vb]})
assert vb.successful_requests == 1
assert vb.prompt_tokens > 0
assert vb.completion_tokens > 0
prompt_tokens = vb.prompt_tokens
completion_tokens = vb.completion_tokens
llm.invoke("2+2", config={"callbacks": [vb]})
assert vb.successful_requests == 2
assert vb.prompt_tokens > prompt_tokens
assert vb.completion_tokens > completion_tokens


def test_llm_stream() -> None:
vb = VertexAICallbackHandler()
llm = VertexAI(model_name="gemini-pro", temperature=0.0, callbacks=[vb])
for _ in llm.stream("2+2"):
pass
assert vb.successful_requests == 1
assert vb.prompt_tokens > 0
assert vb.completion_tokens > 0


def test_chat_stream() -> None:
vb = VertexAICallbackHandler()
llm = ChatVertexAI(model_name="gemini-pro", temperature=0.0, callbacks=[vb])
for _ in llm.stream("2+2"):
pass
assert vb.successful_requests == 1
assert vb.completion_tokens > 0
assert vb.completion_tokens > 0
Loading
Loading