diff --git a/libs/vertexai/langchain_google_vertexai/_base.py b/libs/vertexai/langchain_google_vertexai/_base.py index 32d895a4..26b5ac30 100644 --- a/libs/vertexai/langchain_google_vertexai/_base.py +++ b/libs/vertexai/langchain_google_vertexai/_base.py @@ -215,6 +215,11 @@ class _VertexAICommon(_VertexAIBase): model_name will be used to determine the model family """ + cached_content: Optional[str] = None + """ Optional. Use the model in cache mode. Only supported in Gemini 1.5 and later + models. Must be a string containing the cache name (A sequence of numbers) + """ + @property def _is_gemini_model(self) -> bool: return is_gemini_model(self.model_family) # type: ignore[arg-type] diff --git a/libs/vertexai/langchain_google_vertexai/_utils.py b/libs/vertexai/langchain_google_vertexai/_utils.py index a31c0d22..a7fa0987 100644 --- a/libs/vertexai/langchain_google_vertexai/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_utils.py @@ -153,6 +153,10 @@ def is_gemini_model(model_family: GoogleModelFamily) -> bool: return model_family in [GoogleModelFamily.GEMINI, GoogleModelFamily.GEMINI_ADVANCED] +def is_gemini_advanced(model_family: GoogleModelFamily) -> bool: + return model_family == GoogleModelFamily.GEMINI_ADVANCED + + def get_generation_info( candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool, diff --git a/libs/vertexai/langchain_google_vertexai/chains.py b/libs/vertexai/langchain_google_vertexai/chains.py index 6f3c2680..f0fd3cfa 100644 --- a/libs/vertexai/langchain_google_vertexai/chains.py +++ b/libs/vertexai/langchain_google_vertexai/chains.py @@ -16,6 +16,9 @@ from langchain_core.runnables import Runnable from pydantic import BaseModel +from langchain_google_vertexai._utils import ( + is_gemini_advanced, +) from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser @@ -57,7 +60,7 @@ def _create_structured_runnable_extra_step( else schema.schema()["title"] for schema in functions ] - if hasattr(llm, "is_gemini_advanced") and llm._is_gemini_advanced: # type: ignore + if is_gemini_advanced(llm.model_family): # type: ignore llm_with_functions = llm.bind( functions=functions, tool_config={ diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index ebb9464a..33159fb2 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -114,6 +114,7 @@ create_retry_decorator, get_generation_info, _format_model_name, + is_gemini_advanced, is_gemini_model, replace_defs_in_schema, ) @@ -1119,7 +1120,7 @@ def validate_environment(self) -> Self: if not is_gemini_model(self.model_family): logger.warning( "Non-Gemini models are deprecated. " - "They will be remoced starting from Dec-01-2024. " + "They will be removed starting from Dec-01-2024. " ) values = { "project": self.project, @@ -1142,10 +1143,6 @@ def validate_environment(self) -> Self: ) return self - @property - def _is_gemini_advanced(self) -> bool: - return self.model_family == GoogleModelFamily.GEMINI_ADVANCED - @property def _default_params(self) -> Dict[str, Any]: updated_params = super()._default_params @@ -1811,7 +1808,7 @@ class AnswerWithJustification(BaseModel): parser = JsonOutputKeyToolsParser( key_name=tool_name, first_tool_only=True ) - tool_choice = tool_name if self._is_gemini_advanced else None + tool_choice = tool_name if is_gemini_advanced(self.model_family) else None # type: ignore[arg-type] llm = self.bind_tools([schema], tool_choice=tool_choice) diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index 303169fb..fb1b6523 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -23,6 +23,9 @@ from vertexai.language_models._language_models import ( # type: ignore[import-untyped] TextGenerationResponse, ) +from vertexai.preview.generative_models import ( # type: ignore[import-untyped] + GenerativeModel as PreviewGenerativeModel, +) from vertexai.preview.language_models import ( # type: ignore[import-untyped] CodeGenerationModel as PreviewCodeGenerationModel, ) @@ -55,8 +58,14 @@ def _completion_with_retry( def _completion_with_retry_inner( prompt: List[Union[str, Image]], is_gemini: bool = False, **kwargs: Any ) -> Any: + if llm.cached_content is not None: + selected_cached_content = llm.cached_content + model = llm.client_preview.from_cached_content(selected_cached_content) + else: + model = llm.client + if is_gemini: - return llm.client.generate_content( + return model.generate_content( prompt, stream=stream, safety_settings=kwargs.pop("safety_settings", None), @@ -114,6 +123,10 @@ class VertexAI(_VertexAICommon, BaseLLM): """The name of a tuned model. If tuned_model_name is passed model_name will be used to determine the model family """ + cached_content: Optional[str] = None + """ Optional. Use the model in cache mode. Only supported in Gemini 1.5 and later + models. Must be a string containing the cache name (A sequence of numbers) + """ def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None: """Needed for mypy typing to recognize model_name as a valid arg.""" @@ -159,7 +172,7 @@ def validate_environment(self) -> Self: preview_model_cls = PreviewCodeGenerationModel elif is_gemini: model_cls = GenerativeModel - preview_model_cls = GenerativeModel + preview_model_cls = PreviewGenerativeModel else: model_cls = TextGenerationModel preview_model_cls = PreviewTextGenerationModel diff --git a/libs/vertexai/langchain_google_vertexai/utils.py b/libs/vertexai/langchain_google_vertexai/utils.py index c9286a04..fcde4f87 100644 --- a/libs/vertexai/langchain_google_vertexai/utils.py +++ b/libs/vertexai/langchain_google_vertexai/utils.py @@ -1,9 +1,13 @@ from datetime import datetime, timedelta -from typing import List, Optional +from typing import List, Optional, Union from langchain_core.messages import BaseMessage from vertexai.preview import caching # type: ignore +from langchain_google_vertexai._utils import ( + _format_model_name, + is_gemini_advanced, +) from langchain_google_vertexai.chat_models import ( ChatVertexAI, _parse_chat_history_gemini, @@ -14,10 +18,11 @@ _ToolConfigDict, _ToolsType, ) +from langchain_google_vertexai.llms import VertexAI def create_context_cache( - model: ChatVertexAI, + model: Union[ChatVertexAI, VertexAI], messages: List[BaseMessage], expire_time: Optional[datetime] = None, time_to_live: Optional[timedelta] = None, @@ -49,9 +54,13 @@ def create_context_cache( Returns: String with the identificator of the created cache. """ - - if not model._is_gemini_advanced: - error_msg = f"Model {model.full_model_name} doesn't support context catching" + model_name = _format_model_name( + model=model.model_name, + project=model.project, # type: ignore[arg-type] + location=model.location, + ) + if not is_gemini_advanced(model.model_family): # type: ignore[arg-type] + error_msg = f"Model {model_name} doesn't support context catching" raise ValueError(error_msg) system_instruction, contents = _parse_chat_history_gemini(messages, model.project) @@ -63,7 +72,7 @@ def create_context_cache( tools = [_format_to_gapic_tool(tools)] cached_content = caching.CachedContent.create( - model_name=model.full_model_name, + model_name=model_name, system_instruction=system_instruction, contents=contents, ttl=time_to_live, diff --git a/libs/vertexai/tests/integration_tests/test_llms.py b/libs/vertexai/tests/integration_tests/test_llms.py index 13b606e7..b31547bc 100644 --- a/libs/vertexai/tests/integration_tests/test_llms.py +++ b/libs/vertexai/tests/integration_tests/test_llms.py @@ -5,11 +5,19 @@ """ import pytest +from langchain_core.messages import ( + HumanMessage, + SystemMessage, +) from langchain_core.outputs import LLMResult +from langchain_core.rate_limiters import InMemoryRateLimiter +from langchain_google_vertexai import create_context_cache from langchain_google_vertexai.llms import VertexAI from tests.integration_tests.conftest import _DEFAULT_MODEL_NAME +rate_limiter = InMemoryRateLimiter(requests_per_second=1.0) + @pytest.mark.release def test_vertex_initialization() -> None: @@ -101,3 +109,61 @@ def test_vertex_call_count_tokens() -> None: llm = VertexAI(model_name=_DEFAULT_MODEL_NAME) output = llm.get_num_tokens("How are you?") assert output == 4 + + +@pytest.mark.extended +@pytest.mark.first +def test_context_catching(): + system_instruction = """ + + You are an expert researcher. You always stick to the facts in the sources provided, + and never make up new facts. + + If asked about it, the secret number is 747. + + Now look at these research papers, and answer the following questions. + + """ + + cached_content = create_context_cache( + VertexAI( + model_name="gemini-1.5-pro-001", + rate_limiter=rate_limiter, + ), + messages=[ + SystemMessage(content=system_instruction), + HumanMessage( + content=[ + { + "type": "image_url", + "image_url": { + "url": "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf", + }, + }, + { + "type": "image_url", + "image_url": { + "url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf" + }, + }, + ] + ), + ], + ) + + # Using cached_content in constructor + llm = VertexAI( + model_name="gemini-1.5-pro-001", + cached_content=cached_content, + rate_limiter=rate_limiter, + ) + + response = llm.invoke("What is the secret number?") + + assert isinstance(response, str) + + # Using cached content in request + llm = VertexAI(model_name="gemini-1.5-pro-001", rate_limiter=rate_limiter) + response = llm.invoke("What is the secret number?", cached_content=cached_content) + + assert isinstance(response, str)