Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vertexai: Add context caching to VertexAI class #645

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ class _VertexAICommon(_VertexAIBase):
model_name will be used to determine the model family
"""

cached_content: Optional[str] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, but how would it be different from

cached_content: Optional[str] = None
?

""" Optional. Use the model in cache mode. Only supported in Gemini 1.5 and later
models. Must be a string containing the cache name (A sequence of numbers)
"""

@property
def _is_gemini_model(self) -> bool:
return is_gemini_model(self.model_family) # type: ignore[arg-type]
Expand Down
4 changes: 4 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def is_gemini_model(model_family: GoogleModelFamily) -> bool:
return model_family in [GoogleModelFamily.GEMINI, GoogleModelFamily.GEMINI_ADVANCED]


def is_gemini_advanced(model_family: GoogleModelFamily) -> bool:
return model_family == GoogleModelFamily.GEMINI_ADVANCED


def get_generation_info(
candidate: Union[TextGenerationResponse, Candidate],
is_gemini: bool,
Expand Down
5 changes: 4 additions & 1 deletion libs/vertexai/langchain_google_vertexai/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from langchain_core.runnables import Runnable
from pydantic import BaseModel

from langchain_google_vertexai._utils import (
is_gemini_advanced,
)
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser


Expand Down Expand Up @@ -57,7 +60,7 @@ def _create_structured_runnable_extra_step(
else schema.schema()["title"]
for schema in functions
]
if hasattr(llm, "is_gemini_advanced") and llm._is_gemini_advanced: # type: ignore
if is_gemini_advanced(llm.model_family): # type: ignore
llm_with_functions = llm.bind(
functions=functions,
tool_config={
Expand Down
9 changes: 3 additions & 6 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
create_retry_decorator,
get_generation_info,
_format_model_name,
is_gemini_advanced,
is_gemini_model,
replace_defs_in_schema,
)
Expand Down Expand Up @@ -1119,7 +1120,7 @@ def validate_environment(self) -> Self:
if not is_gemini_model(self.model_family):
logger.warning(
"Non-Gemini models are deprecated. "
"They will be remoced starting from Dec-01-2024. "
"They will be removed starting from Dec-01-2024. "
)
values = {
"project": self.project,
Expand All @@ -1142,10 +1143,6 @@ def validate_environment(self) -> Self:
)
return self

@property
def _is_gemini_advanced(self) -> bool:
return self.model_family == GoogleModelFamily.GEMINI_ADVANCED

@property
def _default_params(self) -> Dict[str, Any]:
updated_params = super()._default_params
Expand Down Expand Up @@ -1811,7 +1808,7 @@ class AnswerWithJustification(BaseModel):
parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
tool_choice = tool_name if self._is_gemini_advanced else None
tool_choice = tool_name if is_gemini_advanced(self.model_family) else None # type: ignore[arg-type]

llm = self.bind_tools([schema], tool_choice=tool_choice)

Expand Down
17 changes: 15 additions & 2 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from vertexai.language_models._language_models import ( # type: ignore[import-untyped]
TextGenerationResponse,
)
from vertexai.preview.generative_models import ( # type: ignore[import-untyped]
GenerativeModel as PreviewGenerativeModel,
)
from vertexai.preview.language_models import ( # type: ignore[import-untyped]
CodeGenerationModel as PreviewCodeGenerationModel,
)
Expand Down Expand Up @@ -55,8 +58,14 @@ def _completion_with_retry(
def _completion_with_retry_inner(
prompt: List[Union[str, Image]], is_gemini: bool = False, **kwargs: Any
) -> Any:
if llm.cached_content is not None:
selected_cached_content = llm.cached_content
model = llm.client_preview.from_cached_content(selected_cached_content)
else:
model = llm.client

if is_gemini:
return llm.client.generate_content(
return model.generate_content(
prompt,
stream=stream,
safety_settings=kwargs.pop("safety_settings", None),
Expand Down Expand Up @@ -114,6 +123,10 @@ class VertexAI(_VertexAICommon, BaseLLM):
"""The name of a tuned model. If tuned_model_name is passed
model_name will be used to determine the model family
"""
cached_content: Optional[str] = None
""" Optional. Use the model in cache mode. Only supported in Gemini 1.5 and later
models. Must be a string containing the cache name (A sequence of numbers)
"""

def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
"""Needed for mypy typing to recognize model_name as a valid arg."""
Expand Down Expand Up @@ -159,7 +172,7 @@ def validate_environment(self) -> Self:
preview_model_cls = PreviewCodeGenerationModel
elif is_gemini:
model_cls = GenerativeModel
preview_model_cls = GenerativeModel
preview_model_cls = PreviewGenerativeModel
else:
model_cls = TextGenerationModel
preview_model_cls = PreviewTextGenerationModel
Expand Down
21 changes: 15 additions & 6 deletions libs/vertexai/langchain_google_vertexai/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from datetime import datetime, timedelta
from typing import List, Optional
from typing import List, Optional, Union

from langchain_core.messages import BaseMessage
from vertexai.preview import caching # type: ignore

from langchain_google_vertexai._utils import (
_format_model_name,
is_gemini_advanced,
)
from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history_gemini,
Expand All @@ -14,10 +18,11 @@
_ToolConfigDict,
_ToolsType,
)
from langchain_google_vertexai.llms import VertexAI


def create_context_cache(
model: ChatVertexAI,
model: Union[ChatVertexAI, VertexAI],
messages: List[BaseMessage],
expire_time: Optional[datetime] = None,
time_to_live: Optional[timedelta] = None,
Expand Down Expand Up @@ -49,9 +54,13 @@ def create_context_cache(
Returns:
String with the identificator of the created cache.
"""

if not model._is_gemini_advanced:
error_msg = f"Model {model.full_model_name} doesn't support context catching"
model_name = _format_model_name(
model=model.model_name,
project=model.project, # type: ignore[arg-type]
location=model.location,
)
if not is_gemini_advanced(model.model_family): # type: ignore[arg-type]
error_msg = f"Model {model_name} doesn't support context catching"
raise ValueError(error_msg)

system_instruction, contents = _parse_chat_history_gemini(messages, model.project)
Expand All @@ -63,7 +72,7 @@ def create_context_cache(
tools = [_format_to_gapic_tool(tools)]

cached_content = caching.CachedContent.create(
model_name=model.full_model_name,
model_name=model_name,
system_instruction=system_instruction,
contents=contents,
ttl=time_to_live,
Expand Down
66 changes: 66 additions & 0 deletions libs/vertexai/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@
"""

import pytest
from langchain_core.messages import (
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import LLMResult
from langchain_core.rate_limiters import InMemoryRateLimiter

from langchain_google_vertexai import create_context_cache
from langchain_google_vertexai.llms import VertexAI
from tests.integration_tests.conftest import _DEFAULT_MODEL_NAME

rate_limiter = InMemoryRateLimiter(requests_per_second=1.0)


@pytest.mark.release
def test_vertex_initialization() -> None:
Expand Down Expand Up @@ -101,3 +109,61 @@ def test_vertex_call_count_tokens() -> None:
llm = VertexAI(model_name=_DEFAULT_MODEL_NAME)
output = llm.get_num_tokens("How are you?")
assert output == 4


@pytest.mark.extended
@pytest.mark.first
def test_context_catching():
system_instruction = """

You are an expert researcher. You always stick to the facts in the sources provided,
and never make up new facts.

If asked about it, the secret number is 747.

Now look at these research papers, and answer the following questions.

"""

cached_content = create_context_cache(
VertexAI(
model_name="gemini-1.5-pro-001",
rate_limiter=rate_limiter,
),
messages=[
SystemMessage(content=system_instruction),
HumanMessage(
content=[
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf",
},
},
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf"
},
},
]
),
],
)

# Using cached_content in constructor
llm = VertexAI(
model_name="gemini-1.5-pro-001",
cached_content=cached_content,
rate_limiter=rate_limiter,
)

response = llm.invoke("What is the secret number?")

assert isinstance(response, str)

# Using cached content in request
llm = VertexAI(model_name="gemini-1.5-pro-001", rate_limiter=rate_limiter)
response = llm.invoke("What is the secret number?", cached_content=cached_content)

assert isinstance(response, str)
Loading