Skip to content

Commit

Permalink
Reafactored partner version of init_vertexai() back into llms.py
Browse files Browse the repository at this point in the history
- Reverted minimum version
  • Loading branch information
holtskinner committed Feb 5, 2024
1 parent a6a738a commit 5fa033f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 41 deletions.
2 changes: 1 addition & 1 deletion libs/community/langchain_community/utilities/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_retry_decorator(
return decorator


def raise_vertex_import_error(minimum_expected_version: str = "1.40.0") -> None:
def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
"""Raise ImportError related to Vertex SDK being not available.
Args:
Expand Down
50 changes: 13 additions & 37 deletions libs/partners/google-vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union

import google.api_core
import vertexai
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.credentials import Credentials
from google.cloud import storage
from google.cloud.aiplatform import initializer
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand Down Expand Up @@ -46,7 +43,7 @@ def create_retry_decorator(
return decorator


def raise_vertex_import_error(minimum_expected_version: str = "1.40.0") -> None:
def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
"""Raise ImportError related to Vertex SDK being not available.
Args:
Expand All @@ -60,53 +57,32 @@ def raise_vertex_import_error(minimum_expected_version: str = "1.40.0") -> None:
)


def _get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
r"""Returns a custom user agent header.
Args:
module (Optional[str]):
Optional. The module for a custom user agent header.
Returns:
Tuple[str, str]: The client library version and user agent.
"""
langchain_version = metadata.version("langchain")
client_library_version = (
f"{langchain_version}-{module}" if module else langchain_version
)
return client_library_version, f"langchain/{client_library_version}"


def init_vertexai(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional["Credentials"] = None,
module: Optional[str] = None,
) -> None:
"""Init vertexai.
Args:
project: The default GCP project to use when making Vertex API calls.
location: The default location to use when making API calls.
credentials: The default custom
credentials to use when making API calls. If not provided credentials
will be ascertained from the environment.
module: The module for a custom user agent header.
Raises:
ImportError: If importing vertexai SDK did not succeed.
"""
vertexai.init(
project=project,
location=location,
credentials=credentials,
)

_, user_agent = _get_user_agent(module)
initializer.global_config.append_user_agent(user_agent)


def get_client_info(module: Optional[str] = None) -> "ClientInfo":
r"""Returns a custom user agent header.
def get_client_info(module: Optional[str] = None) -> ClientInfo:
r"""Returns ClientInfo with a custom user agent header.
Args:
module (Optional[str]):
Optional. The module for a custom user agent header.
Returns:
google.api_core.gapic_v1.client_info.ClientInfo
"""
client_library_version, user_agent = _get_user_agent(module)
client_library_version, user_agent = get_user_agent(module)
return ClientInfo(
client_library_version=client_library_version,
user_agent=user_agent,
Expand Down
10 changes: 7 additions & 3 deletions libs/partners/google-vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from concurrent.futures import Executor
from typing import Any, ClassVar, Dict, Iterator, List, Optional, Union

import vertexai # type: ignore[import-untyped]
from google.api_core.client_options import ClientOptions
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
Expand Down Expand Up @@ -41,7 +43,7 @@
create_retry_decorator,
get_client_info,
get_generation_info,
init_vertexai,
get_user_agent,
is_codey_model,
is_gemini_model,
)
Expand Down Expand Up @@ -218,12 +220,14 @@ def _default_params(self) -> Dict[str, Any]:

@classmethod
def _init_vertexai(cls, values: Dict) -> None:
init_vertexai(
vertexai.init(
project=values.get("project"),
location=values.get("location"),
credentials=values.get("credentials"),
module=values.get("module"),
)
_, user_agent = get_user_agent(values.get("module"))
initializer.global_config.append_user_agent(user_agent)
return None

def _prepare_params(
self,
Expand Down

0 comments on commit 5fa033f

Please sign in to comment.