Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Jul 18, 2024
2 parents becb2f1 + 0231d7b commit f01ce0c
Show file tree
Hide file tree
Showing 17 changed files with 160 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def initialize_bq_vector_index(cls, values: dict) -> dict:
values["_logger"].debug("Not enough rows to create a vector index.")
return values

if "_last_index_check" not in values:
values["_last_index_check"] = datetime.min

if datetime.utcnow() - values["_last_index_check"] < INDEX_CHECK_INTERVAL:
return values

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ class CreateDraftSchema(BaseModel):
description="The subject of the message.",
)
cc: Optional[List[str]] = Field(
None,
default=None,
description="The list of CC recipients.",
)
bcc: Optional[List[str]] = Field(
None,
default=None,
description="The list of BCC recipients.",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class SendMessageSchema(BaseModel):
description="The subject of the message.",
)
cc: Optional[Union[str, List[str]]] = Field(
None,
default=None,
description="The list of CC recipients.",
)
bcc: Optional[Union[str, List[str]]] = Field(
None,
default=None,
description="The list of BCC recipients.",
)

Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_google_community/places_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class GooglePlacesAPIWrapper(BaseModel):
"""

gplaces_api_key: Optional[str] = None
google_map_client: Any #: :meta private:
google_map_client: Any = None #: :meta private:
top_k_results: Optional[int] = None

class Config:
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_google_community/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class GoogleSearchAPIWrapper(BaseModel):
"""

search_engine: Any #: :meta private:
search_engine: Any = None #: :meta private:
google_api_key: Optional[str] = None
google_cse_id: Optional[str] = None
k: int = 10
Expand Down
77 changes: 51 additions & 26 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,9 @@ class Joke(BaseModel):
""" # noqa: E501

client: Any #: :meta private:
async_client: Any #: :meta private:
google_api_key: Optional[SecretStr] = Field(None, alias="api_key")
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
google_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Google AI API key.
If not specified will be read from env var ``GOOGLE_API_KEY``."""
Expand Down Expand Up @@ -965,9 +965,18 @@ async def _agenerate(
**kwargs: Any,
) -> ChatResult:
if not self.async_client:
raise RuntimeError(
"Initialize ChatGoogleGenerativeAI with a running event loop "
"to use async methods."
updated_kwargs = {
**kwargs,
**{
"tools": tools,
"functions": functions,
"safety_settings": safety_settings,
"tool_config": tool_config,
"generation_config": generation_config,
},
}
return await super()._agenerate(
messages, stop, run_manager, **updated_kwargs
)

request = self._prepare_request(
Expand Down Expand Up @@ -1036,27 +1045,43 @@ async def _astream(
generation_config: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
)
async for chunk in await _achat_with_retry(
request=request,
generation_method=self.async_client.stream_generate_content,
**kwargs,
metadata=self.default_metadata,
):
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if not self.async_client:
updated_kwargs = {
**kwargs,
**{
"tools": tools,
"functions": functions,
"safety_settings": safety_settings,
"tool_config": tool_config,
"generation_config": generation_config,
},
}
async for value in super()._astream(
messages, stop, run_manager, **updated_kwargs
):
yield value
else:
request = self._prepare_request(
messages,
stop=stop,
tools=tools,
functions=functions,
safety_settings=safety_settings,
tool_config=tool_config,
generation_config=generation_config,
)
async for chunk in await _achat_with_retry(
request=request,
generation_method=self.async_client.stream_generate_content,
**kwargs,
metadata=self.default_metadata,
):
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])

if run_manager:
await run_manager.on_llm_new_token(gen.text)
yield gen
if run_manager:
await run_manager.on_llm_new_token(gen.text)
yield gen

