diff --git a/src/langgraph_google_alloydb_pg/__init__.py b/src/langgraph_google_alloydb_pg/__init__.py new file mode 100644 index 00000000..dd8b05c0 --- /dev/null +++ b/src/langgraph_google_alloydb_pg/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .engine import AlloyDBEngine +from .checkpoint import AlloyDBSaver +from .version import __version__ + +__all__ = [ + "AlloyDBEngine", + "AlloyDBSaver", + "__version__", +] \ No newline at end of file diff --git a/src/langgraph_google_alloydb_pg/async_checkpoint.py b/src/langgraph_google_alloydb_pg/async_checkpoint.py new file mode 100644 index 00000000..54e7f84d --- /dev/null +++ b/src/langgraph_google_alloydb_pg/async_checkpoint.py @@ -0,0 +1,352 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import asyncpg # type: ignore + +from contextlib import asynccontextmanager + +import json +from typing import List, Sequence, Any, AsyncIterator, Iterator, Optional, Dict, Tuple + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from langchain_core.runnables import RunnableConfig + +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + get_checkpoint_id +) +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol + +from langgraph.checkpoint.serde.base import SerializerProtocol + +MetadataInput = Optional[dict[str, Any]] + +from .engine import AlloyDBEngine + + +class AsyncAlloyDBSaver(BaseCheckpointSaver[str]): + """Checkpoint stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None + ) -> None: + super().__init__(serde=serde) + if key != AsyncAlloyDBSaver.__create_key: + raise Exception( + "only create class through 'create' or 'create_sync' methods" + ) + self.pool = pool + self.schema_name = schema_name + + @classmethod + async def create( + cls, + engine: AlloyDBEngine, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None + ) -> "AsyncAlloyDBSaver": + """Create a new AsyncAlloyDBSaver instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + schema_name (str): The schema name where the table is located (default: "public"). + serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AsyncAlloyDBSaver: A newly created instance of AsyncAlloyDBSaver. + """ + + checkpoints_table_schema = await engine._aload_table_schema("checkpoints", schema_name) + checkpoints_column_names = checkpoints_table_schema.columns.keys() + + checkpoints_required_columns = ["thread_id", + "checkpoint_ns", + "checkpoint_id", + "parent_checkpoint_id", + "v", + "type", + "checkpoint", + "metadata"] + + if not (all(x in checkpoints_column_names for x in checkpoints_required_columns)): + raise IndexError( + f"Table checkpoints.'{schema_name}' has incorrect schema. Got " + f"column names '{checkpoints_column_names}' but required column names " + f"'{checkpoints_required_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.checkpoints (" + "\n thread_id TEXT NOT NULL," + "\n checkpoint_ns TEXT NOT NULL," + "\n checkpoint_id UUID NOT NULL," + "\n parent_checkpoint_id UUID," + "\n v INT NOT NULL," + "\n type TEXT NOT NULL," + "\n checkpoint JSONB NOT NULL," + "\n metadata JSONB" + "\n);" + ) + + checkpoint_writes_table_schema = await engine._aload_table_schema("checkpoint_writes", schema_name) + checkpoint_writes_column_names = checkpoint_writes_table_schema.columns.keys() + + checkpoint_writes_columns = ["thread_id", + "checkpoint_ns", + "checkpoint_id", + "task_id", + "idx", + "channel", + "type", + "blob"] + + if not (all(x in checkpoint_writes_column_names for x in checkpoint_writes_columns)): + raise IndexError( + f"Table checkpoint_writes.'{schema_name}' has incorrect schema. Got " + f"column names '{checkpoint_writes_column_names}' but required column names " + f"'{checkpoint_writes_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.checkpoint_writes (" + "\n thread_id TEXT NOT NULL," + "\n checkpoint_ns TEXT NOT NULL," + "\n checkpoint_id UUID NOT NULL," + "\n task_id UUID NOT NULL," + "\n idx INT NOT NULL," + "\n channel TEXT NOT NULL," + "\n type TEXT NOT NULL," + "\n blob JSONB NOT NULL" + "\n);" + ) + return cls(cls.__create_key, engine._pool, schema_name, serde) + + async def alist( + self, + config: Optional[RunnableConfig], + filter: Optional[Dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """Asynchronously list checkpoints that match the given criteria. + + Args: + config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. + before (Optional[RunnableConfig]): List checkpoints created before this configuration. + limit (Optional[int]): Maximum number of checkpoints to return. + + Returns: + AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples. + + Raises: + NotImplementedError: Implement this method in your custom checkpoint saver. + """ + raise NotImplementedError + yield + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Asynchronously fetch a checkpoint tuple using the given configuration. + + Args: + config (RunnableConfig): Configuration specifying which checkpoint to retrieve. + + Returns: + Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. + + Raises: + NotImplementedError: Implement this method in your custom checkpoint saver. + """ + raise NotImplementedError + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Asynchronously store a checkpoint with its configuration and metadata. + + Args: + config (RunnableConfig): Configuration for the checkpoint. + checkpoint (Checkpoint): The checkpoint to store. + metadata (CheckpointMetadata): Additional metadata for the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + checkpoint_id = configurable.pop( + "checkpoint_id", configurable.pop("thread_ts", None) + ) + + copy = checkpoint.copy() + next_config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + query = f"""INSERT INTO "{self.schema_name}".checkpoints(thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata, channel, version, type, blob) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :parent_checkpoint_id, :checkpoint, :metadata, :channel, :version, :type, :blob); + """ + + async with self.pool.connect() as conn: + await conn.execute( + text(query), + { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + "parent_checkpoint_id": config.get("checkpoint_id"), + "checkpoint": json.dumps(copy), + "metadata": json.dumps(dict(metadata)), + "channel": copy.pop("channel_values"), + "version": new_versions, + "type": next_config["configurable"]["type"], + "blob": json.dumps(next_config["configurable"]["blob"]), + }, + ) + await conn.commit() + + return next_config + + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + ) -> None: + """Asynchronously store intermediate writes linked to a checkpoint. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + + Raises: + NotImplementedError: Implement this method in your custom checkpoint saver. + """ + query = f"""INSERT INTO "{self.schema_name}".checkpoint_writes(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :task_id, :idx, :channel, :type, :blob) + """ + upsert = "" + async with self.pool.connect() as conn: + await conn.execute( + text(query), + { + "thread_id": config["configurable"]["thread_id"], + "checkpoint_ns": config["configurable"]["checkpoint_ns"], + "checkpoint_id": config["configurable"]["checkpoint_id"], + "task_id": task_id, + "idx": idx, + "channel": write[0], + "type": write[1], + "blob": json.dumps(write[2]), + }, + ) + await conn.commit() + + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[Dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." + ) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Fetch a checkpoint tuple using the given configuration. + + Args: + config (RunnableConfig): Configuration specifying which checkpoint to retrieve. + + Returns: + Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. + + Raises: + NotImplementedError: Implement this method in your custom checkpoint saver. + """ + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Store a checkpoint with its configuration and metadata. + + Args: + config (RunnableConfig): Configuration for the checkpoint. + checkpoint (Checkpoint): The checkpoint to store. + metadata (CheckpointMetadata): Additional metadata for the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + + Raises: + NotImplementedError: Method impletented in AsyncAlloyDBSaver. + """ + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." + ) + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + + Raises: + NotImplementedError: Method impletented in AsyncAlloyDBSaver. + """ + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." + ) \ No newline at end of file diff --git a/src/langgraph_google_alloydb_pg/checkpoint.py b/src/langgraph_google_alloydb_pg/checkpoint.py new file mode 100644 index 00000000..ebe9b7f7 --- /dev/null +++ b/src/langgraph_google_alloydb_pg/checkpoint.py @@ -0,0 +1,175 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterator, AsyncIterator, Sequence +from typing import Any, Optional + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple +) +from langgraph.checkpoint.serde.base import SerializerProtocol + +from .async_checkpoint import AsyncAlloyDBSaver +from .engine import AlloyDBEngine + +class AlloyDBSaver(BaseCheckpointSaver[str]): + """Checkpoint stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: AlloyDBEngine, + checkpoint: AsyncAlloyDBSaver, + serde: Optional[SerializerProtocol] = None + ) -> None: + super().__init__(serde=serde) + if key != AlloyDBSaver.__create_key: + raise Exception( + "only create class through 'create' or 'create_sync' methods" + ) + self._engine = engine + self.__checkpoint = checkpoint + + @classmethod + async def create( + cls, + engine: AlloyDBEngine, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None + ) -> "AlloyDBSaver": + """Create a new AlloyDBSaver instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + schema_name (str): The schema name where the table is located (default: "public"). + serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AlloyDBSaver: A newly created instance of AlloyDBSaver. + """ + coro = AsyncAlloyDBSaver.create( + engine, schema_name, serde + ) + checkpoint = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, checkpoint) + + @classmethod + def create_sync( + cls, + engine: AlloyDBEngine, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None + ) -> "AlloyDBSaver": + """Create a new AlloyDBSaver instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + schema_name (str): The schema name where the table is located (default: "public"). + serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AlloyDBSaver: A newly created instance of AlloyDBSaver. + """ + coro = AsyncAlloyDBSaver.create( + engine, schema_name, serde + ) + checkpoint = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, checkpoint) + + + async def alist( + self, + config: Optional[RunnableConfig], + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None + ) -> AsyncIterator[CheckpointTuple]: + '''List checkpoints from AlloyDB ''' + yield await self._engine._run_as_async(self.__checkpoint.alist(config, filter, before, limit)) + + def list( + self, + config: Optional[RunnableConfig], + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None + ) -> Iterator[CheckpointTuple]: + '''List checkpoints from AlloyDB ''' + return self._engine._run_as_sync(self.__checkpoint.alist(config, filter, before, limit)) + + async def aget_tuple( + self, + config: RunnableConfig + ) -> Optional[CheckpointTuple]: + '''Get a checkpoint tuple from AlloyDB''' + return await self._engine._run_as_async(self.__checkpoint.aget_tuple(config)) + + def get_tuple( + self, + config: RunnableConfig + ) -> Optional[CheckpointTuple]: + '''Get a checkpoint tuple from AlloyDB''' + return self._engine._run_as_sync(self.__checkpoint.aget_tuple(config)) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions + ) -> RunnableConfig: + '''Save a checkpoint to AlloyDB''' + return await self._engine._run_as_async(self.__checkpoint.aput(config, checkpoint, metadata, new_versions)) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions + ) -> RunnableConfig: + '''Save a checkpoint to AlloyDB''' + return self._engine._run_as_sync(self.__checkpoint.aput(config, checkpoint, metadata, new_versions)) + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str + ) -> None: + '''Store intermediate writes linked to a checkpoint''' + await self._engine._run_as_async(self.__checkpoint.aput_writes(config, writes, task_id)) + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str + ) -> None: + '''Store intermediate writes linked to a checkpoint''' + self._engine._run_as_sync(self.__checkpoint.aput_writes(config, writes, task_id)) + \ No newline at end of file diff --git a/src/langgraph_google_alloydb_pg/engine.py b/src/langgraph_google_alloydb_pg/engine.py new file mode 100644 index 00000000..2e93f37f --- /dev/null +++ b/src/langgraph_google_alloydb_pg/engine.py @@ -0,0 +1,530 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from concurrent.futures import Future +from dataclasses import dataclass +from threading import Thread +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) + +import aiohttp +import google.auth # type: ignore +import google.auth.transport.requests # type: ignore +from google.cloud.alloydb.connector import AsyncConnector, IPTypes, RefreshStrategy +from sqlalchemy import MetaData, Table, text +from sqlalchemy.engine import URL +from sqlalchemy.exc import InvalidRequestError +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from .version import __version__ + +if TYPE_CHECKING: + import asyncpg # type: ignore + import google.auth.credentials # type: ignore + +T = TypeVar("T") + +USER_AGENT = "langgraph_google_alloydb_pg/" + __version__ + + +async def _get_iam_principal_email( + credentials: google.auth.credentials.Credentials, +) -> str: + """Get email address associated with current authenticated IAM principal. + + Email will be used for automatic IAM database authentication to AlloyDB. + + Args: + credentials (google.auth.credentials.Credentials): + The credentials object to use in finding the associated IAM + principal email address. + + Returns: + email (str): + The email address associated with the current authenticated IAM + principal. + """ + # refresh credentials if they are not valid + if not credentials.valid: + request = google.auth.transport.requests.Request() + credentials.refresh(request) + if hasattr(credentials, "_service_account_email"): + return credentials._service_account_email.replace(".gserviceaccount.com", "") + # call OAuth2 api to get IAM principal email associated with OAuth2 token + url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" + async with aiohttp.ClientSession() as client: + response = await client.get(url, raise_for_status=True) + response_json: Dict = await response.json() + email = response_json.get("email") + if email is None: + raise ValueError( + "Failed to automatically obtain authenticated IAM principal's " + "email address using environment's ADC credentials!" + ) + return email.replace(".gserviceaccount.com", "") + + +@dataclass +class Column: + name: str + data_type: str + nullable: bool = True + + def __post_init__(self) -> None: + """Check if initialization parameters are valid. + + Raises: + ValueError: If Column name is not string. + ValueError: If data_type is not type string. + """ + + if not isinstance(self.name, str): + raise ValueError("Column name must be type string") + if not isinstance(self.data_type, str): + raise ValueError("Column data_type must be type string") + + +class AlloyDBEngine: + """A class for managing connections to a AlloyDB database.""" + + + _connector: Optional[AsyncConnector] = None + _default_loop: Optional[asyncio.AbstractEventLoop] = None + _default_thread: Optional[Thread] = None + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + loop: Optional[asyncio.AbstractEventLoop], + thread: Optional[Thread], + ) -> None: + """AlloyDBEngine constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (AsyncEngine): Async engine connection pool. + loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine. + thread (Optional[Thread]): Thread used to create the engine async. + + Raises: + Exception: If the constructor is called directly by the user. + """ + + if key != AlloyDBEngine.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self._pool = pool + self._loop = loop + self._thread = thread + + @classmethod + def __start_background_loop( + cls, + project_id: str, + region: str, + cluster: str, + instance: str, + database: str, + user: Optional[str] = None, + password: Optional[str] = None, + ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, + iam_account_email: Optional[str] = None, + ) -> Future: + # Running a loop in a background thread allows us to support + # async methods from non-async environments + if cls._default_loop is None: + cls._default_loop = asyncio.new_event_loop() + cls._default_thread = Thread( + target=cls._default_loop.run_forever, daemon=True + ) + cls._default_thread.start() + coro = cls._create( + project_id, + region, + cluster, + instance, + database, + ip_type, + user, + password, + loop=cls._default_loop, + thread=cls._default_thread, + iam_account_email=iam_account_email, + ) + return asyncio.run_coroutine_threadsafe(coro, cls._default_loop) + + @classmethod + def from_instance( + cls: Type[AlloyDBEngine], + project_id: str, + region: str, + cluster: str, + instance: str, + database: str, + user: Optional[str] = None, + password: Optional[str] = None, + ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, + iam_account_email: Optional[str] = None, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine from an AlloyDB instance. + + Args: + project_id (str): GCP project ID. + region (str): Cloud AlloyDB instance region. + cluster (str): Cloud AlloyDB cluster name. + instance (str): Cloud AlloyDB instance name. + database (str): Database name. + user (Optional[str]): Cloud AlloyDB user name. Defaults to None. + password (Optional[str]): Cloud AlloyDB user password. Defaults to None. + ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC. + iam_account_email (Optional[str], optional): IAM service account email. Defaults to None. + + Returns: + AlloyDBEngine: A newly created AlloyDBEngine instance. + """ + future = cls.__start_background_loop( + project_id, + region, + cluster, + instance, + database, + user, + password, + ip_type, + iam_account_email=iam_account_email, + ) + return future.result() + + @classmethod + async def _create( + cls: Type[AlloyDBEngine], + project_id: str, + region: str, + cluster: str, + instance: str, + database: str, + ip_type: Union[str, IPTypes], + user: Optional[str] = None, + password: Optional[str] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + thread: Optional[Thread] = None, + iam_account_email: Optional[str] = None, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine from an AlloyDB instance. + + Args: + project_id (str): GCP project ID. + region (str): Cloud AlloyDB instance region. + cluster (str): Cloud AlloyDB cluster name. + instance (str): Cloud AlloyDB instance name. + database (str): Database name. + ip_type (Union[str, IPTypes]): IP address type. Defaults to IPTypes.PUBLIC. + user (Optional[str]): Cloud AlloyDB user name. Defaults to None. + password (Optional[str]): Cloud AlloyDB user password. Defaults to None. + loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine. + thread (Optional[Thread]): Thread used to create the engine async. + iam_account_email (Optional[str]): IAM service account email. + + Raises: + ValueError: Raises error if only one of 'user' or 'password' is specified. + + Returns: + AlloyDBEngine: A newly created AlloyDBEngine instance. + """ + # error if only one of user or password is set, must be both or neither + if bool(user) ^ bool(password): + raise ValueError( + "Only one of 'user' or 'password' were specified. Either " + "both should be specified to use basic user/password " + "authentication or neither for IAM DB authentication." + ) + + if cls._connector is None: + cls._connector = AsyncConnector( + user_agent=USER_AGENT, refresh_strategy=RefreshStrategy.LAZY + ) + + # if user and password are given, use basic auth + if user and password: + enable_iam_auth = False + db_user = user + # otherwise use automatic IAM database authentication + else: + enable_iam_auth = True + if iam_account_email: + db_user = iam_account_email + else: + # get application default credentials + credentials, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/userinfo.email"] + ) + db_user = await _get_iam_principal_email(credentials) + + # anonymous function to be used for SQLAlchemy 'creator' argument + async def getconn() -> asyncpg.Connection: + conn = await cls._connector.connect( # type: ignore + f"projects/{project_id}/locations/{region}/clusters/{cluster}/instances/{instance}", + "asyncpg", + user=db_user, + password=password, + db=database, + enable_iam_auth=enable_iam_auth, + ip_type=ip_type, + ) + return conn + + engine = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + return cls(cls.__create_key, engine, loop, thread) + + @classmethod + async def afrom_instance( + cls: Type[AlloyDBEngine], + project_id: str, + region: str, + cluster: str, + instance: str, + database: str, + user: Optional[str] = None, + password: Optional[str] = None, + ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, + iam_account_email: Optional[str] = None, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine from an AlloyDB instance. + + Args: + project_id (str): GCP project ID. + region (str): Cloud AlloyDB instance region. + cluster (str): Cloud AlloyDB cluster name. + instance (str): Cloud AlloyDB instance name. + database (str): Cloud AlloyDB database name. + user (Optional[str], optional): Cloud AlloyDB user name. Defaults to None. + password (Optional[str], optional): Cloud AlloyDB user password. Defaults to None. + ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC. + iam_account_email (Optional[str], optional): IAM service account email. Defaults to None. + + Returns: + AlloyDBEngine: A newly created AlloyDBEngine instance. + """ + future = cls.__start_background_loop( + project_id, + region, + cluster, + instance, + database, + user, + password, + ip_type, + iam_account_email=iam_account_email, + ) + return await asyncio.wrap_future(future) + + @classmethod + def from_engine( + cls: Type[AlloyDBEngine], + engine: AsyncEngine, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine instance from an AsyncEngine.""" + return cls(cls.__create_key, engine, loop, None) + + @classmethod + def from_engine_args( + cls, + url: Union[str | URL], + **kwargs: Any, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine instance from arguments + + Args: + url (Optional[str]): the URL used to connect to a database. Use url or set other arguments. + + Raises: + ValueError: If not all database url arguments are specified + + Returns: + AlloyDBEngine + """ + # Running a loop in a background thread allows us to support + # async methods from non-async environments + if cls._default_loop is None: + cls._default_loop = asyncio.new_event_loop() + cls._default_thread = Thread( + target=cls._default_loop.run_forever, daemon=True + ) + cls._default_thread.start() + + driver = "postgresql+asyncpg" + if (isinstance(url, str) and not url.startswith(driver)) or ( + isinstance(url, URL) and url.drivername != driver + ): + raise ValueError("Driver must be type 'postgresql+asyncpg'") + + engine = create_async_engine(url, **kwargs) + return cls(cls.__create_key, engine, cls._default_loop, cls._default_thread) + + async def _run_as_async(self, coro: Awaitable[T]) -> T: + """Run an async coroutine asynchronously""" + # If a loop has not been provided, attempt to run in current thread + if not self._loop: + return await coro + # Otherwise, run in the background thread + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self._loop) + ) + + def _run_as_sync(self, coro: Awaitable[T]) -> T: + """Run an async coroutine synchronously""" + if not self._loop: + raise Exception( + "Engine was initialized without a background loop and cannot call sync methods." + ) + return asyncio.run_coroutine_threadsafe(coro, self._loop).result() + + async def close(self) -> None: + """Dispose of connection pool""" + await self._pool.dispose() + + async def _ainit_checkpoint_table( + self, schema_name: str = "public" + ) -> None: + """ + Create AlloyDB tables to save checkpoints. + + Args: + schema_name (str): The schema name to store the checkpoint tables. + Default: "public". + + Returns: + None + """ + create_checkpoints_table = f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoints( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + parent_checkpoint_id TEXT, + v INTEGER NOT NULL, + checkpoint JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{{}}', + channel TEXT NOT NULL, + version TEXT NOT NULL, + type TEXT, + blob BYTEA, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) + );""" + + create_checkpoint_writes_table = f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoint_writes ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + type TEXT, + blob BYTEA NOT NULL, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) + );""" + + async with self._pool.connect() as conn: + await conn.execute(text(create_checkpoints_table)) + await conn.execute(text(create_checkpoint_writes_table)) + await conn.commit() + + async def ainit_checkpoint_table( + self, schema_name: str = "public" + ) -> None: + """Create an AlloyDB table to save checkpoint messages. + + Args: + schema_name (str): The schema name to store checkpoint tables. + Default: "public". + + Returns: + None + """ + await self._run_as_async( + self._ainit_checkpoint_table( + schema_name, + ) + ) + + def init_checkpoint_table( + self, schema_name: str = "public" + ) -> None: + """Create Cloud SQL tables to store checkpoints. + + Args: + schema_name (str): The schema name to store checkpoint tables. + Default: "public". + + Returns: + None + """ + self._run_as_sync(self._ainit_checkpoint_table(schema_name)) + + async def _aload_table_schema( + self, table_name: str, schema_name: str = "public" + ) -> Table: + """ + Load table schema from an existing table in a PgSQL database, potentially from a specific database schema. + + Args: + table_name: The name of the table to load the table schema from. + schema_name: The name of the database schema where the table resides. + Default: "public". + + Returns: + (sqlalchemy.Table): The loaded table, including its table schema information. + """ + metadata = MetaData() + async with self._pool.connect() as conn: + try: + await conn.run_sync( + metadata.reflect, schema=schema_name, only=[table_name] + ) + except InvalidRequestError as e: + raise ValueError( + f"Table, '{schema_name}'.'{table_name}', does not exist: " + str(e) + ) + + table = Table(table_name, metadata, schema=schema_name) + # Extract the schema information + schema = [] + for column in table.columns: + schema.append( + { + "name": column.name, + "type": column.type.python_type, + "max_length": getattr(column.type, "length", None), + "nullable": not column.nullable, + } + ) + + return metadata.tables[f"{schema_name}.{table_name}"] diff --git a/src/langgraph_google_alloydb_pg/py.typed b/src/langgraph_google_alloydb_pg/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/langgraph_google_alloydb_pg/version.py b/src/langgraph_google_alloydb_pg/version.py new file mode 100644 index 00000000..c1c8212d --- /dev/null +++ b/src/langgraph_google_alloydb_pg/version.py @@ -0,0 +1,15 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.1.0" diff --git a/tests/test_async_checkpoint.py b/tests/test_async_checkpoint.py new file mode 100644 index 00000000..538394f5 --- /dev/null +++ b/tests/test_async_checkpoint.py @@ -0,0 +1,55 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +import pytest_asyncio + +from sqlalchemy import text + +project_id = os.environ["PROJECT_ID"] +region = os.environ["REGION"] +cluster_id = os.environ["CLUSTER_ID"] +instance_id = os.environ["INSTANCE_ID"] +db_name = os.environ["DATABASE_ID"] + +from langgraph_google_alloydb_pg import AlloyDBEngine +from langgraph_google_alloydb_pg.async_checkpoint import AsyncAlloyDBSaver + + +async def aexecute(engine: AlloyDBEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +@pytest_asyncio.fixture +async def async_engine(): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=project_id, + region=region, + cluster=cluster_id, + instance=instance_id, + database=db_name, + ) + await async_engine._ainit_checkpoint_table() + yield async_engine + checkpoints_query = "DROP TABLE IF EXISTS checkpoints" + await aexecute(async_engine, checkpoints_query) + checkpoint_writes_query = "DROP TABLE IF EXISTS checkpoint_writes" + await aexecute(async_engine, checkpoint_writes_query) + await async_engine.close() + diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 00000000..c38dc3b1 --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,14 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +