Skip to content

Commit

Permalink
Merge branch 'main' into image_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorge committed Feb 21, 2024
2 parents 6955a60 + 2689bd3 commit 6d65b2e
Show file tree
Hide file tree
Showing 12 changed files with 349 additions and 72 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ jobs:
- name: Run integration tests
env:
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
GOOGLE_SEARCH_API_KEY: ${{ secrets.GOOGLE_SEARCH_API_KEY }}
GOOGLE_CSE_ID: ${{ secrets.GOOGLE_CSE_ID }}
run: make integration_tests
working-directory: ${{ inputs.working-directory }}

Expand Down
8 changes: 2 additions & 6 deletions libs/genai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@ all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/

test:
poetry run pytest $(TEST_FILE)
integration_test integration_tests: TEST_FILE = tests/integration_tests/

tests:
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)

check_imports: $(shell find langchain_google_genai -name '*.py')
poetry run python ./scripts/check_imports.py $^

integration_tests:
poetry run pytest tests/integration_tests

######################
# LINTING AND FORMATTING
######################
Expand Down
7 changes: 2 additions & 5 deletions libs/vertexai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/

integration_tests: TEST_FILE = tests/integration_tests/
integration_test integration_tests: TEST_FILE = tests/integration_tests/

test integration_tests:
poetry run pytest $(TEST_FILE)

tests:
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)


Expand Down
27 changes: 25 additions & 2 deletions libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
from vertexai.generative_models import ( # type: ignore[import-untyped]
Candidate,
Image,
)
from vertexai.language_models import ( # type: ignore[import-untyped]
TextGenerationResponse,
)
from vertexai.preview.generative_models import Image # type: ignore[import-untyped]


def create_retry_decorator(
Expand Down Expand Up @@ -102,6 +102,7 @@ def get_generation_info(
is_gemini: bool,
*,
stream: bool = False,
usage_metadata: Optional[Dict] = None,
) -> Dict[str, Any]:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
Expand All @@ -121,11 +122,33 @@ def get_generation_info(
else None
),
}
if usage_metadata:
info["usage_metadata"] = usage_metadata
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
else:
info = dataclasses.asdict(candidate)
info.pop("text")
info = {k: v for k, v in info.items() if not k.startswith("_")}
if usage_metadata:
info_usage_metadata = {}
output_usage = usage_metadata.get("tokenMetadata", {}).get(
"outputTokenCount", {}
)
info_usage_metadata["candidates_billable_characters"] = output_usage.get(
"totalBillableCharacters"
)
info_usage_metadata["candidates_token_count"] = output_usage.get(
"totalTokens"
)
input_usage = usage_metadata.get("tokenMetadata", {}).get(
"inputTokenCount", {}
)
info_usage_metadata["prompt_billable_characters"] = input_usage.get(
"totalBillableCharacters"
)
info_usage_metadata["prompt_token_count"] = input_usage.get("totalTokens")
info["usage_metadata"] = {k: v for k, v in info_usage_metadata.items() if v}

if stream:
# Remove non-streamable types, like bools.
info.pop("is_blocked")
Expand Down
66 changes: 66 additions & 0 deletions libs/vertexai/langchain_google_vertexai/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import threading
from typing import Any, Dict, List

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult


class VertexAICallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks VertexAI info."""

prompt_tokens: int = 0
prompt_characters: int = 0
completion_tokens: int = 0
completion_characters: int = 0
successful_requests: int = 0

def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()

def __repr__(self) -> str:
return (
f"\tPrompt tokens: {self.prompt_tokens}\n"
f"\tPrompt characters: {self.prompt_characters}\n"
f"\tCompletion tokens: {self.completion_tokens}\n"
f"\tCompletion characters: {self.completion_characters}\n"
f"Successful requests: {self.successful_requests}\n"
)

@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Runs when LLM starts running."""
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Runs on new LLM token. Only available when streaming is enabled."""
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collects token usage."""
completion_tokens, prompt_tokens = 0, 0
completion_characters, prompt_characters = 0, 0
for generations in response.generations:
if len(generations) > 0 and generations[0].generation_info:
usage_metadata = generations[0].generation_info.get(
"usage_metadata", {}
)
completion_tokens += usage_metadata.get("candidates_token_count", 0)
prompt_tokens += usage_metadata.get("prompt_token_count", 0)
completion_characters += usage_metadata.get(
"candidates_billable_characters", 0
)
prompt_characters += usage_metadata.get("prompt_billable_characters", 0)

