Skip to content

Commit

Permalink
Merge branch 'langchain-ai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lspataroG authored Jun 28, 2024
2 parents bf23101 + e7b4572 commit 9a036a8
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 34 deletions.
4 changes: 2 additions & 2 deletions libs/vertexai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff .
poetry run ruff check .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
poetry run ruff check --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
Expand Down
25 changes: 10 additions & 15 deletions libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Optional, Type

from langchain_core.messages import ToolCall
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
Expand All @@ -26,25 +26,20 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An
"""
if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else []

message = result[0].message
if isinstance(message.content, str):
tool_calls: List = []
else:
content: List = message.content
_tool_calls = [dict(tc) for tc in _extract_tool_calls(content)]
# Map tool call id to index
id_to_index = {
block["id"]: i
for i, block in enumerate(content)
if block["type"] == "tool_use"
}
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
tool_calls: List[Any] = []

if isinstance(message, AIMessage) and message.tool_calls:
tool_calls = message.tool_calls
elif isinstance(message.content, list):
content: Any = message.content
tool_calls = _extract_tool_calls(content)

if self.pydantic_schemas:
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
elif self.args_only:
tool_calls = [tc["args"] for tc in tool_calls]
else:
pass

if self.first_tool_only:
return tool_calls[0] if tool_calls else None
Expand Down
39 changes: 24 additions & 15 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
)
parts.append(Part(function_call=function_call))

prev_content = vertex_messages[-1]
prev_content_is_model = prev_content and prev_content.role == "model"
if prev_content_is_model:
prev_parts = list(prev_content.parts)
prev_parts.extend(parts)
vertex_messages[-1] = Content(role=role, parts=prev_parts)
continue

vertex_messages.append(Content(role=role, parts=parts))
elif isinstance(message, FunctionMessage):
prev_ai_message = None
Expand All @@ -306,18 +314,18 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
name=message.name, response={"content": message.content}
)
)
parts = [part]

prev_content = vertex_messages[-1]
prev_content_is_function = prev_content and prev_content.role == "function"

if prev_content_is_function:
parts = list(prev_content.parts)
parts.append(part)
prev_parts = list(prev_content.parts)
prev_parts.extend(parts)
# replacing last message
vertex_messages[-1] = Content(role=role, parts=parts)
vertex_messages[-1] = Content(role=role, parts=prev_parts)
continue

parts = [part]

vertex_messages.append(Content(role=role, parts=parts))
elif isinstance(message, ToolMessage):
role = "function"
Expand Down Expand Up @@ -383,18 +391,19 @@ def _parse_content(raw_content: str | Dict[Any, Any]) -> Dict[Any, Any]:
response=content,
)
)
parts = [part]

prev_content = vertex_messages[-1]
prev_content_is_function = prev_content and prev_content.role == "function"

if prev_content_is_function:
parts = list(prev_content.parts)
parts.append(part)
prev_parts = list(prev_content.parts)
prev_parts.extend(parts)
# replacing last message
vertex_messages[-1] = Content(role=role, parts=parts)
vertex_messages[-1] = Content(role=role, parts=prev_parts)
continue
else:
parts = [part]
vertex_messages.append(Content(role=role, parts=parts))

vertex_messages.append(Content(role=role, parts=parts))
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
Expand Down Expand Up @@ -966,11 +975,11 @@ class Joke(BaseModel):
setting this parameter to True is discouraged.
"""
response_mime_type: Optional[str] = None
"""Optional. Output response mimetype of the generated candidate text. Only
supported in Gemini 1.5 and later models. Supported mimetype:
* "text/plain": (default) Text output.
"""Optional. Output response mimetype of the generated candidate text. Only
supported in Gemini 1.5 and later models. Supported mimetype:
* "text/plain": (default) Text output.
* "application/json": JSON response in the candidates.
The model also needs to be prompted to output the appropriate response
The model also needs to be prompted to output the appropriate response
type, otherwise the behavior is undefined. This is a preview feature.
"""

Expand Down
3 changes: 2 additions & 1 deletion libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def _prepare_and_validate_batches(
first_result = self._get_embeddings_with_retry(
first_batch, embeddings_type
)
batches = batches[1:]
break
except InvalidArgument:
had_failure = True
Expand All @@ -347,6 +346,8 @@ def _prepare_and_validate_batches(
batches = VertexAIEmbeddings._prepare_batches(
texts[first_batch_len:], self.instance["batch_size"]
)
else:
batches = batches[1:]
else:
# Still figuring out max batch size.
batches = batches[1:]
Expand Down
6 changes: 6 additions & 0 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ def _format_to_gapic_function_declaration(
elif isinstance(tool, vertexai.FunctionDeclaration):
return _format_vertex_to_function_declaration(tool)
elif isinstance(tool, dict):
# this could come from
# 'langchain_core.utils.function_calling.convert_to_openai_tool'
if tool.get("type") == "function" and tool.get("function"):
return _format_dict_to_function_declaration(
cast(FunctionDescription, tool.get("function"))
)
return _format_dict_to_function_declaration(tool)
else:
raise ValueError(f"Unsupported tool call type {tool}")
Expand Down
23 changes: 23 additions & 0 deletions libs/vertexai/tests/integration_tests/test_model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,26 @@ def my_tool(name: str, age: int) -> None:
assert tool_call_chunk["args"]
if tool_call_chunk["args"]:
assert json.loads(tool_call_chunk["args"]) == {"age": 27.0, "name": "Erick"}


@pytest.mark.extended
def test_anthropic_with_structured_output() -> None:
project = os.environ["PROJECT_ID"]
location = "us-east5"
model = ChatAnthropicVertex(
project=project,
location=location,
model="claude-3-opus@20240229",
)

class MyModel(BaseModel):
name: str
age: int

message = HumanMessage(content="My name is Erick and I am 27 years old")
model_with_structured_output = model.with_structured_output(MyModel)
response = model_with_structured_output.invoke([message])

assert isinstance(response, MyModel)
assert response.name == "Erick"
assert response.age == 27
45 changes: 45 additions & 0 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,51 @@ def test_parse_history_gemini_function() -> None:
)
],
),
(
[
AIMessage(
content=["Mike age is 30"],
tool_calls=[
ToolCall(
name="Information",
args={"name": "Rob"},
id="00000000-0000-0000-0000-00000000000",
),
],
),
AIMessage(
content=["Arthur age is 30"],
tool_calls=[
ToolCall(
name="Information",
args={"name": "Ben"},
id="00000000-0000-0000-0000-00000000000",
),
],
),
],
[
Content(
role="model",
parts=[
Part(text="Mike age is 30"),
Part(
function_call=FunctionCall(
name="Information",
args={"name": "Rob"},
)
),
Part(text="Arthur age is 30"),
Part(
function_call=FunctionCall(
name="Information",
args={"name": "Ben"},
)
),
],
)
],
),
],
)
def test_parse_history_gemini_multi(source_history, expected_history) -> None:
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
def test_langchain_google_vertexai_no_dups_dynamic_batch_size() -> None:
mock_embeddings = MockVertexAIEmbeddings("textembedding-gecko@001")
default_batch_size = mock_embeddings.instance["batch_size"]
texts = ["text_{i}" for i in range(default_batch_size * 2)]
texts = ["text {i}" for i in range(default_batch_size * 2)]
# It should only return one batch (out of two) still to process
_, batches = mock_embeddings._prepare_and_validate_batches(texts=texts)
assert len(batches) == 1
Expand Down

0 comments on commit 9a036a8

Please sign in to comment.