Skip to content

Commit

Permalink
Add Parallel Tool mode for Vertex AI
Browse files Browse the repository at this point in the history
  • Loading branch information
fayzfi committed Nov 24, 2024
1 parent 58eef74 commit c582618
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 15 deletions.
42 changes: 30 additions & 12 deletions instructor/client_vertexai.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations

from typing import Any
from typing import Any, Type, Union, get_origin

from vertexai.preview.generative_models import ToolConfig # type: ignore
import vertexai.generative_models as gm # type: ignore
from pydantic import BaseModel
import instructor
from instructor.dsl.parallel import get_types_array
import jsonref


def _create_gemini_json_schema(model: BaseModel):
# Add type check to ensure we have a concrete model class
if get_origin(model) is not None:
raise TypeError(f"Expected concrete model class, got type hint {model}")

schema = model.model_json_schema()
schema_without_refs: dict[str, Any] = jsonref.replace_refs(schema) # type: ignore
gemini_schema: dict[Any, Any] = {
Expand All @@ -22,16 +27,28 @@ def _create_gemini_json_schema(model: BaseModel):
return gemini_schema


def _create_vertexai_tool(model: BaseModel) -> gm.Tool:
parameters = _create_gemini_json_schema(model)

declaration = gm.FunctionDeclaration(
name=model.__name__, description=model.__doc__, parameters=parameters
)

tool = gm.Tool(function_declarations=[declaration])
def _create_vertexai_tool(models: Union[BaseModel, list[BaseModel], Type]) -> gm.Tool:
"""Creates a tool with function declarations for single model or list of models"""
# Handle Iterable case first
if get_origin(models) is not None:
model_list = list(get_types_array(models))
else:
# Handle both single model and list of models
model_list = models if isinstance(models, list) else [models]

print(f"Debug - Model list: {[model.__name__ for model in model_list]}")

declarations = []
for model in model_list:
parameters = _create_gemini_json_schema(model)
declaration = gm.FunctionDeclaration(
name=model.__name__,
description=model.__doc__,
parameters=parameters
)
declarations.append(declaration)

return tool
return gm.Tool(function_declarations=declarations)


def vertexai_message_parser(
Expand Down Expand Up @@ -84,11 +101,11 @@ def vertexai_function_response_parser(
)


def vertexai_process_response(_kwargs: dict[str, Any], model: BaseModel):
def vertexai_process_response(_kwargs: dict[str, Any], model: Union[BaseModel, list[BaseModel], Type]):
messages: list[dict[str, str]] = _kwargs.pop("messages")
contents = _vertexai_message_list_parser(messages) # type: ignore

tool = _create_vertexai_tool(model=model)
tool = _create_vertexai_tool(models=model)

tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
Expand Down Expand Up @@ -122,6 +139,7 @@ def from_vertexai(
**kwargs: Any,
) -> instructor.Instructor:
assert mode in {
instructor.Mode.VERTEXAI_PARALLEL_TOOLS,
instructor.Mode.VERTEXAI_TOOLS,
instructor.Mode.VERTEXAI_JSON,
}, "Mode must be instructor.Mode.VERTEXAI_TOOLS"
Expand Down
37 changes: 37 additions & 0 deletions instructor/dsl/parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import json
from typing import (
Any,
Optional,
Expand Down Expand Up @@ -45,6 +46,38 @@ def from_response(
)


class VertexAIParallelBase(ParallelBase):
def from_response(
self,
response: Any,
mode: Mode,
validation_context: Optional[Any] = None,
strict: Optional[bool] = None,
) -> Generator[BaseModel, None, None]:
assert mode == Mode.VERTEXAI_PARALLEL_TOOLS, "Mode must be VERTEXAI_PARALLEL_TOOLS"

if not response or not response.candidates:
return

for candidate in response.candidates:
if not candidate.content or not candidate.content.parts:
continue

for part in candidate.content.parts:
if (hasattr(part, 'function_call') and
part.function_call is not None):

name = part.function_call.name
arguments = part.function_call.args

if name in self.registry:
# Convert dict to JSON string before validation
json_str = json.dumps(arguments)
yield self.registry[name].model_validate_json(
json_str, context=validation_context, strict=strict
)


if sys.version_info >= (3, 10):
from types import UnionType

Expand Down Expand Up @@ -82,3 +115,7 @@ def handle_parallel_model(typehint: type[Iterable[T]]) -> list[dict[str, Any]]:
def ParallelModel(typehint: type[Iterable[T]]) -> ParallelBase:
the_types = get_types_array(typehint)
return ParallelBase(*[model for model in the_types])

def VertexAIParallelModel(typehint: type[Iterable[T]]) -> VertexAIParallelBase:
the_types = get_types_array(typehint)
return VertexAIParallelBase(*[model for model in the_types])
1 change: 1 addition & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Mode(enum.Enum):
COHERE_TOOLS = "cohere_tools"
VERTEXAI_TOOLS = "vertexai_tools"
VERTEXAI_JSON = "vertexai_json"
VERTEXAI_PARALLEL_TOOLS = "vertexai_parallel_tools"
GEMINI_JSON = "gemini_json"
GEMINI_TOOLS = "gemini_tools"
COHERE_JSON_SCHEMA = "json_object"
Expand Down
36 changes: 33 additions & 3 deletions instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@

from instructor.mode import Mode
from instructor.dsl.iterable import IterableBase, IterableModel
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
from instructor.dsl.parallel import (
ParallelBase,
ParallelModel,
handle_parallel_model,
get_types_array,
VertexAIParallelBase,
VertexAIParallelModel
)
from instructor.dsl.partial import PartialBase
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
from instructor.function_calls import OpenAISchema, openai_schema
Expand Down Expand Up @@ -112,7 +119,7 @@ def process_response(
validation_context: dict[str, Any] | None = None,
strict=None,
mode: Mode = Mode.TOOLS,
):
) -> T_Model | list[T_Model] | VertexAIParallelBase | None:
"""
Process the response from the API call and convert it to the specified response model.
Expand Down Expand Up @@ -485,6 +492,27 @@ def handle_gemini_tools(
return response_model, new_kwargs


def handle_vertexai_parallel_tools(
response_model: type[Iterable[T]], new_kwargs: dict[str, Any]
) -> tuple[VertexAIParallelBase, dict[str, Any]]:
assert (
new_kwargs.get("stream", False) is False
), "stream=True is not supported when using PARALLEL_TOOLS mode"

from instructor.client_vertexai import vertexai_process_response
from instructor.dsl.parallel import VertexAIParallelModel

# Extract concrete types before passing to vertexai_process_response
model_types = list(get_types_array(response_model))
contents, tools, tool_config = vertexai_process_response(new_kwargs, model_types)

new_kwargs["contents"] = contents
new_kwargs["tools"] = tools
new_kwargs["tool_config"] = tool_config

return VertexAIParallelModel(typehint=response_model), new_kwargs


def handle_vertexai_tools(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
Expand Down Expand Up @@ -646,7 +674,7 @@ def prepare_response_model(response_model: type[T] | None) -> type[T] | None:

def handle_response_model(
response_model: type[T] | None, mode: Mode = Mode.TOOLS, **kwargs: Any
) -> tuple[type[T] | None, dict[str, Any]]:
) -> tuple[type[T] | VertexAIParallelBase | None, dict[str, Any]]:
"""
Handles the response model based on the specified mode and prepares the kwargs for the API call.
Expand Down Expand Up @@ -690,6 +718,8 @@ def handle_response_model(

if mode in {Mode.PARALLEL_TOOLS}:
return handle_parallel_tools(response_model, new_kwargs)
elif mode in {Mode.VERTEXAI_PARALLEL_TOOLS}:
return handle_vertexai_parallel_tools(response_model, new_kwargs)

response_model = prepare_response_model(response_model)

Expand Down

0 comments on commit c582618

Please sign in to comment.