-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add Checkpointer base #282
base: langgraph
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .engine import AlloyDBEngine | ||
from .checkpoint import AlloyDBSaver | ||
from .version import __version__ | ||
|
||
__all__ = [ | ||
"AlloyDBEngine", | ||
"AlloyDBSaver", | ||
"__version__", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,352 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import asyncio | ||
import asyncpg # type: ignore | ||
|
||
from contextlib import asynccontextmanager | ||
|
||
import json | ||
from typing import List, Sequence, Any, AsyncIterator, Iterator, Optional, Dict, Tuple | ||
|
||
from sqlalchemy import text | ||
from sqlalchemy.ext.asyncio import AsyncEngine | ||
|
||
from langchain_core.runnables import RunnableConfig | ||
|
||
from langgraph.checkpoint.base import ( | ||
WRITES_IDX_MAP, | ||
BaseCheckpointSaver, | ||
ChannelVersions, | ||
Checkpoint, | ||
CheckpointMetadata, | ||
CheckpointTuple, | ||
get_checkpoint_id | ||
) | ||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer | ||
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol | ||
|
||
from langgraph.checkpoint.serde.base import SerializerProtocol | ||
|
||
MetadataInput = Optional[dict[str, Any]] | ||
|
||
from .engine import AlloyDBEngine | ||
|
||
|
||
class AsyncAlloyDBSaver(BaseCheckpointSaver[str]): | ||
"""Checkpoint stored in an AlloyDB for PostgreSQL database.""" | ||
|
||
__create_key = object() | ||
|
||
def __init__( | ||
self, | ||
key: object, | ||
pool: AsyncEngine, | ||
schema_name: str = "public", | ||
serde: Optional[SerializerProtocol] = None | ||
) -> None: | ||
super().__init__(serde=serde) | ||
if key != AsyncAlloyDBSaver.__create_key: | ||
raise Exception( | ||
"only create class through 'create' or 'create_sync' methods" | ||
) | ||
self.pool = pool | ||
self.schema_name = schema_name | ||
|
||
@classmethod | ||
async def create( | ||
cls, | ||
engine: AlloyDBEngine, | ||
schema_name: str = "public", | ||
serde: Optional[SerializerProtocol] = None | ||
) -> "AsyncAlloyDBSaver": | ||
"""Create a new AsyncAlloyDBSaver instance. | ||
|
||
Args: | ||
engine (AlloyDBEngine): AlloyDB engine to use. | ||
schema_name (str): The schema name where the table is located (default: "public"). | ||
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None). | ||
|
||
Raises: | ||
IndexError: If the table provided does not contain required schema. | ||
|
||
Returns: | ||
AsyncAlloyDBSaver: A newly created instance of AsyncAlloyDBSaver. | ||
""" | ||
|
||
checkpoints_table_schema = await engine._aload_table_schema("checkpoints", schema_name) | ||
checkpoints_column_names = checkpoints_table_schema.columns.keys() | ||
|
||
checkpoints_required_columns = ["thread_id", | ||
"checkpoint_ns", | ||
"checkpoint_id", | ||
"parent_checkpoint_id", | ||
"v", | ||
"type", | ||
"checkpoint", | ||
"metadata"] | ||
|
||
if not (all(x in checkpoints_column_names for x in checkpoints_required_columns)): | ||
raise IndexError( | ||
f"Table checkpoints.'{schema_name}' has incorrect schema. Got " | ||
f"column names '{checkpoints_column_names}' but required column names " | ||
f"'{checkpoints_required_columns}'.\nPlease create table with following schema:" | ||
f"\nCREATE TABLE {schema_name}.checkpoints (" | ||
"\n thread_id TEXT NOT NULL," | ||
"\n checkpoint_ns TEXT NOT NULL," | ||
"\n checkpoint_id UUID NOT NULL," | ||
"\n parent_checkpoint_id UUID," | ||
"\n v INT NOT NULL," | ||
"\n type TEXT NOT NULL," | ||
"\n checkpoint JSONB NOT NULL," | ||
"\n metadata JSONB" | ||
"\n);" | ||
) | ||
|
||
checkpoint_writes_table_schema = await engine._aload_table_schema("checkpoint_writes", schema_name) | ||
checkpoint_writes_column_names = checkpoint_writes_table_schema.columns.keys() | ||
|
||
checkpoint_writes_columns = ["thread_id", | ||
"checkpoint_ns", | ||
"checkpoint_id", | ||
"task_id", | ||
"idx", | ||
"channel", | ||
"type", | ||
"blob"] | ||
|
||
if not (all(x in checkpoint_writes_column_names for x in checkpoint_writes_columns)): | ||
raise IndexError( | ||
f"Table checkpoint_writes.'{schema_name}' has incorrect schema. Got " | ||
f"column names '{checkpoint_writes_column_names}' but required column names " | ||
f"'{checkpoint_writes_columns}'.\nPlease create table with following schema:" | ||
f"\nCREATE TABLE {schema_name}.checkpoint_writes (" | ||
"\n thread_id TEXT NOT NULL," | ||
"\n checkpoint_ns TEXT NOT NULL," | ||
"\n checkpoint_id UUID NOT NULL," | ||
"\n task_id UUID NOT NULL," | ||
"\n idx INT NOT NULL," | ||
"\n channel TEXT NOT NULL," | ||
"\n type TEXT NOT NULL," | ||
"\n blob JSONB NOT NULL" | ||
"\n);" | ||
) | ||
return cls(cls.__create_key, engine._pool, schema_name, serde) | ||
|
||
async def alist( | ||
self, | ||
config: Optional[RunnableConfig], | ||
filter: Optional[Dict[str, Any]] = None, | ||
before: Optional[RunnableConfig] = None, | ||
limit: Optional[int] = None, | ||
) -> AsyncIterator[CheckpointTuple]: | ||
"""Asynchronously list checkpoints that match the given criteria. | ||
|
||
Args: | ||
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. | ||
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. | ||
before (Optional[RunnableConfig]): List checkpoints created before this configuration. | ||
limit (Optional[int]): Maximum number of checkpoints to return. | ||
|
||
Returns: | ||
AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples. | ||
|
||
Raises: | ||
NotImplementedError: Implement this method in your custom checkpoint saver. | ||
""" | ||
raise NotImplementedError | ||
yield | ||
|
||
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: | ||
"""Asynchronously fetch a checkpoint tuple using the given configuration. | ||
|
||
Args: | ||
config (RunnableConfig): Configuration specifying which checkpoint to retrieve. | ||
|
||
Returns: | ||
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. | ||
|
||
Raises: | ||
NotImplementedError: Implement this method in your custom checkpoint saver. | ||
""" | ||
raise NotImplementedError | ||
|
||
async def aput( | ||
self, | ||
config: RunnableConfig, | ||
checkpoint: Checkpoint, | ||
metadata: CheckpointMetadata, | ||
new_versions: ChannelVersions, | ||
) -> RunnableConfig: | ||
"""Asynchronously store a checkpoint with its configuration and metadata. | ||
|
||
Args: | ||
config (RunnableConfig): Configuration for the checkpoint. | ||
checkpoint (Checkpoint): The checkpoint to store. | ||
metadata (CheckpointMetadata): Additional metadata for the checkpoint. | ||
new_versions (ChannelVersions): New channel versions as of this write. | ||
|
||
Returns: | ||
RunnableConfig: Updated configuration after storing the checkpoint. | ||
""" | ||
configurable = config["configurable"].copy() | ||
thread_id = configurable.pop("thread_id") | ||
checkpoint_ns = configurable.pop("checkpoint_ns") | ||
checkpoint_id = configurable.pop( | ||
"checkpoint_id", configurable.pop("thread_ts", None) | ||
) | ||
|
||
copy = checkpoint.copy() | ||
next_config: RunnableConfig = { | ||
"configurable": { | ||
"thread_id": thread_id, | ||
"checkpoint_ns": checkpoint_ns, | ||
"checkpoint_id": checkpoint["id"], | ||
} | ||
} | ||
|
||
query = f"""INSERT INTO "{self.schema_name}".checkpoints(thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata, channel, version, type, blob) | ||
VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :parent_checkpoint_id, :checkpoint, :metadata, :channel, :version, :type, :blob); | ||
""" | ||
|
||
async with self.pool.connect() as conn: | ||
await conn.execute( | ||
text(query), | ||
{ | ||
"thread_id": thread_id, | ||
"checkpoint_ns": checkpoint_ns, | ||
"checkpoint_id": checkpoint_id, | ||
"parent_checkpoint_id": config.get("checkpoint_id"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to check on this value |
||
"checkpoint": json.dumps(copy), | ||
"metadata": json.dumps(dict(metadata)), | ||
"channel": copy.pop("channel_values"), | ||
"version": new_versions, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to ask how new_versions is handled. |
||
"type": next_config["configurable"]["type"], | ||
"blob": json.dumps(next_config["configurable"]["blob"]), | ||
Comment on lines
+235
to
+236
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need type and blob here since that is capture in the writes. |
||
}, | ||
) | ||
await conn.commit() | ||
|
||
return next_config | ||
|
||
|
||
async def aput_writes( | ||
self, | ||
config: RunnableConfig, | ||
writes: Sequence[Tuple[str, Any]], | ||
task_id: str, | ||
) -> None: | ||
"""Asynchronously store intermediate writes linked to a checkpoint. | ||
|
||
Args: | ||
config (RunnableConfig): Configuration of the related checkpoint. | ||
writes (List[Tuple[str, Any]]): List of writes to store. | ||
task_id (str): Identifier for the task creating the writes. | ||
|
||
Raises: | ||
NotImplementedError: Implement this method in your custom checkpoint saver. | ||
""" | ||
query = f"""INSERT INTO "{self.schema_name}".checkpoint_writes(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) | ||
VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :task_id, :idx, :channel, :type, :blob) | ||
""" | ||
upsert = "" | ||
async with self.pool.connect() as conn: | ||
await conn.execute( | ||
text(query), | ||
{ | ||
"thread_id": config["configurable"]["thread_id"], | ||
"checkpoint_ns": config["configurable"]["checkpoint_ns"], | ||
"checkpoint_id": config["configurable"]["checkpoint_id"], | ||
"task_id": task_id, | ||
"idx": idx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. idx is converted from writes[0] with the IDX Map, see https://github.com/langchain-ai/langgraph/blob/506539ac9df277a113423d40999518a0ad2bdb3f/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py#L299 |
||
"channel": write[0], | ||
"type": write[1], | ||
"blob": json.dumps(write[2]), | ||
Comment on lines
+273
to
+275
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is write but note to self to check closer |
||
}, | ||
) | ||
await conn.commit() | ||
|
||
|
||
def list( | ||
self, | ||
config: Optional[RunnableConfig], | ||
*, | ||
filter: Optional[Dict[str, Any]] = None, | ||
before: Optional[RunnableConfig] = None, | ||
limit: Optional[int] = None, | ||
) -> Iterator[CheckpointTuple]: | ||
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." | ||
) | ||
|
||
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: | ||
"""Fetch a checkpoint tuple using the given configuration. | ||
|
||
Args: | ||
config (RunnableConfig): Configuration specifying which checkpoint to retrieve. | ||
|
||
Returns: | ||
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. | ||
|
||
Raises: | ||
NotImplementedError: Implement this method in your custom checkpoint saver. | ||
""" | ||
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." | ||
) | ||
|
||
def put( | ||
self, | ||
config: RunnableConfig, | ||
checkpoint: Checkpoint, | ||
metadata: CheckpointMetadata, | ||
new_versions: ChannelVersions, | ||
) -> RunnableConfig: | ||
"""Store a checkpoint with its configuration and metadata. | ||
|
||
Args: | ||
config (RunnableConfig): Configuration for the checkpoint. | ||
checkpoint (Checkpoint): The checkpoint to store. | ||
metadata (CheckpointMetadata): Additional metadata for the checkpoint. | ||
new_versions (ChannelVersions): New channel versions as of this write. | ||
|
||
Returns: | ||
RunnableConfig: Updated configuration after storing the checkpoint. | ||
|
||
Raises: | ||
NotImplementedError: Method impletented in AsyncAlloyDBSaver. | ||
""" | ||
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." | ||
) | ||
|
||
def put_writes( | ||
self, | ||
config: RunnableConfig, | ||
writes: Sequence[Tuple[str, Any]], | ||
task_id: str, | ||
) -> None: | ||
"""Store intermediate writes linked to a checkpoint. | ||
|
||
Args: | ||
config (RunnableConfig): Configuration of the related checkpoint. | ||
writes (List[Tuple[str, Any]]): List of writes to store. | ||
task_id (str): Identifier for the task creating the writes. | ||
|
||
Raises: | ||
NotImplementedError: Method impletented in AsyncAlloyDBSaver. | ||
""" | ||
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use a variable for the table name "checkpoints" and "checkpoint_writes" in the engine class and import them here, so we make sure they stay consistent