Skip to content

Commit

Permalink
Add credentials parameter to GoogleGenerativeAI, ChatGoogleGenerative…
Browse files Browse the repository at this point in the history
…AI, and GoogleGenerativeAIEmbedded (langchain-ai#78)

* Adding credentials to GoogleGenerativeAI and GoogleGenerativeAIEmbeddings

---------

Co-authored-by: jack <[email protected]>
  • Loading branch information
jackklika and jack authored Mar 21, 2024
1 parent fcca8db commit d7bc26a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 34 deletions.
27 changes: 17 additions & 10 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,17 +483,24 @@ def is_lc_serializable(self) -> bool:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates params and passes them to google-generativeai package."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
if values.get("credentials"):
genai.configure(
credentials=values.get("credentials"),
transport=values.get("transport"),
client_options=values.get("client_options"),
)
else:
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()

genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
)
genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
)
if (
values.get("temperature") is not None
and not 0 <= values["temperature"] <= 1
Expand Down
38 changes: 26 additions & 12 deletions libs/genai/langchain_google_genai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

# TODO: remove ignore once the google package is published with types
import google.generativeai as genai # type: ignore[import]
Expand Down Expand Up @@ -43,6 +43,13 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
description="The Google API key to use. If not provided, "
"the GOOGLE_API_KEY environment variable will be used.",
)
credentials: Any = Field(
default=None,
exclude=True,
description="The default custom credentials "
"(google.auth.credentials.Credentials) to use when making API calls. If not "
"provided, credentials will be ascertained from the GOOGLE_API_KEY envvar",
)
client_options: Optional[Dict] = Field(
None,
description=(
Expand All @@ -58,17 +65,24 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates params and passes them to google-generativeai package."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()

genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
)
if values.get("credentials"):
genai.configure(
credentials=values.get("credentials"),
transport=values.get("transport"),
client_options=values.get("client_options"),
)
else:
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()

genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
)
return values

def _embed(
Expand Down
34 changes: 22 additions & 12 deletions libs/genai/langchain_google_genai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class _BaseGoogleGenerativeAI(BaseModel):
)
"""Model name to use."""
google_api_key: Optional[SecretStr] = None
credentials: Any = None
"The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from "
"the GOOGLE_API_KEY envvar"
temperature: float = 0.7
"""Run inference with this temperature. Must by in the closed interval
[0.0, 1.0]."""
Expand Down Expand Up @@ -203,22 +207,28 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates params and passes them to google-generativeai package."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
if values.get("credentials"):
genai.configure(
credentials=values.get("credentials"),
transport=values.get("transport"),
client_options=values.get("client_options"),
)
else:
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
)

model_name = values["model"]

safety_settings = values["safety_settings"]

if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()

genai.configure(
api_key=google_api_key,
transport=values.get("transport"),
client_options=values.get("client_options"),
)

if safety_settings and (
not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
):
Expand Down

0 comments on commit d7bc26a

Please sign in to comment.