From 142046c085d70a782f69aa048006a9d8ea6272a7 Mon Sep 17 00:00:00 2001 From: Fernando Omar Salazar Ortiz Date: Thu, 5 Dec 2024 02:39:14 -0600 Subject: [PATCH 1/3] fix(checkpoint): Address PR comments from https://github.com/googleapis/langchain-google-cloud-sql-pg-python/pull/235#pullrequestreview-2480107632 --- src/langgraph_google_alloydb_pg/__init__.py | 23 + .../async_checkpoint.py | 109 ++++ src/langgraph_google_alloydb_pg/checkpoint.py | 174 ++++++ src/langgraph_google_alloydb_pg/engine.py | 521 ++++++++++++++++++ src/langgraph_google_alloydb_pg/py.typed | 0 src/langgraph_google_alloydb_pg/version.py | 15 + tests/test_async_checkpoint.py | 63 +++ tests/test_checkpoint.py | 14 + 8 files changed, 919 insertions(+) create mode 100644 src/langgraph_google_alloydb_pg/__init__.py create mode 100644 src/langgraph_google_alloydb_pg/async_checkpoint.py create mode 100644 src/langgraph_google_alloydb_pg/checkpoint.py create mode 100644 src/langgraph_google_alloydb_pg/engine.py create mode 100644 src/langgraph_google_alloydb_pg/py.typed create mode 100644 src/langgraph_google_alloydb_pg/version.py create mode 100644 tests/test_async_checkpoint.py create mode 100644 tests/test_checkpoint.py diff --git a/src/langgraph_google_alloydb_pg/__init__.py b/src/langgraph_google_alloydb_pg/__init__.py new file mode 100644 index 00000000..1306df4d --- /dev/null +++ b/src/langgraph_google_alloydb_pg/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .checkpoint import AlloyDBSaver +from .async_checkpoint import AsyncAlloyDBSaver +from .version import __version__ + +__all__ = [ + "AlloyDBSaver", + "AsyncAlloyDBSaver", + "__version__", +] \ No newline at end of file diff --git a/src/langgraph_google_alloydb_pg/async_checkpoint.py b/src/langgraph_google_alloydb_pg/async_checkpoint.py new file mode 100644 index 00000000..c0515a83 --- /dev/null +++ b/src/langgraph_google_alloydb_pg/async_checkpoint.py @@ -0,0 +1,109 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from contextlib import asynccontextmanager + +import json +from typing import List, Sequence, Any, AsyncIterator, Iterator, Optional + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from langchain_core.runnables import RunnableConfig + +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple +) +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol + +from langgraph.checkpoint.serde.base import SerializerProtocol + +MetadataInput = Optional[dict[str, Any]] + +from .engine import AlloyDBEngine + + +class AsyncAlloyDBSaver(BaseCheckpointSaver[str]): + """Checkpoint stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + serde: Optional[SerializerProtocol] = None + ) -> None: + super().__init__(serde=serde) + if key != AsyncAlloyDBSaver.__create_key: + raise Exception( + "only create class through 'create' or 'create_sync' methods" + ) + self.pool = pool + self.serde = serde + + @classmethod + async def create( + cls, + engine: AlloyDBEngine, + serde: Optional[SerializerProtocol] = None + ) -> "AsyncAlloyDBSaver": + pass + + + + async def alist( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + pass + + async def aget_tuple(self): + pass + + async def aput(self): + pass + + async def aput_writes(self): + pass + + def list(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." + ) + + def get_tuple(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." + ) + + def put(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." + ) + + def put_writes(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." + ) \ No newline at end of file diff --git a/src/langgraph_google_alloydb_pg/checkpoint.py b/src/langgraph_google_alloydb_pg/checkpoint.py new file mode 100644 index 00000000..7b92a671 --- /dev/null +++ b/src/langgraph_google_alloydb_pg/checkpoint.py @@ -0,0 +1,174 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterator, Sequence +from typing import Any, Optional + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple +) +from langgraph.checkpoint.serde.base import SerializerProtocol + +from .async_checkpoint import AsyncAlloyDBSaver +from .engine import AlloyDBEngine + +class AlloyDBSaver(BaseCheckpointSaver[str]): + + __create_key = object() + + def __init__( + self, + key: object, + engine: AlloyDBEngine, + checkpoint: AsyncAlloyDBSaver, + serde: Optional[SerializerProtocol] = None + ) -> None: + super().__init__(serde=serde) + if key != AlloyDBSaver.__create_key: + raise Exception( + "only create class through 'create' or 'create_sync' methods" + ) + self._engine = engine + self.__checkpoint = checkpoint + + @classmethod + async def create( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None + ) -> "AlloyDBSaver": + """Create a new AlloyDBSaver instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + session_id (str): Retrieve the table content with this session ID. + table_name (str): Table name that stores the chat message history. + schema_name (str): The schema name where the table is located (default: "public"). + serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AlloyDBSaver: A newly created instance of AlloyDBSaver. + """ + coro = AsyncAlloyDBSaver.create( + engine, table_name, schema_name, serde + ) + checkpoint = engine._run_as_async(coro) + return cls(cls.__create_key, engine, table_name, schema_name, serde) + + @classmethod + def create_sync( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + serde: Optional[SerializerProtocol] = None + ) -> "AlloyDBSaver": + """Create a new AlloyDBSaver instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + table_name (str): Table name that stores the chat message history. + schema_name (str): The schema name where the table is located (default: "public"). + serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AlloyDBChatMessageHistory: A newly created instance of AlloyDBSaver. + """ + coro = AsyncAlloyDBSaver.create( + engine, serde + ) + checkpoint = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, table_name, schema_name, serde) + + async def alist(self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None + ) -> Iterator[CheckpointTuple]: + '''List checkpoints from AlloyDB ''' + await self._engine._run_as_async(self.__checkpoint.alist(config, filter, before, limit)) + + def list(self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None + ) -> Iterator[CheckpointTuple]: + '''List checkpoints from AlloyDB ''' + self._engine._run_as_sync(self.__checkpoint.alist(config, filter, before, limit)) + + async def aget_tuple( + self, + config: RunnableConfig + ) -> Optional[CheckpointTuple]: + '''Get a checkpoint tuple from AlloyDB''' + await self._engine._run_as_sync(self.__checkpoint.aget_tuple(config)) + + def get_tuple( + self, + config: RunnableConfig + ) -> Optional[CheckpointTuple]: + '''Get a checkpoint tuple from AlloyDB''' + self._engine._run_as_sync(self.__checkpoint.aget_tuple(config)) + + async def aput(self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions + ) -> RunnableConfig: + '''Save a checkpoint to AlloyDB''' + await self._engine._run_as_async(self.__checkpoint.aput(config, checkpoint, metadata, new_versions)) + + def put(self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions + ) -> RunnableConfig: + '''Save a checkpoint to AlloyDB''' + self._engine._run_as_sync(self.__checkpoint.aput(config, checkpoint, metadata, new_versions)) + + async def aput_writes(self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str + ) -> None: + '''Store intermediate writes linked to a checkpoint''' + await self._engine._run_as_sync(self.__checkpoint.aput_writes(config, writes, task_id)) + + def put_writes(self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str + ) -> None: + '''Store intermediate writes linked to a checkpoint''' + self._engine._run_as_sync(self.__checkpoint.aput_writes(config, writes, task_id)) + \ No newline at end of file diff --git a/src/langgraph_google_alloydb_pg/engine.py b/src/langgraph_google_alloydb_pg/engine.py new file mode 100644 index 00000000..fa89312d --- /dev/null +++ b/src/langgraph_google_alloydb_pg/engine.py @@ -0,0 +1,521 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from concurrent.futures import Future +from dataclasses import dataclass +from threading import Thread +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) + +import aiohttp +import google.auth # type: ignore +import google.auth.transport.requests # type: ignore +from google.cloud.alloydb.connector import AsyncConnector, IPTypes, RefreshStrategy +from sqlalchemy import MetaData, RowMapping, Table, text +from sqlalchemy.engine import URL +from sqlalchemy.exc import InvalidRequestError +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from langgraph.checkpoint.postgres.base import ( + SELECT_SQL, + MIGRATIONS, + UPSERT_CHECKPOINT_BLOBS_SQL, + UPSERT_CHECKPOINTS_SQL, + UPSERT_CHECKPOINT_WRITES_SQL, + INSERT_CHECKPOINT_WRITES_SQL +) + +from .version import __version__ + +if TYPE_CHECKING: + import asyncpg # type: ignore + import google.auth.credentials # type: ignore + +T = TypeVar("T") + +USER_AGENT = "langgraph_google_alloydb_pg/" + __version__ + + +async def _get_iam_principal_email( + credentials: google.auth.credentials.Credentials, +) -> str: + """Get email address associated with current authenticated IAM principal. + + Email will be used for automatic IAM database authentication to AlloyDB. + + Args: + credentials (google.auth.credentials.Credentials): + The credentials object to use in finding the associated IAM + principal email address. + + Returns: + email (str): + The email address associated with the current authenticated IAM + principal. + """ + # refresh credentials if they are not valid + if not credentials.valid: + request = google.auth.transport.requests.Request() + credentials.refresh(request) + if hasattr(credentials, "_service_account_email"): + return credentials._service_account_email.replace(".gserviceaccount.com", "") + # call OAuth2 api to get IAM principal email associated with OAuth2 token + url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" + async with aiohttp.ClientSession() as client: + response = await client.get(url, raise_for_status=True) + response_json: Dict = await response.json() + email = response_json.get("email") + if email is None: + raise ValueError( + "Failed to automatically obtain authenticated IAM principal's " + "email address using environment's ADC credentials!" + ) + return email.replace(".gserviceaccount.com", "") + + +@dataclass +class Column: + name: str + data_type: str + nullable: bool = True + + def __post_init__(self) -> None: + """Check if initialization parameters are valid. + + Raises: + ValueError: If Column name is not string. + ValueError: If data_type is not type string. + """ + + if not isinstance(self.name, str): + raise ValueError("Column name must be type string") + if not isinstance(self.data_type, str): + raise ValueError("Column data_type must be type string") + + +class AlloyDBEngine: + """A class for managing connections to a AlloyDB database.""" + + + _connector: Optional[AsyncConnector] = None + _default_loop: Optional[asyncio.AbstractEventLoop] = None + _default_thread: Optional[Thread] = None + __create_key = object() + + def __init__( + self, + key: object, + pool: AsyncEngine, + loop: Optional[asyncio.AbstractEventLoop], + thread: Optional[Thread], + ) -> None: + """AlloyDBEngine constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (AsyncEngine): Async engine connection pool. + loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine. + thread (Optional[Thread]): Thread used to create the engine async. + + Raises: + Exception: If the constructor is called directly by the user. + """ + + if key != AlloyDBEngine.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self._pool = pool + self._loop = loop + self._thread = thread + + @classmethod + def __start_background_loop( + cls, + project_id: str, + region: str, + cluster: str, + instance: str, + database: str, + user: Optional[str] = None, + password: Optional[str] = None, + ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, + iam_account_email: Optional[str] = None, + ) -> Future: + # Running a loop in a background thread allows us to support + # async methods from non-async environments + if cls._default_loop is None: + cls._default_loop = asyncio.new_event_loop() + cls._default_thread = Thread( + target=cls._default_loop.run_forever, daemon=True + ) + cls._default_thread.start() + coro = cls._create( + project_id, + region, + cluster, + instance, + database, + ip_type, + user, + password, + loop=cls._default_loop, + thread=cls._default_thread, + iam_account_email=iam_account_email, + ) + return asyncio.run_coroutine_threadsafe(coro, cls._default_loop) + + @classmethod + def from_instance( + cls: Type[AlloyDBEngine], + project_id: str, + region: str, + cluster: str, + instance: str, + database: str, + user: Optional[str] = None, + password: Optional[str] = None, + ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, + iam_account_email: Optional[str] = None, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine from an AlloyDB instance. + + Args: + project_id (str): GCP project ID. + region (str): Cloud AlloyDB instance region. + cluster (str): Cloud AlloyDB cluster name. + instance (str): Cloud AlloyDB instance name. + database (str): Database name. + user (Optional[str]): Cloud AlloyDB user name. Defaults to None. + password (Optional[str]): Cloud AlloyDB user password. Defaults to None. + ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC. + iam_account_email (Optional[str], optional): IAM service account email. Defaults to None. + + Returns: + AlloyDBEngine: A newly created AlloyDBEngine instance. + """ + future = cls.__start_background_loop( + project_id, + region, + cluster, + instance, + database, + user, + password, + ip_type, + iam_account_email=iam_account_email, + ) + return future.result() + + @classmethod + async def _create( + cls: Type[AlloyDBEngine], + project_id: str, + region: str, + cluster: str, + instance: str, + database: str, + ip_type: Union[str, IPTypes], + user: Optional[str] = None, + password: Optional[str] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + thread: Optional[Thread] = None, + iam_account_email: Optional[str] = None, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine from an AlloyDB instance. + + Args: + project_id (str): GCP project ID. + region (str): Cloud AlloyDB instance region. + cluster (str): Cloud AlloyDB cluster name. + instance (str): Cloud AlloyDB instance name. + database (str): Database name. + ip_type (Union[str, IPTypes]): IP address type. Defaults to IPTypes.PUBLIC. + user (Optional[str]): Cloud AlloyDB user name. Defaults to None. + password (Optional[str]): Cloud AlloyDB user password. Defaults to None. + loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine. + thread (Optional[Thread]): Thread used to create the engine async. + iam_account_email (Optional[str]): IAM service account email. + + Raises: + ValueError: Raises error if only one of 'user' or 'password' is specified. + + Returns: + AlloyDBEngine: A newly created AlloyDBEngine instance. + """ + # error if only one of user or password is set, must be both or neither + if bool(user) ^ bool(password): + raise ValueError( + "Only one of 'user' or 'password' were specified. Either " + "both should be specified to use basic user/password " + "authentication or neither for IAM DB authentication." + ) + + if cls._connector is None: + cls._connector = AsyncConnector( + user_agent=USER_AGENT, refresh_strategy=RefreshStrategy.LAZY + ) + + # if user and password are given, use basic auth + if user and password: + enable_iam_auth = False + db_user = user + # otherwise use automatic IAM database authentication + else: + enable_iam_auth = True + if iam_account_email: + db_user = iam_account_email + else: + # get application default credentials + credentials, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/userinfo.email"] + ) + db_user = await _get_iam_principal_email(credentials) + + # anonymous function to be used for SQLAlchemy 'creator' argument + async def getconn() -> asyncpg.Connection: + conn = await cls._connector.connect( # type: ignore + f"projects/{project_id}/locations/{region}/clusters/{cluster}/instances/{instance}", + "asyncpg", + user=db_user, + password=password, + db=database, + enable_iam_auth=enable_iam_auth, + ip_type=ip_type, + ) + return conn + + engine = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + return cls(cls.__create_key, engine, loop, thread) + + @classmethod + async def afrom_instance( + cls: Type[AlloyDBEngine], + project_id: str, + region: str, + cluster: str, + instance: str, + database: str, + user: Optional[str] = None, + password: Optional[str] = None, + ip_type: Union[str, IPTypes] = IPTypes.PUBLIC, + iam_account_email: Optional[str] = None, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine from an AlloyDB instance. + + Args: + project_id (str): GCP project ID. + region (str): Cloud AlloyDB instance region. + cluster (str): Cloud AlloyDB cluster name. + instance (str): Cloud AlloyDB instance name. + database (str): Cloud AlloyDB database name. + user (Optional[str], optional): Cloud AlloyDB user name. Defaults to None. + password (Optional[str], optional): Cloud AlloyDB user password. Defaults to None. + ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC. + iam_account_email (Optional[str], optional): IAM service account email. Defaults to None. + + Returns: + AlloyDBEngine: A newly created AlloyDBEngine instance. + """ + future = cls.__start_background_loop( + project_id, + region, + cluster, + instance, + database, + user, + password, + ip_type, + iam_account_email=iam_account_email, + ) + return await asyncio.wrap_future(future) + + @classmethod + def from_engine( + cls: Type[AlloyDBEngine], + engine: AsyncEngine, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine instance from an AsyncEngine.""" + return cls(cls.__create_key, engine, loop, None) + + @classmethod + def from_engine_args( + cls, + url: Union[str | URL], + **kwargs: Any, + ) -> AlloyDBEngine: + """Create an AlloyDBEngine instance from arguments + + Args: + url (Optional[str]): the URL used to connect to a database. Use url or set other arguments. + + Raises: + ValueError: If not all database url arguments are specified + + Returns: + AlloyDBEngine + """ + # Running a loop in a background thread allows us to support + # async methods from non-async environments + if cls._default_loop is None: + cls._default_loop = asyncio.new_event_loop() + cls._default_thread = Thread( + target=cls._default_loop.run_forever, daemon=True + ) + cls._default_thread.start() + + driver = "postgresql+asyncpg" + if (isinstance(url, str) and not url.startswith(driver)) or ( + isinstance(url, URL) and url.drivername != driver + ): + raise ValueError("Driver must be type 'postgresql+asyncpg'") + + engine = create_async_engine(url, **kwargs) + return cls(cls.__create_key, engine, cls._default_loop, cls._default_thread) + + async def _run_as_async(self, coro: Awaitable[T]) -> T: + """Run an async coroutine asynchronously""" + # If a loop has not been provided, attempt to run in current thread + if not self._loop: + return await coro + # Otherwise, run in the background thread + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self._loop) + ) + + def _run_as_sync(self, coro: Awaitable[T]) -> T: + """Run an async coroutine synchronously""" + if not self._loop: + raise Exception( + "Engine was initialized without a background loop and cannot call sync methods." + ) + return asyncio.run_coroutine_threadsafe(coro, self._loop).result() + + async def _ainit_checkpoint_table( + self, schema_name: str = "public" + ) -> None: + """ + Create AlloyDB tables to save checkpoints. + + Args: + schema_name (str): The schema name to store the checkpoint tables. + Default: "public". + + Returns: + None + """ + MIGRATIONS = [ + f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoint_migrations ( + v INTEGER PRIMARY KEY + );""", + f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoints ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + parent_checkpoint_id TEXT, + type TEXT, + checkpoint JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{{}}', + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) + );""", + f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoint_blobs ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + channel TEXT NOT NULL, + version TEXT NOT NULL, + type TEXT NOT NULL, + blob BYTEA, + PRIMARY KEY (thread_id, checkpoint_ns, channel, version) + );""", + f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoint_writes ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + type TEXT, + blob BYTEA NOT NULL, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) + );""", + f"""ALTER TABLE "{schema_name}".checkpoint_blobs ALTER COLUMN blob DROP not null;""", + ] + async with self._pool.connect() as conn: + create_table_query = MIGRATIONS[0] + result = await conn.execute(text(create_table_query)) + row = await result.fetchone( + text("SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1") + ) + if row is None: + version = -1 + else: + version = row["v"] + for v, migration in zip( + range(version + 1, len(MIGRATIONS)), + MIGRATIONS[version + 1:] + ): + await conn.execute(text(migration)) + query = f"INSERT INTO checkpoint_migrations (v) VALUES ({v})" + await conn.execute(text(query)) + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.commit() + + async def ainit_checkpoint_table( + self, schema_name: str = "public" + ) -> None: + """Create an AlloyDB table to save checkpoint messages. + + Args: + schema_name (str): The schema name to store checkpoint tables. + Default: "public". + + Returns: + None + """ + await self._run_as_async( + self._ainit_checkpoint_table( + schema_name, + ) + ) + + def init_checkpoint_table( + self, schema_name: str = "public" + ) -> None: + """Create Cloud SQL tables to store checkpoints. + + Args: + schema_name (str): The schema name to store checkpoint tables. + Default: "public". + + Returns: + None + """ + self._run_as_sync(self._ainit_checkpoint_table(schema_name)) diff --git a/src/langgraph_google_alloydb_pg/py.typed b/src/langgraph_google_alloydb_pg/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/langgraph_google_alloydb_pg/version.py b/src/langgraph_google_alloydb_pg/version.py new file mode 100644 index 00000000..c1c8212d --- /dev/null +++ b/src/langgraph_google_alloydb_pg/version.py @@ -0,0 +1,15 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.1.0" diff --git a/tests/test_async_checkpoint.py b/tests/test_async_checkpoint.py new file mode 100644 index 00000000..ecacf327 --- /dev/null +++ b/tests/test_async_checkpoint.py @@ -0,0 +1,63 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +import pytest_asyncio + +from sqlalchemy import text + + +project_id = os.environ["PROJECT_ID"] +region = os.environ["REGION"] +cluster_id = os.environ["CLUSTER_ID"] +instance_id = os.environ["INSTANCE_ID"] +db_name = os.environ["DATABASE_ID"] +table_name = "message_store" + str(uuid.uuid4()) +table_name_async = "message_store" + str(uuid.uuid4()) + +from langchain_google_alloydb_pg import AlloyDBEngine +from langchain_google_alloydb_pg.async_chat_message_history import ( + AsyncAlloyDBChatMessageHistory, +) + +from ..src.langgraph_google_alloydb_pg.engine import AlloyDBEngine +from ..src.langgraph_google_alloydb_pg.checkpoint import AsyncAlloyDBSaver + + +@pytest_asyncio.fixture +async def async_engine(): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=project_id, + region=region, + instance=instance_id, + database=db_name, + ) + + await async_engine.setup() + await async_engine.close() + +@pytest.mark.asyncio +async def test_alloydb_checkpoint_async( + async_engine: AlloyDBEngine +) -> None: + pass + +@pytest.mark.asyncio +async def test_alloydb_checkpoint_sync( + async_engine: AlloyDBEngine +) -> None: + pass \ No newline at end of file diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 00000000..c38dc3b1 --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,14 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + From 2dfeff6070c038a1fec3400abfdce9c546de0e65 Mon Sep 17 00:00:00 2001 From: Fernando Omar Salazar Ortiz Date: Fri, 6 Dec 2024 13:17:17 -0600 Subject: [PATCH 2/3] fix: address comments --- src/langgraph_google_alloydb_pg/__init__.py | 2 -- src/langgraph_google_alloydb_pg/async_checkpoint.py | 1 - src/langgraph_google_alloydb_pg/checkpoint.py | 6 +++--- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/langgraph_google_alloydb_pg/__init__.py b/src/langgraph_google_alloydb_pg/__init__.py index 1306df4d..839e5027 100644 --- a/src/langgraph_google_alloydb_pg/__init__.py +++ b/src/langgraph_google_alloydb_pg/__init__.py @@ -13,11 +13,9 @@ # limitations under the License. from .checkpoint import AlloyDBSaver -from .async_checkpoint import AsyncAlloyDBSaver from .version import __version__ __all__ = [ "AlloyDBSaver", - "AsyncAlloyDBSaver", "__version__", ] \ No newline at end of file diff --git a/src/langgraph_google_alloydb_pg/async_checkpoint.py b/src/langgraph_google_alloydb_pg/async_checkpoint.py index c0515a83..9394c7c7 100644 --- a/src/langgraph_google_alloydb_pg/async_checkpoint.py +++ b/src/langgraph_google_alloydb_pg/async_checkpoint.py @@ -57,7 +57,6 @@ def __init__( "only create class through 'create' or 'create_sync' methods" ) self.pool = pool - self.serde = serde @classmethod async def create( diff --git a/src/langgraph_google_alloydb_pg/checkpoint.py b/src/langgraph_google_alloydb_pg/checkpoint.py index 7b92a671..2d1f3d58 100644 --- a/src/langgraph_google_alloydb_pg/checkpoint.py +++ b/src/langgraph_google_alloydb_pg/checkpoint.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterator, Sequence +from collections.abc import Iterator, AsyncIterator, Sequence from typing import Any, Optional from langchain_core.runnables import RunnableConfig @@ -96,7 +96,7 @@ def create_sync( IndexError: If the table provided does not contain required schema. Returns: - AlloyDBChatMessageHistory: A newly created instance of AlloyDBSaver. + AlloyDBSaver: A newly created instance of AlloyDBSaver. """ coro = AsyncAlloyDBSaver.create( engine, serde @@ -110,7 +110,7 @@ async def alist(self, filter: Optional[dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None - ) -> Iterator[CheckpointTuple]: + ) -> AsyncIterator[CheckpointTuple]: '''List checkpoints from AlloyDB ''' await self._engine._run_as_async(self.__checkpoint.alist(config, filter, before, limit)) From 3d6afa0931018fe80b190f715be9fc23a3af23c5 Mon Sep 17 00:00:00 2001 From: Fernando Omar Salazar Ortiz Date: Tue, 17 Dec 2024 00:59:17 -0600 Subject: [PATCH 3/3] feat(checkpoint): implement inital table schema --- src/langgraph_google_alloydb_pg/__init__.py | 2 + .../async_checkpoint.py | 280 ++++++++++++++++-- src/langgraph_google_alloydb_pg/checkpoint.py | 95 +++--- src/langgraph_google_alloydb_pg/engine.py | 139 +++++---- tests/test_async_checkpoint.py | 40 +-- 5 files changed, 402 insertions(+), 154 deletions(-) diff --git a/src/langgraph_google_alloydb_pg/__init__.py b/src/langgraph_google_alloydb_pg/__init__.py index 839e5027..dd8b05c0 100644 --- a/src/langgraph_google_alloydb_pg/__init__.py +++ b/src/langgraph_google_alloydb_pg/__init__.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .engine import AlloyDBEngine from .checkpoint import AlloyDBSaver from .version import __version__ __all__ = [ + "AlloyDBEngine", "AlloyDBSaver", "__version__", ] \ No newline at end of file diff --git a/src/langgraph_google_alloydb_pg/async_checkpoint.py b/src/langgraph_google_alloydb_pg/async_checkpoint.py index 9394c7c7..54e7f84d 100644 --- a/src/langgraph_google_alloydb_pg/async_checkpoint.py +++ b/src/langgraph_google_alloydb_pg/async_checkpoint.py @@ -13,10 +13,12 @@ # limitations under the License. import asyncio +import asyncpg # type: ignore + from contextlib import asynccontextmanager import json -from typing import List, Sequence, Any, AsyncIterator, Iterator, Optional +from typing import List, Sequence, Any, AsyncIterator, Iterator, Optional, Dict, Tuple from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncEngine @@ -24,11 +26,13 @@ from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, BaseCheckpointSaver, ChannelVersions, Checkpoint, CheckpointMetadata, - CheckpointTuple + CheckpointTuple, + get_checkpoint_id ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol @@ -49,6 +53,7 @@ def __init__( self, key: object, pool: AsyncEngine, + schema_name: str = "public", serde: Optional[SerializerProtocol] = None ) -> None: super().__init__(serde=serde) @@ -57,52 +62,291 @@ def __init__( "only create class through 'create' or 'create_sync' methods" ) self.pool = pool + self.schema_name = schema_name @classmethod async def create( cls, engine: AlloyDBEngine, + schema_name: str = "public", serde: Optional[SerializerProtocol] = None ) -> "AsyncAlloyDBSaver": - pass - - + """Create a new AsyncAlloyDBSaver instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + schema_name (str): The schema name where the table is located (default: "public"). + serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AsyncAlloyDBSaver: A newly created instance of AsyncAlloyDBSaver. + """ + + checkpoints_table_schema = await engine._aload_table_schema("checkpoints", schema_name) + checkpoints_column_names = checkpoints_table_schema.columns.keys() + checkpoints_required_columns = ["thread_id", + "checkpoint_ns", + "checkpoint_id", + "parent_checkpoint_id", + "v", + "type", + "checkpoint", + "metadata"] + + if not (all(x in checkpoints_column_names for x in checkpoints_required_columns)): + raise IndexError( + f"Table checkpoints.'{schema_name}' has incorrect schema. Got " + f"column names '{checkpoints_column_names}' but required column names " + f"'{checkpoints_required_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.checkpoints (" + "\n thread_id TEXT NOT NULL," + "\n checkpoint_ns TEXT NOT NULL," + "\n checkpoint_id UUID NOT NULL," + "\n parent_checkpoint_id UUID," + "\n v INT NOT NULL," + "\n type TEXT NOT NULL," + "\n checkpoint JSONB NOT NULL," + "\n metadata JSONB" + "\n);" + ) + + checkpoint_writes_table_schema = await engine._aload_table_schema("checkpoint_writes", schema_name) + checkpoint_writes_column_names = checkpoint_writes_table_schema.columns.keys() + + checkpoint_writes_columns = ["thread_id", + "checkpoint_ns", + "checkpoint_id", + "task_id", + "idx", + "channel", + "type", + "blob"] + + if not (all(x in checkpoint_writes_column_names for x in checkpoint_writes_columns)): + raise IndexError( + f"Table checkpoint_writes.'{schema_name}' has incorrect schema. Got " + f"column names '{checkpoint_writes_column_names}' but required column names " + f"'{checkpoint_writes_columns}'.\nPlease create table with following schema:" + f"\nCREATE TABLE {schema_name}.checkpoint_writes (" + "\n thread_id TEXT NOT NULL," + "\n checkpoint_ns TEXT NOT NULL," + "\n checkpoint_id UUID NOT NULL," + "\n task_id UUID NOT NULL," + "\n idx INT NOT NULL," + "\n channel TEXT NOT NULL," + "\n type TEXT NOT NULL," + "\n blob JSONB NOT NULL" + "\n);" + ) + return cls(cls.__create_key, engine._pool, schema_name, serde) + async def alist( self, config: Optional[RunnableConfig], - *, - filter: Optional[dict[str, Any]] = None, + filter: Optional[Dict[str, Any]] = None, before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[CheckpointTuple]: - pass + """Asynchronously list checkpoints that match the given criteria. + + Args: + config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. + before (Optional[RunnableConfig]): List checkpoints created before this configuration. + limit (Optional[int]): Maximum number of checkpoints to return. + + Returns: + AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples. + + Raises: + NotImplementedError: Implement this method in your custom checkpoint saver. + """ + raise NotImplementedError + yield - async def aget_tuple(self): - pass + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Asynchronously fetch a checkpoint tuple using the given configuration. + + Args: + config (RunnableConfig): Configuration specifying which checkpoint to retrieve. + + Returns: + Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. + + Raises: + NotImplementedError: Implement this method in your custom checkpoint saver. + """ + raise NotImplementedError - async def aput(self): - pass + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Asynchronously store a checkpoint with its configuration and metadata. + + Args: + config (RunnableConfig): Configuration for the checkpoint. + checkpoint (Checkpoint): The checkpoint to store. + metadata (CheckpointMetadata): Additional metadata for the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + checkpoint_id = configurable.pop( + "checkpoint_id", configurable.pop("thread_ts", None) + ) + + copy = checkpoint.copy() + next_config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + query = f"""INSERT INTO "{self.schema_name}".checkpoints(thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata, channel, version, type, blob) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :parent_checkpoint_id, :checkpoint, :metadata, :channel, :version, :type, :blob); + """ + + async with self.pool.connect() as conn: + await conn.execute( + text(query), + { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + "parent_checkpoint_id": config.get("checkpoint_id"), + "checkpoint": json.dumps(copy), + "metadata": json.dumps(dict(metadata)), + "channel": copy.pop("channel_values"), + "version": new_versions, + "type": next_config["configurable"]["type"], + "blob": json.dumps(next_config["configurable"]["blob"]), + }, + ) + await conn.commit() + + return next_config + - async def aput_writes(self): - pass + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + ) -> None: + """Asynchronously store intermediate writes linked to a checkpoint. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + + Raises: + NotImplementedError: Implement this method in your custom checkpoint saver. + """ + query = f"""INSERT INTO "{self.schema_name}".checkpoint_writes(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :task_id, :idx, :channel, :type, :blob) + """ + upsert = "" + async with self.pool.connect() as conn: + await conn.execute( + text(query), + { + "thread_id": config["configurable"]["thread_id"], + "checkpoint_ns": config["configurable"]["checkpoint_ns"], + "checkpoint_id": config["configurable"]["checkpoint_id"], + "task_id": task_id, + "idx": idx, + "channel": write[0], + "type": write[1], + "blob": json.dumps(write[2]), + }, + ) + await conn.commit() + - def list(self) -> None: + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[Dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." ) - def get_tuple(self) -> None: + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Fetch a checkpoint tuple using the given configuration. + + Args: + config (RunnableConfig): Configuration specifying which checkpoint to retrieve. + + Returns: + Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. + + Raises: + NotImplementedError: Implement this method in your custom checkpoint saver. + """ raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." ) - def put(self) -> None: + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Store a checkpoint with its configuration and metadata. + + Args: + config (RunnableConfig): Configuration for the checkpoint. + checkpoint (Checkpoint): The checkpoint to store. + metadata (CheckpointMetadata): Additional metadata for the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + + Raises: + NotImplementedError: Method impletented in AsyncAlloyDBSaver. + """ raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." ) - def put_writes(self) -> None: + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + + Raises: + NotImplementedError: Method impletented in AsyncAlloyDBSaver. + """ raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." ) \ No newline at end of file diff --git a/src/langgraph_google_alloydb_pg/checkpoint.py b/src/langgraph_google_alloydb_pg/checkpoint.py index 2d1f3d58..ebe9b7f7 100644 --- a/src/langgraph_google_alloydb_pg/checkpoint.py +++ b/src/langgraph_google_alloydb_pg/checkpoint.py @@ -29,6 +29,7 @@ from .engine import AlloyDBEngine class AlloyDBSaver(BaseCheckpointSaver[str]): + """Checkpoint stored in an AlloyDB for PostgreSQL database.""" __create_key = object() @@ -51,7 +52,6 @@ def __init__( async def create( cls, engine: AlloyDBEngine, - table_name: str, schema_name: str = "public", serde: Optional[SerializerProtocol] = None ) -> "AlloyDBSaver": @@ -59,8 +59,6 @@ async def create( Args: engine (AlloyDBEngine): AlloyDB engine to use. - session_id (str): Retrieve the table content with this session ID. - table_name (str): Table name that stores the chat message history. schema_name (str): The schema name where the table is located (default: "public"). serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). @@ -71,16 +69,15 @@ async def create( AlloyDBSaver: A newly created instance of AlloyDBSaver. """ coro = AsyncAlloyDBSaver.create( - engine, table_name, schema_name, serde + engine, schema_name, serde ) - checkpoint = engine._run_as_async(coro) - return cls(cls.__create_key, engine, table_name, schema_name, serde) + checkpoint = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, checkpoint) @classmethod def create_sync( cls, engine: AlloyDBEngine, - table_name: str, schema_name: str = "public", serde: Optional[SerializerProtocol] = None ) -> "AlloyDBSaver": @@ -88,7 +85,6 @@ def create_sync( Args: engine (AlloyDBEngine): AlloyDB engine to use. - table_name (str): Table name that stores the chat message history. schema_name (str): The schema name where the table is located (default: "public"). serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). @@ -99,75 +95,80 @@ def create_sync( AlloyDBSaver: A newly created instance of AlloyDBSaver. """ coro = AsyncAlloyDBSaver.create( - engine, serde + engine, schema_name, serde ) checkpoint = engine._run_as_sync(coro) - return cls(cls.__create_key, engine, table_name, schema_name, serde) + return cls(cls.__create_key, engine, checkpoint) - async def alist(self, - config: Optional[RunnableConfig], - *, - filter: Optional[dict[str, Any]] = None, - before: Optional[RunnableConfig] = None, - limit: Optional[int] = None + + async def alist( + self, + config: Optional[RunnableConfig], + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None ) -> AsyncIterator[CheckpointTuple]: '''List checkpoints from AlloyDB ''' - await self._engine._run_as_async(self.__checkpoint.alist(config, filter, before, limit)) + yield await self._engine._run_as_async(self.__checkpoint.alist(config, filter, before, limit)) - def list(self, - config: Optional[RunnableConfig], - *, - filter: Optional[dict[str, Any]] = None, - before: Optional[RunnableConfig] = None, - limit: Optional[int] = None + def list( + self, + config: Optional[RunnableConfig], + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None ) -> Iterator[CheckpointTuple]: '''List checkpoints from AlloyDB ''' - self._engine._run_as_sync(self.__checkpoint.alist(config, filter, before, limit)) + return self._engine._run_as_sync(self.__checkpoint.alist(config, filter, before, limit)) async def aget_tuple( self, config: RunnableConfig ) -> Optional[CheckpointTuple]: '''Get a checkpoint tuple from AlloyDB''' - await self._engine._run_as_sync(self.__checkpoint.aget_tuple(config)) + return await self._engine._run_as_async(self.__checkpoint.aget_tuple(config)) def get_tuple( self, config: RunnableConfig ) -> Optional[CheckpointTuple]: '''Get a checkpoint tuple from AlloyDB''' - self._engine._run_as_sync(self.__checkpoint.aget_tuple(config)) + return self._engine._run_as_sync(self.__checkpoint.aget_tuple(config)) - async def aput(self, - config: RunnableConfig, - checkpoint: Checkpoint, - metadata: CheckpointMetadata, - new_versions: ChannelVersions + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions ) -> RunnableConfig: '''Save a checkpoint to AlloyDB''' - await self._engine._run_as_async(self.__checkpoint.aput(config, checkpoint, metadata, new_versions)) + return await self._engine._run_as_async(self.__checkpoint.aput(config, checkpoint, metadata, new_versions)) - def put(self, - config: RunnableConfig, - checkpoint: Checkpoint, - metadata: CheckpointMetadata, - new_versions: ChannelVersions + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions ) -> RunnableConfig: '''Save a checkpoint to AlloyDB''' - self._engine._run_as_sync(self.__checkpoint.aput(config, checkpoint, metadata, new_versions)) + return self._engine._run_as_sync(self.__checkpoint.aput(config, checkpoint, metadata, new_versions)) - async def aput_writes(self, - config: RunnableConfig, - writes: Sequence[tuple[str, Any]], - task_id: str + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str ) -> None: '''Store intermediate writes linked to a checkpoint''' - await self._engine._run_as_sync(self.__checkpoint.aput_writes(config, writes, task_id)) + await self._engine._run_as_async(self.__checkpoint.aput_writes(config, writes, task_id)) - def put_writes(self, - config: RunnableConfig, - writes: Sequence[tuple[str, Any]], - task_id: str + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str ) -> None: '''Store intermediate writes linked to a checkpoint''' self._engine._run_as_sync(self.__checkpoint.aput_writes(config, writes, task_id)) diff --git a/src/langgraph_google_alloydb_pg/engine.py b/src/langgraph_google_alloydb_pg/engine.py index fa89312d..2e93f37f 100644 --- a/src/langgraph_google_alloydb_pg/engine.py +++ b/src/langgraph_google_alloydb_pg/engine.py @@ -34,20 +34,10 @@ import google.auth # type: ignore import google.auth.transport.requests # type: ignore from google.cloud.alloydb.connector import AsyncConnector, IPTypes, RefreshStrategy -from sqlalchemy import MetaData, RowMapping, Table, text +from sqlalchemy import MetaData, Table, text from sqlalchemy.engine import URL from sqlalchemy.exc import InvalidRequestError from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine - -from langgraph.checkpoint.postgres.base import ( - SELECT_SQL, - MIGRATIONS, - UPSERT_CHECKPOINT_BLOBS_SQL, - UPSERT_CHECKPOINTS_SQL, - UPSERT_CHECKPOINT_WRITES_SQL, - INSERT_CHECKPOINT_WRITES_SQL -) - from .version import __version__ if TYPE_CHECKING: @@ -417,6 +407,10 @@ def _run_as_sync(self, coro: Awaitable[T]) -> T: "Engine was initialized without a background loop and cannot call sync methods." ) return asyncio.run_coroutine_threadsafe(coro, self._loop).result() + + async def close(self) -> None: + """Dispose of connection pool""" + await self._pool.dispose() async def _ainit_checkpoint_table( self, schema_name: str = "public" @@ -431,61 +425,36 @@ async def _ainit_checkpoint_table( Returns: None """ - MIGRATIONS = [ - f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoint_migrations ( - v INTEGER PRIMARY KEY - );""", - f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoints ( - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - parent_checkpoint_id TEXT, - type TEXT, - checkpoint JSONB NOT NULL, - metadata JSONB NOT NULL DEFAULT '{{}}', - PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) - );""", - f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoint_blobs ( - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - channel TEXT NOT NULL, - version TEXT NOT NULL, - type TEXT NOT NULL, - blob BYTEA, - PRIMARY KEY (thread_id, checkpoint_ns, channel, version) - );""", - f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoint_writes ( - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - task_id TEXT NOT NULL, - idx INTEGER NOT NULL, - channel TEXT NOT NULL, - type TEXT, - blob BYTEA NOT NULL, - PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) - );""", - f"""ALTER TABLE "{schema_name}".checkpoint_blobs ALTER COLUMN blob DROP not null;""", - ] - async with self._pool.connect() as conn: - create_table_query = MIGRATIONS[0] - result = await conn.execute(text(create_table_query)) - row = await result.fetchone( - text("SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1") - ) - if row is None: - version = -1 - else: - version = row["v"] - for v, migration in zip( - range(version + 1, len(MIGRATIONS)), - MIGRATIONS[version + 1:] - ): - await conn.execute(text(migration)) - query = f"INSERT INTO checkpoint_migrations (v) VALUES ({v})" - await conn.execute(text(query)) + create_checkpoints_table = f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoints( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + parent_checkpoint_id TEXT, + v INTEGER NOT NULL, + checkpoint JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{{}}', + channel TEXT NOT NULL, + version TEXT NOT NULL, + type TEXT, + blob BYTEA, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) + );""" + + create_checkpoint_writes_table = f"""CREATE TABLE IF NOT EXISTS "{schema_name}".checkpoint_writes ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + type TEXT, + blob BYTEA NOT NULL, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) + );""" + async with self._pool.connect() as conn: - await conn.execute(text(create_table_query)) + await conn.execute(text(create_checkpoints_table)) + await conn.execute(text(create_checkpoint_writes_table)) await conn.commit() async def ainit_checkpoint_table( @@ -519,3 +488,43 @@ def init_checkpoint_table( None """ self._run_as_sync(self._ainit_checkpoint_table(schema_name)) + + async def _aload_table_schema( + self, table_name: str, schema_name: str = "public" + ) -> Table: + """ + Load table schema from an existing table in a PgSQL database, potentially from a specific database schema. + + Args: + table_name: The name of the table to load the table schema from. + schema_name: The name of the database schema where the table resides. + Default: "public". + + Returns: + (sqlalchemy.Table): The loaded table, including its table schema information. + """ + metadata = MetaData() + async with self._pool.connect() as conn: + try: + await conn.run_sync( + metadata.reflect, schema=schema_name, only=[table_name] + ) + except InvalidRequestError as e: + raise ValueError( + f"Table, '{schema_name}'.'{table_name}', does not exist: " + str(e) + ) + + table = Table(table_name, metadata, schema=schema_name) + # Extract the schema information + schema = [] + for column in table.columns: + schema.append( + { + "name": column.name, + "type": column.type.python_type, + "max_length": getattr(column.type, "length", None), + "nullable": not column.nullable, + } + ) + + return metadata.tables[f"{schema_name}.{table_name}"] diff --git a/tests/test_async_checkpoint.py b/tests/test_async_checkpoint.py index ecacf327..538394f5 100644 --- a/tests/test_async_checkpoint.py +++ b/tests/test_async_checkpoint.py @@ -20,44 +20,36 @@ from sqlalchemy import text - project_id = os.environ["PROJECT_ID"] region = os.environ["REGION"] cluster_id = os.environ["CLUSTER_ID"] instance_id = os.environ["INSTANCE_ID"] db_name = os.environ["DATABASE_ID"] -table_name = "message_store" + str(uuid.uuid4()) -table_name_async = "message_store" + str(uuid.uuid4()) - -from langchain_google_alloydb_pg import AlloyDBEngine -from langchain_google_alloydb_pg.async_chat_message_history import ( - AsyncAlloyDBChatMessageHistory, -) -from ..src.langgraph_google_alloydb_pg.engine import AlloyDBEngine -from ..src.langgraph_google_alloydb_pg.checkpoint import AsyncAlloyDBSaver +from langgraph_google_alloydb_pg import AlloyDBEngine +from langgraph_google_alloydb_pg.async_checkpoint import AsyncAlloyDBSaver +async def aexecute(engine: AlloyDBEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + @pytest_asyncio.fixture async def async_engine(): async_engine = await AlloyDBEngine.afrom_instance( project_id=project_id, region=region, + cluster=cluster_id, instance=instance_id, database=db_name, ) - - await async_engine.setup() + await async_engine._ainit_checkpoint_table() + yield async_engine + checkpoints_query = "DROP TABLE IF EXISTS checkpoints" + await aexecute(async_engine, checkpoints_query) + checkpoint_writes_query = "DROP TABLE IF EXISTS checkpoint_writes" + await aexecute(async_engine, checkpoint_writes_query) await async_engine.close() - -@pytest.mark.asyncio -async def test_alloydb_checkpoint_async( - async_engine: AlloyDBEngine -) -> None: - pass - -@pytest.mark.asyncio -async def test_alloydb_checkpoint_sync( - async_engine: AlloyDBEngine -) -> None: - pass \ No newline at end of file +