def _prepare_request(
self,
Expand Down
12 changes: 6 additions & 6 deletions libs/genai/langchain_google_genai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
embeddings.embed_query("What's our Q1 revenue?")
"""

client: Any #: :meta private:
client: Any = None #: :meta private:
model: str = Field(
...,
description="The name of the embedding model to use. "
"Example: models/embedding-001",
)
task_type: Optional[str] = Field(
None,
default=None,
description="The task type. Valid options include: "
"task_type_unspecified, retrieval_query, retrieval_document, "
"semantic_similarity, classification, and clustering",
)
google_api_key: Optional[SecretStr] = Field(
None,
default=None,
description="The Google API key to use. If not provided, "
"the GOOGLE_API_KEY environment variable will be used.",
)
Expand All @@ -64,18 +64,18 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
"provided, credentials will be ascertained from the GOOGLE_API_KEY envvar",
)
client_options: Optional[Dict] = Field(
None,
default=None,
description=(
"A dictionary of client options to pass to the Google API client, "
"such as `api_endpoint`."
),
)
transport: Optional[str] = Field(
None,
default=None,
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)
request_options: Optional[Dict] = Field(
None,
default=None,
description="A dictionary of request options to pass to the Google API client."
"Example: `{'timeout': 10}`",
)
Expand Down
8 changes: 4 additions & 4 deletions libs/genai/langchain_google_genai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,18 @@ class _BaseGoogleGenerativeAI(BaseModel):
"""The maximum number of seconds to wait for a response."""

client_options: Optional[Dict] = Field(
None,
default=None,
description=(
"A dictionary of client options to pass to the Google API client, "
"such as `api_endpoint`."
),
)
transport: Optional[str] = Field(
None,
default=None,
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)
additional_headers: Optional[Dict[str, str]] = Field(
None,
default=None,
description=(
"A key-value dictionary representing additional headers for the model call"
),
Expand Down Expand Up @@ -212,7 +212,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
llm = GoogleGenerativeAI(model="gemini-pro")
"""

client: Any #: :meta private:
client: Any = None #: :meta private:

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down
29 changes: 27 additions & 2 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test ChatGoogleGenerativeAI chat model."""

import asyncio
import json
from typing import Generator, Optional, Type
from typing import Generator, List, Optional, Type

import pytest
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -465,3 +465,28 @@ class MyModel(BaseModel):
}
]
assert response == expected


def test_ainvoke_without_eventloop() -> None:
model = ChatGoogleGenerativeAI(model="gemini-1.5-flash-001")

async def model_ainvoke(context: str) -> BaseMessage:
result = await model.ainvoke(context)
return result

result = asyncio.run(model_ainvoke("How can you help me?"))
assert isinstance(result, AIMessage)


def test_astream_without_eventloop() -> None:
model = ChatGoogleGenerativeAI(model="gemini-1.5-flash-001")

async def model_astream(context: str) -> List[BaseMessageChunk]:
result = []
async for chunk in model.astream(context):
result.append(chunk)
return result

result = asyncio.run(model_astream("How can you help me?"))
assert len(result) > 0
assert isinstance(result[0], AIMessageChunk)
2 changes: 1 addition & 1 deletion libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class _VertexAIBase(BaseModel):
client_options: Optional["ClientOptions"] = Field(
default=None, exclude=True
) #: :meta private:
api_endpoint: Optional[str] = Field(None, alias="base_url")
api_endpoint: Optional[str] = Field(default=None, alias="base_url")
"Desired API endpoint, e.g., us-central1-aiplatform.googleapis.com"
api_transport: Optional[str] = None
"""The desired API transport method, can be either 'grpc' or 'rest'.
Expand Down
4 changes: 4 additions & 0 deletions libs/vertexai/langchain_google_vertexai/model_garden.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Union,
)

from google.auth.credentials import Credentials
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand Down Expand Up @@ -134,6 +135,7 @@ class ChatAnthropicVertex(_VertexAICommon, BaseChatModel):
max_output_tokens: int = Field(default=1024, alias="max_tokens")
access_token: Optional[str] = None
stream_usage: bool = True # Whether to include usage metadata in streaming output
credentials: Optional[Credentials] = None

class Config:
"""Configuration for this pydantic object."""
Expand All @@ -156,12 +158,14 @@ def validate_environment(cls, values: Dict) -> Dict:
region=values["location"],
max_retries=values["max_retries"],
access_token=values["access_token"],
credentials=values["credentials"],
)
values["async_client"] = AsyncAnthropicVertex(
project_id=values["project"],
region=values["location"],
max_retries=values["max_retries"],
access_token=values["access_token"],
credentials=values["credentials"],
)
return values

Expand Down
11 changes: 6 additions & 5 deletions libs/vertexai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/vertexai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ langchain-core = ">=0.2.9,<0.3"
google-cloud-aiplatform = "^1.56.0"
google-cloud-storage = "^2.17.0"
# optional dependencies
anthropic = { extras = ["vertexai"], version = ">=0.29.0,<1", optional = true }
anthropic = { extras = ["vertexai"], version = ">=0.30.0,<1", optional = true }

[tool.poetry.group.test]
optional = true
Expand Down
7 changes: 6 additions & 1 deletion libs/vertexai/tests/integration_tests/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@ PROJECT_ID=projecy_id
FALCON_ENDPOINT_ID=falcon_endpoint_id
GEMMA_ENDPOINT_ID=gemma_endpoint_id
LLAMA_ENDPOINT_ID=llama_endpoint_id
IMAGE_GCS_PATH=image_gcs_path
IMAGE_GCS_PATH=image_gcs_path
VECTOR_SEARCH_STAGING_BUCKET=VECTOR_SEARCH_STAGING_BUCKET
VECTOR_SEARCH_STREAM_INDEX_ID=VECTOR_SEARCH_STREAM_INDEX_ID
VECTOR_SEARCH_STREAM_ENDPOINT_ID=VECTOR_SEARCH_STREAM_ENDPOINT_ID
VECTOR_SEARCH_BATCH_INDEX_ID=VECTOR_SEARCH_BATCH_INDEX_ID
VECTOR_SEARCH_BATCH_ENDPOINT_ID=VECTOR_SEARCH_BATCH_ENDPOINT_ID
5 changes: 5 additions & 0 deletions libs/vertexai/tests/integration_tests/terraform/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ module "cloudbuild" {
GEMMA_ENDPOINT_ID = "",
LLAMA_ENDPOINT_ID = "",
IMAGE_GCS_PATH = "",
VECTOR_SEARCH_STAGING_BUCKET="",
VECTOR_SEARCH_STREAM_INDEX_ID="",
VECTOR_SEARCH_STREAM_ENDPOINT_ID="",
VECTOR_SEARCH_BATCH_INDEX_ID="",
VECTOR_SEARCH_BATCH_ENDPOINT_ID="",
}
cloudbuild_secret_vars = {
GOOGLE_API_KEY = ""
Expand Down
Loading

0 comments on commit f01ce0c

Please sign in to comment.