with self._lock:
self.prompt_characters += prompt_characters
self.prompt_tokens += prompt_tokens
self.completion_characters += completion_characters
self.completion_tokens += completion_tokens
self.successful_requests += 1
57 changes: 41 additions & 16 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,14 @@ def _generate(
)
generations = [
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
message=_parse_response_candidate(candidate),
generation_info=get_generation_info(
candidate,
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
for c in response.candidates
for candidate in response.candidates
]
else:
question = _get_question(messages)
Expand All @@ -382,10 +386,14 @@ def _generate(
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
message=AIMessage(content=candidate.text),
generation_info=get_generation_info(
candidate,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)
for r in response.candidates
for candidate in response.candidates
]
return ChatResult(generations=generations)

Expand Down Expand Up @@ -440,7 +448,11 @@ async def _agenerate(
generations = [
ChatGeneration(
message=_parse_response_candidate(c),
generation_info=get_generation_info(c, self._is_gemini_model),
generation_info=get_generation_info(
c,
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
for c in response.candidates
]
Expand All @@ -455,7 +467,11 @@ async def _agenerate(
generations = [
ChatGeneration(
message=AIMessage(content=r.text),
generation_info=get_generation_info(r, self._is_gemini_model),
generation_info=get_generation_info(
r,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)
for r in response.candidates
]
Expand Down Expand Up @@ -496,7 +512,12 @@ def _stream(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
)
),
generation_info=get_generation_info(
response.candidates[0],
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
else:
question = _get_question(messages)
Expand All @@ -506,13 +527,17 @@ def _stream(
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(response, self._is_gemini_model),
)
for response in responses:
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(
message=AIMessageChunk(content=response.text),
generation_info=get_generation_info(
response,
self._is_gemini_model,
usage_metadata=response.raw_prediction_response.metadata,
),
)

def _start_chat(
self, history: _ChatHistory, **kwargs: Any
Expand Down
57 changes: 34 additions & 23 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Dict, List, Type, Union
from typing import Any, Dict, List, Type, Union

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
Expand All @@ -25,17 +25,7 @@ def _format_pydantic_to_vertex_function(
return {
"name": schema["title"],
"description": schema.get("description", ""),
"parameters": {
"properties": {
k: {
"type": v["type"],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type": schema["type"],
},
"parameters": _get_parameters_from_schema(schema=schema),
}


Expand All @@ -48,17 +38,7 @@ def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription:
return {
"name": tool.name or schema["title"],
"description": tool.description or schema["description"],
"parameters": {
"properties": {
k: {
"type": v["type"],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type": schema["type"],
},
"parameters": _get_parameters_from_schema(schema=schema),
}
else:
return {
Expand Down Expand Up @@ -89,6 +69,37 @@ def _format_tools_to_vertex_tool(
return [VertexTool(function_declarations=function_declarations)]


def _get_parameters_from_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Given a schema, format the parameters key to match VertexAI
expected input.
Args:
schema: Dictionary that must have the following keys.
Returns:
Dictionary with the formatted parameters.
"""

parameters = {}

parameters["type"] = schema["type"]

if "required" in schema:
parameters["required"] = schema["required"]

schema_properties: Dict[str, Any] = schema.get("properties", {})

parameters["properties"] = {
parameter_name: {
"type": parameter_dict["type"],
"description": parameter_dict.get("description"),
}
for parameter_name, parameter_dict in schema_properties.items()
}

return parameters


class PydanticFunctionsOutputParser(BaseOutputParser):
"""Parse an output as a pydantic object.
Expand Down
Loading

0 comments on commit 6d65b2e

Please sign in to comment.