Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Checkpointer base #282

Open
wants to merge 3 commits into
base: langgraph
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/langgraph_google_alloydb_pg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .engine import AlloyDBEngine
from .checkpoint import AlloyDBSaver
from .version import __version__

__all__ = [
"AlloyDBEngine",
"AlloyDBSaver",
"__version__",
]
352 changes: 352 additions & 0 deletions src/langgraph_google_alloydb_pg/async_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import asyncpg # type: ignore

from contextlib import asynccontextmanager

import json
from typing import List, Sequence, Any, AsyncIterator, Iterator, Optional, Dict, Tuple

from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine

from langchain_core.runnables import RunnableConfig

from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_id
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol

from langgraph.checkpoint.serde.base import SerializerProtocol

MetadataInput = Optional[dict[str, Any]]

from .engine import AlloyDBEngine


class AsyncAlloyDBSaver(BaseCheckpointSaver[str]):
"""Checkpoint stored in an AlloyDB for PostgreSQL database."""

__create_key = object()

def __init__(
self,
key: object,
pool: AsyncEngine,
schema_name: str = "public",
serde: Optional[SerializerProtocol] = None
) -> None:
super().__init__(serde=serde)
if key != AsyncAlloyDBSaver.__create_key:
raise Exception(
"only create class through 'create' or 'create_sync' methods"
)
self.pool = pool
self.schema_name = schema_name

@classmethod
async def create(
cls,
engine: AlloyDBEngine,
schema_name: str = "public",
serde: Optional[SerializerProtocol] = None
) -> "AsyncAlloyDBSaver":
"""Create a new AsyncAlloyDBSaver instance.

Args:
engine (AlloyDBEngine): AlloyDB engine to use.
schema_name (str): The schema name where the table is located (default: "public").
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints (default: None).

Raises:
IndexError: If the table provided does not contain required schema.

Returns:
AsyncAlloyDBSaver: A newly created instance of AsyncAlloyDBSaver.
"""

checkpoints_table_schema = await engine._aload_table_schema("checkpoints", schema_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a variable for the table name "checkpoints" and "checkpoint_writes" in the engine class and import them here, so we make sure they stay consistent

checkpoints_column_names = checkpoints_table_schema.columns.keys()

checkpoints_required_columns = ["thread_id",
"checkpoint_ns",
"checkpoint_id",
"parent_checkpoint_id",
"v",
"type",
"checkpoint",
"metadata"]

if not (all(x in checkpoints_column_names for x in checkpoints_required_columns)):
raise IndexError(
f"Table checkpoints.'{schema_name}' has incorrect schema. Got "
f"column names '{checkpoints_column_names}' but required column names "
f"'{checkpoints_required_columns}'.\nPlease create table with following schema:"
f"\nCREATE TABLE {schema_name}.checkpoints ("
"\n thread_id TEXT NOT NULL,"
"\n checkpoint_ns TEXT NOT NULL,"
"\n checkpoint_id UUID NOT NULL,"
"\n parent_checkpoint_id UUID,"
"\n v INT NOT NULL,"
"\n type TEXT NOT NULL,"
"\n checkpoint JSONB NOT NULL,"
"\n metadata JSONB"
"\n);"
)

checkpoint_writes_table_schema = await engine._aload_table_schema("checkpoint_writes", schema_name)
checkpoint_writes_column_names = checkpoint_writes_table_schema.columns.keys()

checkpoint_writes_columns = ["thread_id",
"checkpoint_ns",
"checkpoint_id",
"task_id",
"idx",
"channel",
"type",
"blob"]

if not (all(x in checkpoint_writes_column_names for x in checkpoint_writes_columns)):
raise IndexError(
f"Table checkpoint_writes.'{schema_name}' has incorrect schema. Got "
f"column names '{checkpoint_writes_column_names}' but required column names "
f"'{checkpoint_writes_columns}'.\nPlease create table with following schema:"
f"\nCREATE TABLE {schema_name}.checkpoint_writes ("
"\n thread_id TEXT NOT NULL,"
"\n checkpoint_ns TEXT NOT NULL,"
"\n checkpoint_id UUID NOT NULL,"
"\n task_id UUID NOT NULL,"
"\n idx INT NOT NULL,"
"\n channel TEXT NOT NULL,"
"\n type TEXT NOT NULL,"
"\n blob JSONB NOT NULL"
"\n);"
)
return cls(cls.__create_key, engine._pool, schema_name, serde)

async def alist(
self,
config: Optional[RunnableConfig],
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""Asynchronously list checkpoints that match the given criteria.

Args:
config (Optional[RunnableConfig]): Base configuration for filtering checkpoints.
filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata.
before (Optional[RunnableConfig]): List checkpoints created before this configuration.
limit (Optional[int]): Maximum number of checkpoints to return.

Returns:
AsyncIterator[CheckpointTuple]: Async iterator of matching checkpoint tuples.

Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
yield

async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Asynchronously fetch a checkpoint tuple using the given configuration.

Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.

Returns:
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.

Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError

async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Asynchronously store a checkpoint with its configuration and metadata.

Args:
config (RunnableConfig): Configuration for the checkpoint.
checkpoint (Checkpoint): The checkpoint to store.
metadata (CheckpointMetadata): Additional metadata for the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.

Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
configurable = config["configurable"].copy()
thread_id = configurable.pop("thread_id")
checkpoint_ns = configurable.pop("checkpoint_ns")
checkpoint_id = configurable.pop(
"checkpoint_id", configurable.pop("thread_ts", None)
)

copy = checkpoint.copy()
next_config: RunnableConfig = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}

query = f"""INSERT INTO "{self.schema_name}".checkpoints(thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata, channel, version, type, blob)
VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :parent_checkpoint_id, :checkpoint, :metadata, :channel, :version, :type, :blob);
"""

async with self.pool.connect() as conn:
await conn.execute(
text(query),
{
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
"parent_checkpoint_id": config.get("checkpoint_id"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to check on this value

"checkpoint": json.dumps(copy),
"metadata": json.dumps(dict(metadata)),
"channel": copy.pop("channel_values"),
"version": new_versions,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to ask how new_versions is handled.

"type": next_config["configurable"]["type"],
"blob": json.dumps(next_config["configurable"]["blob"]),
Comment on lines +235 to +236
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need type and blob here since that is capture in the writes.

},
)
await conn.commit()

return next_config


async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Asynchronously store intermediate writes linked to a checkpoint.

Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.

Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
query = f"""INSERT INTO "{self.schema_name}".checkpoint_writes(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
VALUES (:thread_id, :checkpoint_ns, :checkpoint_id, :task_id, :idx, :channel, :type, :blob)
"""
upsert = ""
async with self.pool.connect() as conn:
await conn.execute(
text(query),
{
"thread_id": config["configurable"]["thread_id"],
"checkpoint_ns": config["configurable"]["checkpoint_ns"],
"checkpoint_id": config["configurable"]["checkpoint_id"],
"task_id": task_id,
"idx": idx,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"channel": write[0],
"type": write[1],
"blob": json.dumps(write[2]),
Comment on lines +273 to +275
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is write but note to self to check closer

},
)
await conn.commit()


def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead."
)

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Fetch a checkpoint tuple using the given configuration.

Args:
config (RunnableConfig): Configuration specifying which checkpoint to retrieve.

Returns:
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.

Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead."
)

def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Store a checkpoint with its configuration and metadata.

Args:
config (RunnableConfig): Configuration for the checkpoint.
checkpoint (Checkpoint): The checkpoint to store.
metadata (CheckpointMetadata): Additional metadata for the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.

Returns:
RunnableConfig: Updated configuration after storing the checkpoint.

Raises:
NotImplementedError: Method impletented in AsyncAlloyDBSaver.
"""
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead."
)

def put_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.

Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (List[Tuple[str, Any]]): List of writes to store.
task_id (str): Identifier for the task creating the writes.

Raises:
NotImplementedError: Method impletented in AsyncAlloyDBSaver.
"""
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBSaver. Use AlloyDBSaver interface instead."
)
Loading
Loading